diff --git a/.agents/skills/component-refactoring/SKILL.md b/.agents/skills/component-refactoring/SKILL.md index 140e0ef434..0ed18d71d1 100644 --- a/.agents/skills/component-refactoring/SKILL.md +++ b/.agents/skills/component-refactoring/SKILL.md @@ -187,53 +187,12 @@ const Template = useMemo(() => { **When**: Component directly handles API calls, data transformation, or complex async operations. -**Dify Convention**: Use `@tanstack/react-query` hooks from `web/service/use-*.ts` or create custom data hooks. - -```typescript -// ❌ Before: API logic in component -const MCPServiceCard = () => { - const [basicAppConfig, setBasicAppConfig] = useState({}) - - useEffect(() => { - if (isBasicApp && appId) { - (async () => { - const res = await fetchAppDetail({ url: '/apps', id: appId }) - setBasicAppConfig(res?.model_config || {}) - })() - } - }, [appId, isBasicApp]) - - // More API-related logic... -} - -// ✅ After: Extract to data hook using React Query -// use-app-config.ts -import { useQuery } from '@tanstack/react-query' -import { get } from '@/service/base' - -const NAME_SPACE = 'appConfig' - -export const useAppConfig = (appId: string, isBasicApp: boolean) => { - return useQuery({ - enabled: isBasicApp && !!appId, - queryKey: [NAME_SPACE, 'detail', appId], - queryFn: () => get(`/apps/${appId}`), - select: data => data?.model_config || {}, - }) -} - -// Component becomes cleaner -const MCPServiceCard = () => { - const { data: config, isLoading } = useAppConfig(appId, isBasicApp) - // UI only -} -``` - -**React Query Best Practices in Dify**: -- Define `NAME_SPACE` for query key organization -- Use `enabled` option for conditional fetching -- Use `select` for data transformation -- Export invalidation hooks: `useInvalidXxx` +**Dify Convention**: +- This skill is for component decomposition, not query/mutation design. +- When refactoring data fetching, follow `web/AGENTS.md`. +- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling. +- Do not introduce deprecated `useInvalid` / `useReset`. +- Do not add thin passthrough `useQuery` wrappers during refactoring; only extract a custom hook when it truly orchestrates multiple queries/mutations or shared derived state. **Dify Examples**: - `web/service/use-workflow.ts` diff --git a/.agents/skills/component-refactoring/references/hook-extraction.md b/.agents/skills/component-refactoring/references/hook-extraction.md index a8d75deffd..0d567eb2a6 100644 --- a/.agents/skills/component-refactoring/references/hook-extraction.md +++ b/.agents/skills/component-refactoring/references/hook-extraction.md @@ -155,48 +155,14 @@ const Configuration: FC = () => { ## Common Hook Patterns in Dify -### 1. Data Fetching Hook (React Query) +### 1. Data Fetching / Mutation Hooks -```typescript -// Pattern: Use @tanstack/react-query for data fetching -import { useQuery, useQueryClient } from '@tanstack/react-query' -import { get } from '@/service/base' -import { useInvalid } from '@/service/use-base' +When hook extraction touches query or mutation code, do not use this reference as the source of truth for data-layer patterns. -const NAME_SPACE = 'appConfig' - -// Query keys for cache management -export const appConfigQueryKeys = { - detail: (appId: string) => [NAME_SPACE, 'detail', appId] as const, -} - -// Main data hook -export const useAppConfig = (appId: string) => { - return useQuery({ - enabled: !!appId, - queryKey: appConfigQueryKeys.detail(appId), - queryFn: () => get(`/apps/${appId}`), - select: data => data?.model_config || null, - }) -} - -// Invalidation hook for refreshing data -export const useInvalidAppConfig = () => { - return useInvalid([NAME_SPACE]) -} - -// Usage in component -const Component = () => { - const { data: config, isLoading, error, refetch } = useAppConfig(appId) - const invalidAppConfig = useInvalidAppConfig() - - const handleRefresh = () => { - invalidAppConfig() // Invalidates cache and triggers refetch - } - - return
...
-} -``` +- Follow `web/AGENTS.md` first. +- Use `frontend-query-mutation` for contracts, query shape, data-fetching wrappers, query/mutation call-site patterns, conditional queries, invalidation, and mutation error handling. +- Do not introduce deprecated `useInvalid` / `useReset`. +- Do not extract thin passthrough `useQuery` hooks; only extract orchestration hooks. ### 2. Form State Hook diff --git a/.agents/skills/frontend-query-mutation/SKILL.md b/.agents/skills/frontend-query-mutation/SKILL.md new file mode 100644 index 0000000000..49888bdb66 --- /dev/null +++ b/.agents/skills/frontend-query-mutation/SKILL.md @@ -0,0 +1,44 @@ +--- +name: frontend-query-mutation +description: Guide for implementing Dify frontend query and mutation patterns with TanStack Query and oRPC. Trigger when creating or updating contracts in web/contract, wiring router composition, consuming consoleQuery or marketplaceQuery in components or services, deciding whether to call queryOptions() directly or extract a helper or use-* hook, handling conditional queries, cache invalidation, mutation error handling, or migrating legacy service calls to contract-first query and mutation helpers. +--- + +# Frontend Query & Mutation + +## Intent + +- Keep contract as the single source of truth in `web/contract/*`. +- Prefer contract-shaped `queryOptions()` and `mutationOptions()`. +- Keep invalidation and mutation flow knowledge in the service layer. +- Keep abstractions minimal to preserve TypeScript inference. + +## Workflow + +1. Identify the change surface. + - Read `references/contract-patterns.md` for contract files, router composition, client helpers, and query or mutation call-site shape. + - Read `references/runtime-rules.md` for conditional queries, invalidation, error handling, and legacy migrations. + - Read both references when a task spans contract shape and runtime behavior. +2. Implement the smallest abstraction that fits the task. + - Default to direct `useQuery(...)` or `useMutation(...)` calls with oRPC helpers at the call site. + - Extract a small shared query helper only when multiple call sites share the same extra options. + - Create `web/service/use-{domain}.ts` only for orchestration or shared domain behavior. +3. Preserve Dify conventions. + - Keep contract inputs in `{ params, query?, body? }` shape. + - Bind invalidation in the service-layer mutation definition. + - Prefer `mutate(...)`; use `mutateAsync(...)` only when Promise semantics are required. + +## Files Commonly Touched + +- `web/contract/console/*.ts` +- `web/contract/marketplace.ts` +- `web/contract/router.ts` +- `web/service/client.ts` +- `web/service/use-*.ts` +- component and hook call sites using `consoleQuery` or `marketplaceQuery` + +## References + +- Use `references/contract-patterns.md` for contract shape, router registration, query and mutation helpers, and anti-patterns that degrade inference. +- Use `references/runtime-rules.md` for conditional queries, invalidation, `mutate` versus `mutateAsync`, and legacy migration rules. + +Treat this skill as the single query and mutation entry point for Dify frontend work. Keep detailed rules in the reference files instead of duplicating them in project docs. diff --git a/.agents/skills/frontend-query-mutation/agents/openai.yaml b/.agents/skills/frontend-query-mutation/agents/openai.yaml new file mode 100644 index 0000000000..87f7ae6ea4 --- /dev/null +++ b/.agents/skills/frontend-query-mutation/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "Frontend Query & Mutation" + short_description: "Dify TanStack Query and oRPC patterns" + default_prompt: "Use this skill when implementing or reviewing Dify frontend contracts, query and mutation call sites, conditional queries, invalidation, or legacy query/mutation migrations." diff --git a/.agents/skills/frontend-query-mutation/references/contract-patterns.md b/.agents/skills/frontend-query-mutation/references/contract-patterns.md new file mode 100644 index 0000000000..08016ed2cc --- /dev/null +++ b/.agents/skills/frontend-query-mutation/references/contract-patterns.md @@ -0,0 +1,98 @@ +# Contract Patterns + +## Table of Contents + +- Intent +- Minimal structure +- Core workflow +- Query usage decision rule +- Mutation usage decision rule +- Anti-patterns +- Contract rules +- Type export + +## Intent + +- Keep contract as the single source of truth in `web/contract/*`. +- Default query usage to call-site `useQuery(consoleQuery|marketplaceQuery.xxx.queryOptions(...))` when endpoint behavior maps 1:1 to the contract. +- Keep abstractions minimal and preserve TypeScript inference. + +## Minimal Structure + +```text +web/contract/ +├── base.ts +├── router.ts +├── marketplace.ts +└── console/ + ├── billing.ts + └── ...other domains +web/service/client.ts +``` + +## Core Workflow + +1. Define contract in `web/contract/console/{domain}.ts` or `web/contract/marketplace.ts`. + - Use `base.route({...}).output(type<...>())` as the baseline. + - Add `.input(type<...>())` only when the request has `params`, `query`, or `body`. + - For `GET` without input, omit `.input(...)`; do not use `.input(type())`. +2. Register contract in `web/contract/router.ts`. + - Import directly from domain files and nest by API prefix. +3. Consume from UI call sites via oRPC query utilities. + +```typescript +import { useQuery } from '@tanstack/react-query' +import { consoleQuery } from '@/service/client' + +const invoiceQuery = useQuery(consoleQuery.billing.invoices.queryOptions({ + staleTime: 5 * 60 * 1000, + throwOnError: true, + select: invoice => invoice.url, +})) +``` + +## Query Usage Decision Rule + +1. Default to direct `*.queryOptions(...)` usage at the call site. +2. If 3 or more call sites share the same extra options, extract a small query helper, not a `use-*` passthrough hook. +3. Create `web/service/use-{domain}.ts` only for orchestration. + - Combine multiple queries or mutations. + - Share domain-level derived state or invalidation helpers. + +```typescript +const invoicesBaseQueryOptions = () => + consoleQuery.billing.invoices.queryOptions({ retry: false }) + +const invoiceQuery = useQuery({ + ...invoicesBaseQueryOptions(), + throwOnError: true, +}) +``` + +## Mutation Usage Decision Rule + +1. Default to mutation helpers from `consoleQuery` or `marketplaceQuery`, for example `useMutation(consoleQuery.billing.bindPartnerStack.mutationOptions(...))`. +2. If the mutation flow is heavily custom, use oRPC clients as `mutationFn`, for example `consoleClient.xxx` or `marketplaceClient.xxx`, instead of handwritten non-oRPC mutation logic. + +## Anti-Patterns + +- Do not wrap `useQuery` with `options?: Partial`. +- Do not split local `queryKey` and `queryFn` when oRPC `queryOptions` already exists and fits the use case. +- Do not create thin `use-*` passthrough hooks for a single endpoint. +- These patterns can degrade inference, especially around `throwOnError` and `select`, and add unnecessary indirection. + +## Contract Rules + +- Input structure: always use `{ params, query?, body? }`. +- No-input `GET`: omit `.input(...)`; do not use `.input(type())`. +- Path params: use `{paramName}` in the path and match it in the `params` object. +- Router nesting: group by API prefix, for example `/billing/*` becomes `billing: {}`. +- No barrel files: import directly from specific files. +- Types: import from `@/types/` and use the `type()` helper. +- Mutations: prefer `mutationOptions`; use explicit `mutationKey` mainly for defaults, filtering, and devtools. + +## Type Export + +```typescript +export type ConsoleInputs = InferContractRouterInputs +``` diff --git a/.agents/skills/frontend-query-mutation/references/runtime-rules.md b/.agents/skills/frontend-query-mutation/references/runtime-rules.md new file mode 100644 index 0000000000..02e8b9c2b6 --- /dev/null +++ b/.agents/skills/frontend-query-mutation/references/runtime-rules.md @@ -0,0 +1,133 @@ +# Runtime Rules + +## Table of Contents + +- Conditional queries +- Cache invalidation +- Key API guide +- `mutate` vs `mutateAsync` +- Legacy migration + +## Conditional Queries + +Prefer contract-shaped `queryOptions(...)`. +When required input is missing, prefer `input: skipToken` instead of placeholder params or non-null assertions. +Use `enabled` only for extra business gating after the input itself is already valid. + +```typescript +import { skipToken, useQuery } from '@tanstack/react-query' + +// Disable the query by skipping input construction. +function useAccessMode(appId: string | undefined) { + return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({ + input: appId + ? { params: { appId } } + : skipToken, + })) +} + +// Avoid runtime-only guards that bypass type checking. +function useBadAccessMode(appId: string | undefined) { + return useQuery(consoleQuery.accessControl.appAccessMode.queryOptions({ + input: { params: { appId: appId! } }, + enabled: !!appId, + })) +} +``` + +## Cache Invalidation + +Bind invalidation in the service-layer mutation definition. +Components may add UI feedback in call-site callbacks, but they should not decide which queries to invalidate. + +Use: + +- `.key()` for namespace or prefix invalidation +- `.queryKey(...)` only for exact cache reads or writes such as `getQueryData` and `setQueryData` +- `queryClient.invalidateQueries(...)` in mutation `onSuccess` + +Do not use deprecated `useInvalid` from `use-base.ts`. + +```typescript +// Service layer owns cache invalidation. +export const useUpdateAccessMode = () => { + const queryClient = useQueryClient() + + return useMutation(consoleQuery.accessControl.updateAccessMode.mutationOptions({ + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(), + }) + }, + })) +} + +// Component only adds UI behavior. +updateAccessMode({ appId, mode }, { + onSuccess: () => Toast.notify({ type: 'success', message: '...' }), +}) + +// Avoid putting invalidation knowledge in the component. +mutate({ appId, mode }, { + onSuccess: () => { + queryClient.invalidateQueries({ + queryKey: consoleQuery.accessControl.appWhitelistSubjects.key(), + }) + }, +}) +``` + +## Key API Guide + +- `.key(...)` + - Use for partial matching operations. + - Prefer it for invalidation, refetch, and cancel patterns. + - Example: `queryClient.invalidateQueries({ queryKey: consoleQuery.billing.key() })` +- `.queryKey(...)` + - Use for a specific query's full key. + - Prefer it for exact cache addressing and direct reads or writes. +- `.mutationKey(...)` + - Use for a specific mutation's full key. + - Prefer it for mutation defaults registration, mutation-status filtering, and devtools grouping. + +## `mutate` vs `mutateAsync` + +Prefer `mutate` by default. +Use `mutateAsync` only when Promise semantics are truly required, such as parallel mutations or sequential steps with result dependencies. + +Rules: + +- Event handlers should usually call `mutate(...)` with `onSuccess` or `onError`. +- Every `await mutateAsync(...)` must be wrapped in `try/catch`. +- Do not use `mutateAsync` when callbacks already express the flow clearly. + +```typescript +// Default case. +mutation.mutate(data, { + onSuccess: result => router.push(result.url), +}) + +// Promise semantics are required. +try { + const order = await createOrder.mutateAsync(orderData) + await confirmPayment.mutateAsync({ orderId: order.id, token }) + router.push(`/orders/${order.id}`) +} +catch (error) { + Toast.notify({ + type: 'error', + message: error instanceof Error ? error.message : 'Unknown error', + }) +} +``` + +## Legacy Migration + +When touching old code, migrate it toward these rules: + +| Old pattern | New pattern | +|---|---| +| `useInvalid(key)` in service layer | `queryClient.invalidateQueries(...)` inside mutation `onSuccess` | +| component-triggered invalidation after mutation | move invalidation into the service-layer mutation definition | +| imperative fetch plus manual invalidation | wrap it in `useMutation(...mutationOptions(...))` | +| `await mutateAsync()` without `try/catch` | switch to `mutate(...)` or add `try/catch` | diff --git a/.agents/skills/frontend-testing/SKILL.md b/.agents/skills/frontend-testing/SKILL.md index 280fcb6341..4da070bdbf 100644 --- a/.agents/skills/frontend-testing/SKILL.md +++ b/.agents/skills/frontend-testing/SKILL.md @@ -63,7 +63,8 @@ pnpm analyze-component --review ### File Naming -- Test files: `ComponentName.spec.tsx` (same directory as component) +- Test files: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory +- Placement rule: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. - Integration tests: `web/__tests__/` directory ## Test Structure Template @@ -204,6 +205,16 @@ When assigned to test a directory/path, test **ALL content** within that path: > See [Test Structure Template](#test-structure-template) for correct import/mock patterns. +### `nuqs` Query State Testing (Required for URL State Hooks) + +When a component or hook uses `useQueryState` / `useQueryStates`: + +- ✅ Use `NuqsTestingAdapter` (prefer shared helpers in `web/test/nuqs-testing.tsx`) +- ✅ Assert URL synchronization via `onUrlUpdate` (`searchParams`, `options.history`) +- ✅ For custom parsers (`createParser`), keep `parse` and `serialize` bijective and add round-trip edge cases (`%2F`, `%25`, spaces, legacy encoded values) +- ✅ Verify default-clearing behavior (default values should be removed from URL when applicable) +- ⚠️ Only mock `nuqs` directly when URL behavior is explicitly out of scope for the test + ## Core Principles ### 1. AAA Pattern (Arrange-Act-Assert) diff --git a/.agents/skills/frontend-testing/assets/component-test.template.tsx b/.agents/skills/frontend-testing/assets/component-test.template.tsx index 6b7803bd4b..ff38f88d23 100644 --- a/.agents/skills/frontend-testing/assets/component-test.template.tsx +++ b/.agents/skills/frontend-testing/assets/component-test.template.tsx @@ -41,7 +41,7 @@ import userEvent from '@testing-library/user-event' // Router (if component uses useRouter, usePathname, useSearchParams) // WHY: Isolates tests from Next.js routing, enables testing navigation behavior // const mockPush = vi.fn() -// vi.mock('next/navigation', () => ({ +// vi.mock('@/next/navigation', () => ({ // useRouter: () => ({ push: mockPush }), // usePathname: () => '/test-path', // })) diff --git a/.agents/skills/frontend-testing/references/checklist.md b/.agents/skills/frontend-testing/references/checklist.md index 1ff2b27bbb..10b8fb66f9 100644 --- a/.agents/skills/frontend-testing/references/checklist.md +++ b/.agents/skills/frontend-testing/references/checklist.md @@ -80,6 +80,9 @@ Use this checklist when generating or reviewing tests for Dify frontend componen - [ ] Router mocks match actual Next.js API - [ ] Mocks reflect actual component conditional behavior - [ ] Only mock: API services, complex context providers, third-party libs +- [ ] For `nuqs` URL-state tests, wrap with `NuqsTestingAdapter` (prefer `web/test/nuqs-testing.tsx`) +- [ ] For `nuqs` URL-state tests, assert `onUrlUpdate` payload (`searchParams`, `options.history`) +- [ ] If custom `nuqs` parser exists, add round-trip tests for encoded edge cases (`%2F`, `%25`, spaces, legacy encoded values) ### Queries diff --git a/.agents/skills/frontend-testing/references/mocking.md b/.agents/skills/frontend-testing/references/mocking.md index 86bd375987..f58377c4a5 100644 --- a/.agents/skills/frontend-testing/references/mocking.md +++ b/.agents/skills/frontend-testing/references/mocking.md @@ -125,6 +125,31 @@ describe('Component', () => { }) ``` +### 2.1 `nuqs` Query State (Preferred: Testing Adapter) + +For tests that validate URL query behavior, use `NuqsTestingAdapter` instead of mocking `nuqs` directly. + +```typescript +import { renderHookWithNuqs } from '@/test/nuqs-testing' + +it('should sync query to URL with push history', async () => { + const { result, onUrlUpdate } = renderHookWithNuqs(() => useMyQueryState(), { + searchParams: '?page=1', + }) + + act(() => { + result.current.setQuery({ page: 2 }) + }) + + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('push') + expect(update.searchParams.get('page')).toBe('2') +}) +``` + +Use direct `vi.mock('nuqs')` only when URL synchronization is intentionally out of scope. + ### 3. Portal Components (with Shared State) ```typescript diff --git a/.agents/skills/orpc-contract-first/SKILL.md b/.agents/skills/orpc-contract-first/SKILL.md deleted file mode 100644 index 4e3bfc7a37..0000000000 --- a/.agents/skills/orpc-contract-first/SKILL.md +++ /dev/null @@ -1,46 +0,0 @@ ---- -name: orpc-contract-first -description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories. ---- - -# oRPC Contract-First Development - -## Project Structure - -``` -web/contract/ -├── base.ts # Base contract (inputStructure: 'detailed') -├── router.ts # Router composition & type exports -├── marketplace.ts # Marketplace contracts -└── console/ # Console contracts by domain - ├── system.ts - └── billing.ts -``` - -## Workflow - -1. **Create contract** in `web/contract/console/{domain}.ts` - - Import `base` from `../base` and `type` from `@orpc/contract` - - Define route with `path`, `method`, `input`, `output` - -2. **Register in router** at `web/contract/router.ts` - - Import directly from domain file (no barrel files) - - Nest by API prefix: `billing: { invoices, bindPartnerStack }` - -3. **Create hooks** in `web/service/use-{domain}.ts` - - Use `consoleQuery.{group}.{contract}.queryKey()` for query keys - - Use `consoleClient.{group}.{contract}()` for API calls - -## Key Rules - -- **Input structure**: Always use `{ params, query?, body? }` format -- **Path params**: Use `{paramName}` in path, match in `params` object -- **Router nesting**: Group by API prefix (e.g., `/billing/*` → `billing: {}`) -- **No barrel files**: Import directly from specific files -- **Types**: Import from `@/types/`, use `type()` helper - -## Type Export - -```typescript -export type ConsoleInputs = InferContractRouterInputs -``` diff --git a/.claude/skills/frontend-query-mutation b/.claude/skills/frontend-query-mutation new file mode 120000 index 0000000000..197eed2e64 --- /dev/null +++ b/.claude/skills/frontend-query-mutation @@ -0,0 +1 @@ +../../.agents/skills/frontend-query-mutation \ No newline at end of file diff --git a/.claude/skills/orpc-contract-first b/.claude/skills/orpc-contract-first deleted file mode 120000 index da47b335c7..0000000000 --- a/.claude/skills/orpc-contract-first +++ /dev/null @@ -1 +0,0 @@ -../../.agents/skills/orpc-contract-first \ No newline at end of file diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 637593b9de..b92d4c35a8 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -7,7 +7,7 @@ cd web && pnpm install pipx install uv echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc -echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc +echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index bfb1c85436..1bb7d06232 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,7 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/core/model_runtime/ @laipz8200 @QuantumGhost +/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/actions/setup-web/action.yml b/.github/actions/setup-web/action.yml new file mode 100644 index 0000000000..6f3b3c08b4 --- /dev/null +++ b/.github/actions/setup-web/action.yml @@ -0,0 +1,13 @@ +name: Setup Web Environment + +runs: + using: composite + steps: + - name: Setup Vite+ + uses: voidzero-dev/setup-vp@4a524139920f87f9f7080d3b8545acac019e1852 # v1.0.0 + with: + node-version-file: web/.nvmrc + cache: true + cache-dependency-path: web/pnpm-lock.yaml + run-install: | + cwd: ./web diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 1a57bb0050..a183f0b58c 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -1,25 +1,212 @@ version: 2 -multi-ecosystem-groups: - python: - schedule: - interval: "weekly" # or whatever schedule you want - updates: - package-ecosystem: "pip" directory: "/api" - open-pull-requests-limit: 2 - patterns: ["*"] + open-pull-requests-limit: 10 schedule: interval: "weekly" + groups: + flask: + patterns: + - "flask" + - "flask-*" + - "werkzeug" + - "gunicorn" + google: + patterns: + - "google-*" + - "googleapis-*" + opentelemetry: + patterns: + - "opentelemetry-*" + pydantic: + patterns: + - "pydantic" + - "pydantic-*" + llm: + patterns: + - "langfuse" + - "langsmith" + - "litellm" + - "mlflow*" + - "opik" + - "weave*" + - "arize*" + - "tiktoken" + - "transformers" + database: + patterns: + - "sqlalchemy" + - "psycopg2*" + - "psycogreen" + - "redis*" + - "alembic*" + storage: + patterns: + - "boto3*" + - "botocore*" + - "azure-*" + - "bce-*" + - "cos-python-*" + - "esdk-obs-*" + - "google-cloud-storage" + - "opendal" + - "oss2" + - "supabase*" + - "tos*" + vdb: + patterns: + - "alibabacloud*" + - "chromadb" + - "clickhouse-*" + - "clickzetta-*" + - "couchbase" + - "elasticsearch" + - "opensearch-py" + - "oracledb" + - "pgvect*" + - "pymilvus" + - "pymochow" + - "pyobvector" + - "qdrant-client" + - "intersystems-*" + - "tablestore" + - "tcvectordb" + - "tidb-vector" + - "upstash-*" + - "volcengine-*" + - "weaviate-*" + - "xinference-*" + - "mo-vector" + - "mysql-connector-*" + dev: + patterns: + - "coverage" + - "dotenv-linter" + - "faker" + - "lxml-stubs" + - "basedpyright" + - "ruff" + - "pytest*" + - "types-*" + - "boto3-stubs" + - "hypothesis" + - "pandas-stubs" + - "scipy-stubs" + - "import-linter" + - "celery-types" + - "mypy*" + - "pyrefly" + python-packages: + patterns: + - "*" - package-ecosystem: "uv" directory: "/api" - open-pull-requests-limit: 2 - patterns: ["*"] + open-pull-requests-limit: 10 schedule: interval: "weekly" - - package-ecosystem: "npm" - directory: "/web" + groups: + flask: + patterns: + - "flask" + - "flask-*" + - "werkzeug" + - "gunicorn" + google: + patterns: + - "google-*" + - "googleapis-*" + opentelemetry: + patterns: + - "opentelemetry-*" + pydantic: + patterns: + - "pydantic" + - "pydantic-*" + llm: + patterns: + - "langfuse" + - "langsmith" + - "litellm" + - "mlflow*" + - "opik" + - "weave*" + - "arize*" + - "tiktoken" + - "transformers" + database: + patterns: + - "sqlalchemy" + - "psycopg2*" + - "psycogreen" + - "redis*" + - "alembic*" + storage: + patterns: + - "boto3*" + - "botocore*" + - "azure-*" + - "bce-*" + - "cos-python-*" + - "esdk-obs-*" + - "google-cloud-storage" + - "opendal" + - "oss2" + - "supabase*" + - "tos*" + vdb: + patterns: + - "alibabacloud*" + - "chromadb" + - "clickhouse-*" + - "clickzetta-*" + - "couchbase" + - "elasticsearch" + - "opensearch-py" + - "oracledb" + - "pgvect*" + - "pymilvus" + - "pymochow" + - "pyobvector" + - "qdrant-client" + - "intersystems-*" + - "tablestore" + - "tcvectordb" + - "tidb-vector" + - "upstash-*" + - "volcengine-*" + - "weaviate-*" + - "xinference-*" + - "mo-vector" + - "mysql-connector-*" + dev: + patterns: + - "coverage" + - "dotenv-linter" + - "faker" + - "lxml-stubs" + - "basedpyright" + - "ruff" + - "pytest*" + - "types-*" + - "boto3-stubs" + - "hypothesis" + - "pandas-stubs" + - "scipy-stubs" + - "import-linter" + - "celery-types" + - "mypy*" + - "pyrefly" + python-packages: + patterns: + - "*" + - package-ecosystem: "github-actions" + directory: "/" + open-pull-requests-limit: 5 schedule: interval: "weekly" - open-pull-requests-limit: 2 + groups: + github-actions-dependencies: + patterns: + - "*" diff --git a/.github/workflows/anti-slop.yml b/.github/workflows/anti-slop.yml new file mode 100644 index 0000000000..b0f0a36bc9 --- /dev/null +++ b/.github/workflows/anti-slop.yml @@ -0,0 +1,19 @@ +name: Anti-Slop PR Check + +on: + pull_request_target: + types: [opened, edited, synchronize] + +permissions: + pull-requests: write + contents: read + +jobs: + anti-slop: + runs-on: ubuntu-latest + steps: + - uses: peakoss/anti-slop@85daca1880e9e1af197fc06ea03349daf08f4202 # v0.2.1 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + close-pr: false + failure-add-pr-labels: "needs-revision" diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 52e3272f99..6b87946221 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -2,6 +2,12 @@ name: Run Pytest on: workflow_call: + secrets: + CODECOV_TOKEN: + required: false + +permissions: + contents: read concurrency: group: api-tests-${{ github.head_ref || github.run_id }} @@ -11,6 +17,8 @@ jobs: test: name: API Tests runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: run: shell: bash @@ -22,12 +30,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -51,7 +60,7 @@ jobs: run: sh .github/workflows/expose_service_ports.sh - name: Set up Sandbox - uses: hoverkraft-tech/compose-action@v2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -79,21 +88,12 @@ jobs: api/tests/test_containers_integration_tests \ api/tests/unit_tests - - name: Coverage Summary - run: | - set -x - # Extract coverage percentage and create a summary - TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])') - - # Create a detailed coverage summary - echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY - echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY - { - echo "" - echo "
File-level coverage (click to expand)" - echo "" - echo '```' - uv run --project api coverage report -m - echo '```' - echo "
" - } >> $GITHUB_STEP_SUMMARY + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }} + uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 + with: + files: ./coverage.xml + disable_search: true + flags: api + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index 4571fd1cd1..be6186980e 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -12,22 +12,34 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs id: docker-compose-changes - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | docker/generate_docker_compose docker/.env.example docker/docker-compose-template.yaml docker/docker-compose.yaml - - uses: actions/setup-python@v6 + - name: Check web inputs + id: web-changes + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + with: + files: | + web/** + - name: Check api inputs + id: api-changes + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 + with: + files: | + api/** + - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@v7 + - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Generate Docker Compose if: steps.docker-compose-changes.outputs.any_changed == 'true' @@ -35,7 +47,8 @@ jobs: cd docker ./generate_docker_compose - - run: | + - if: steps.api-changes.outputs.any_changed == 'true' + run: | cd api uv sync --dev # fmt first to avoid line too long @@ -46,11 +59,13 @@ jobs: uv run ruff format .. - name: count migration progress + if: steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep + if: steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -79,9 +94,14 @@ jobs: find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \; find . -name "*.py.bak" -type f -delete - # mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter. - - name: mdformat - run: | - uvx --python 3.13 mdformat . --exclude ".agents/skills/**" + - name: Setup web environment + if: steps.web-changes.outputs.any_changed == 'true' + uses: ./.github/actions/setup-web - - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27 + - name: ESLint autofix + if: steps.web-changes.outputs.any_changed == 'true' + run: | + cd web + vp exec eslint --concurrency=2 --prune-suppressions --quiet || true + + - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index ac7f3a6b48..1ae8d44482 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -53,26 +53,26 @@ jobs: echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV - name: Login to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Extract metadata for Docker id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 with: images: ${{ env[matrix.image_name_env] }} - name: Build Docker image id: build - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: context: "{{defaultContext}}:${{ matrix.context }}" platforms: ${{ matrix.platform }} @@ -91,7 +91,7 @@ jobs: touch "/tmp/digests/${sanitized_digest}" - name: Upload digest - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: digests-${{ matrix.context }}-${{ env.PLATFORM_PAIR }} path: /tmp/digests/* @@ -113,21 +113,21 @@ jobs: context: "web" steps: - name: Download digests - uses: actions/download-artifact@v7 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: path: /tmp/digests pattern: digests-${{ matrix.context }}-* merge-multiple: true - name: Login to Docker Hub - uses: docker/login-action@v3 + uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # v4.0.0 with: username: ${{ env.DOCKERHUB_USER }} password: ${{ env.DOCKERHUB_TOKEN }} - name: Extract metadata for Docker id: meta - uses: docker/metadata-action@v5 + uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # v6.0.0 with: images: ${{ env[matrix.image_name_env] }} tags: | diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index e20cf9850b..ffb9734e48 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -13,13 +13,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: "3.12" @@ -40,7 +40,7 @@ jobs: cp middleware.env.example middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml @@ -63,13 +63,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: "3.12" @@ -94,7 +94,7 @@ jobs: sed -i 's/DB_USERNAME=postgres/DB_USERNAME=mysql/' middleware.env - name: Set up Middlewares - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.middleware.yaml diff --git a/.github/workflows/deploy-agent-dev.yml b/.github/workflows/deploy-agent-dev.yml index dd759f7ba5..cd5fe9242e 100644 --- a/.github/workflows/deploy-agent-dev.yml +++ b/.github/workflows/deploy-agent-dev.yml @@ -19,7 +19,7 @@ jobs: github.event.workflow_run.head_branch == 'deploy/agent-dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.AGENT_DEV_SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/deploy-dev.yml b/.github/workflows/deploy-dev.yml index 38fa0b9a7f..954537663a 100644 --- a/.github/workflows/deploy-dev.yml +++ b/.github/workflows/deploy-dev.yml @@ -16,7 +16,7 @@ jobs: github.event.workflow_run.head_branch == 'deploy/dev' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/deploy-hitl.yml b/.github/workflows/deploy-hitl.yml index a3fd52afc6..c6f1cc7e6f 100644 --- a/.github/workflows/deploy-hitl.yml +++ b/.github/workflows/deploy-hitl.yml @@ -16,7 +16,7 @@ jobs: github.event.workflow_run.head_branch == 'build/feat/hitl' steps: - name: Deploy to server - uses: appleboy/ssh-action@v1 + uses: appleboy/ssh-action@0ff4204d59e8e51228ff73bce53f80d53301dee2 # v1.2.5 with: host: ${{ secrets.HITL_SSH_HOST }} username: ${{ secrets.SSH_USER }} diff --git a/.github/workflows/docker-build.yml b/.github/workflows/docker-build.yml index cadc1b5507..340b380dc9 100644 --- a/.github/workflows/docker-build.yml +++ b/.github/workflows/docker-build.yml @@ -32,13 +32,13 @@ jobs: context: "web" steps: - name: Set up QEMU - uses: docker/setup-qemu-action@v3 + uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4.0.0 - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 + uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # v4.0.0 - name: Build Docker Image - uses: docker/build-push-action@v6 + uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # v7.0.0 with: push: false context: "{{defaultContext}}:${{ matrix.context }}" diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 06782b53c1..278e10bc04 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -9,6 +9,6 @@ jobs: pull-requests: write runs-on: ubuntu-latest steps: - - uses: actions/labeler@v6 + - uses: actions/labeler@634933edcd8ababfe52f92936142cc22ac488b1b # v6.0.1 with: sync-labels: true diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index d6653de950..69023c24cc 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -27,8 +27,8 @@ jobs: vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} steps: - - uses: actions/checkout@v6 - - uses: dorny/paths-filter@v3 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1 id: changes with: filters: | @@ -39,6 +39,7 @@ jobs: web: - 'web/**' - '.github/workflows/web-tests.yml' + - '.github/actions/setup-web/**' vdb: - 'api/core/rag/datasource/**' - 'docker/**' @@ -55,12 +56,14 @@ jobs: needs: check-changes if: needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml + secrets: inherit web-tests: name: Web Tests needs: check-changes if: needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml + secrets: inherit style-check: name: Style Check diff --git a/.github/workflows/pyrefly-diff-comment.yml b/.github/workflows/pyrefly-diff-comment.yml index f9fbcba465..0278e1e0d3 100644 --- a/.github/workflows/pyrefly-diff-comment.yml +++ b/.github/workflows/pyrefly-diff-comment.yml @@ -21,7 +21,7 @@ jobs: if: ${{ github.event.workflow_run.conclusion == 'success' && github.event.workflow_run.pull_requests[0].head.repo.full_name != github.repository }} steps: - name: Download pyrefly diff artifact - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | @@ -49,7 +49,7 @@ jobs: run: unzip -o pyrefly_diff.zip - name: Post comment - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/pyrefly-diff.yml b/.github/workflows/pyrefly-diff.yml index 14338e85b3..a00f469bbe 100644 --- a/.github/workflows/pyrefly-diff.yml +++ b/.github/workflows/pyrefly-diff.yml @@ -17,12 +17,12 @@ jobs: pull-requests: write steps: - name: Checkout PR branch - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 - name: Setup Python & UV - uses: astral-sh/setup-uv@v5 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true @@ -55,7 +55,7 @@ jobs: echo ${{ github.event.pull_request.number }} > pr_number.txt - name: Upload pyrefly diff - uses: actions/upload-artifact@v4 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: pyrefly_diff path: | @@ -64,7 +64,7 @@ jobs: - name: Comment PR with pyrefly diff if: ${{ github.event.pull_request.head.repo.full_name == github.repository }} - uses: actions/github-script@v8 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index b15c26a096..c21331ec0d 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -16,6 +16,6 @@ jobs: runs-on: ubuntu-latest steps: - name: Check title - uses: amannn/action-semantic-pull-request@v6.1.1 + uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index b6df1d7e93..5cf52daed2 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -18,7 +18,7 @@ jobs: pull-requests: write steps: - - uses: actions/stale@v10 + - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 with: days-before-issue-stale: 15 days-before-issue-close: 3 diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index cbd6edf94b..657a481f74 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -19,13 +19,13 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** @@ -33,7 +33,7 @@ jobs: - name: Setup UV and Python if: steps.changed-files.outputs.any_changed == 'true' - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: false python-version: "3.12" @@ -67,42 +67,28 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** .github/workflows/style.yml + .github/actions/setup-web/** - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup NodeJS - uses: actions/setup-node@v6 + - name: Setup web environment if: steps.changed-files.outputs.any_changed == 'true' - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Web dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm install --frozen-lockfile + uses: ./.github/actions/setup-web - name: Web style check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web run: | - pnpm run lint:ci + vp run lint:ci # pnpm run lint:report # continue-on-error: true @@ -116,17 +102,17 @@ jobs: - name: Web tsslint if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run lint:tss + run: vp run lint:tss - name: Web type check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run type-check + run: vp run type-check - name: Web dead code check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run knip + run: vp run knip superlinter: name: SuperLinter @@ -134,14 +120,14 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | **.sh @@ -152,7 +138,7 @@ jobs: .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@v8 + uses: super-linter/super-linter/slim@61abc07d755095a68f4987d1c2c3d1d64408f1f9 # v8.5.0 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning diff --git a/.github/workflows/tool-test-sdks.yaml b/.github/workflows/tool-test-sdks.yaml index ec392cb3b2..3fc351c0c2 100644 --- a/.github/workflows/tool-test-sdks.yaml +++ b/.github/workflows/tool-test-sdks.yaml @@ -21,14 +21,14 @@ jobs: working-directory: sdks/nodejs-client steps: - - uses: actions/checkout@v6 + - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Use Node.js - uses: actions/setup-node@v6 + uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # v6.3.0 with: - node-version: 24 + node-version: 22 cache: '' cache-dependency-path: 'pnpm-lock.yaml' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 5d9440ff35..84f8000a01 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -38,7 +38,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 token: ${{ secrets.GITHUB_TOKEN }} @@ -48,18 +48,8 @@ jobs: git config --global user.name "github-actions[bot]" git config --global user.email "github-actions[bot]@users.noreply.github.com" - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Set up Node.js - uses: actions/setup-node@v6 - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml + - name: Setup web environment + uses: ./.github/actions/setup-web - name: Detect changed files and generate diff id: detect_changes @@ -130,7 +120,7 @@ jobs: - name: Run Claude Code for Translation Sync if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@v1 + uses: anthropics/claude-code-action@6062f3709600659be5e47fcddf2cf76993c235c2 # v1.0.76 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml index 66a29453b4..1caaddd47a 100644 --- a/.github/workflows/trigger-i18n-sync.yml +++ b/.github/workflows/trigger-i18n-sync.yml @@ -21,7 +21,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: fetch-depth: 0 @@ -59,7 +59,7 @@ jobs: - name: Trigger i18n sync workflow if: steps.detect.outputs.has_changes == 'true' - uses: peter-evans/repository-dispatch@v3 + uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1 with: token: ${{ secrets.GITHUB_TOKEN }} event-type: i18n-sync diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index 7735afdaca..f45f2137d6 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -19,19 +19,19 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Free Disk Space - uses: endersonmenezes/free-disk-space@v3 + uses: endersonmenezes/free-disk-space@7901478139cff6e9d44df5972fd8ab8fcade4db1 # v3.2.2 with: remove_dotnet: true remove_haskell: true remove_tool_cache: true - name: Setup UV and Python - uses: astral-sh/setup-uv@v7 + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 with: enable-cache: true python-version: ${{ matrix.python-version }} @@ -60,7 +60,7 @@ jobs: # tiflash - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase) - uses: hoverkraft-tech/compose-action@v2.0.2 + uses: hoverkraft-tech/compose-action@4894d2492015c1774ee5a13a95b1072093087ec3 # v2.5.0 with: compose-file: | docker/docker-compose.yaml diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index f50689636b..d40cd4bfeb 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -2,6 +2,9 @@ name: Web Tests on: workflow_call: + secrets: + CODECOV_TOKEN: + required: false permissions: contents: read @@ -14,11 +17,13 @@ jobs: test: name: Web Tests (${{ matrix.shardIndex }}/${{ matrix.shardTotal }}) runs-on: ubuntu-latest + env: + VITEST_COVERAGE_SCOPE: app-components strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4] - shardTotal: [4] + shardIndex: [1, 2, 3, 4, 5, 6] + shardTotal: [6] defaults: run: shell: bash @@ -26,32 +31,19 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - run: pnpm install --frozen-lockfile + - name: Setup web environment + uses: ./.github/actions/setup-web - name: Run tests - run: pnpm vitest run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage + run: vp test run --reporter=blob --shard=${{ matrix.shardIndex }}/${{ matrix.shardTotal }} --coverage - name: Upload blob report if: ${{ !cancelled() }} - uses: actions/upload-artifact@v6 + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 with: name: blob-report-${{ matrix.shardIndex }} path: web/.vitest-reports/* @@ -63,6 +55,8 @@ jobs: if: ${{ !cancelled() }} needs: [test] runs-on: ubuntu-latest + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} defaults: run: shell: bash @@ -70,361 +64,32 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: + fetch-depth: 0 persist-credentials: false - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup Node.js - uses: actions/setup-node@v6 - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Install dependencies - run: pnpm install --frozen-lockfile + - name: Setup web environment + uses: ./.github/actions/setup-web - name: Download blob reports - uses: actions/download-artifact@v6 + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 with: path: web/.vitest-reports pattern: blob-report-* merge-multiple: true - name: Merge reports - run: pnpm vitest --merge-reports --coverage --silent=passed-only + run: vp test --merge-reports --coverage --silent=passed-only - - name: Coverage Summary - if: always() - id: coverage-summary - run: | - set -eo pipefail - - COVERAGE_FILE="coverage/coverage-final.json" - COVERAGE_SUMMARY_FILE="coverage/coverage-summary.json" - - if [ ! -f "$COVERAGE_FILE" ] && [ ! -f "$COVERAGE_SUMMARY_FILE" ]; then - echo "has_coverage=false" >> "$GITHUB_OUTPUT" - echo "### 🚨 Test Coverage Report :test_tube:" >> "$GITHUB_STEP_SUMMARY" - echo "Coverage data not found. Ensure Vitest runs with coverage enabled." >> "$GITHUB_STEP_SUMMARY" - exit 0 - fi - - echo "has_coverage=true" >> "$GITHUB_OUTPUT" - - node <<'NODE' >> "$GITHUB_STEP_SUMMARY" - const fs = require('fs'); - const path = require('path'); - let libCoverage = null; - - try { - libCoverage = require('istanbul-lib-coverage'); - } catch (error) { - libCoverage = null; - } - - const summaryPath = path.join('coverage', 'coverage-summary.json'); - const finalPath = path.join('coverage', 'coverage-final.json'); - - const hasSummary = fs.existsSync(summaryPath); - const hasFinal = fs.existsSync(finalPath); - - if (!hasSummary && !hasFinal) { - console.log('### Test Coverage Summary :test_tube:'); - console.log(''); - console.log('No coverage data found.'); - process.exit(0); - } - - const summary = hasSummary - ? JSON.parse(fs.readFileSync(summaryPath, 'utf8')) - : null; - const coverage = hasFinal - ? JSON.parse(fs.readFileSync(finalPath, 'utf8')) - : null; - - const getLineCoverageFromStatements = (statementMap, statementHits) => { - const lineHits = {}; - - if (!statementMap || !statementHits) { - return lineHits; - } - - Object.entries(statementMap).forEach(([key, statement]) => { - const line = statement?.start?.line; - if (!line) { - return; - } - const hits = statementHits[key] ?? 0; - const previous = lineHits[line]; - lineHits[line] = previous === undefined ? hits : Math.max(previous, hits); - }); - - return lineHits; - }; - - const getFileCoverage = (entry) => ( - libCoverage ? libCoverage.createFileCoverage(entry) : null - ); - - const getLineHits = (entry, fileCoverage) => { - const lineHits = entry.l ?? {}; - if (Object.keys(lineHits).length > 0) { - return lineHits; - } - if (fileCoverage) { - return fileCoverage.getLineCoverage(); - } - return getLineCoverageFromStatements(entry.statementMap ?? {}, entry.s ?? {}); - }; - - const getUncoveredLines = (entry, fileCoverage, lineHits) => { - if (lineHits && Object.keys(lineHits).length > 0) { - return Object.entries(lineHits) - .filter(([, count]) => count === 0) - .map(([line]) => Number(line)) - .sort((a, b) => a - b); - } - if (fileCoverage) { - return fileCoverage.getUncoveredLines(); - } - return []; - }; - - const totals = { - lines: { covered: 0, total: 0 }, - statements: { covered: 0, total: 0 }, - branches: { covered: 0, total: 0 }, - functions: { covered: 0, total: 0 }, - }; - const fileSummaries = []; - - if (summary) { - const totalEntry = summary.total ?? {}; - ['lines', 'statements', 'branches', 'functions'].forEach((key) => { - if (totalEntry[key]) { - totals[key].covered = totalEntry[key].covered ?? 0; - totals[key].total = totalEntry[key].total ?? 0; - } - }); - - Object.entries(summary) - .filter(([file]) => file !== 'total') - .forEach(([file, data]) => { - fileSummaries.push({ - file, - pct: data.lines?.pct ?? data.statements?.pct ?? 0, - lines: { - covered: data.lines?.covered ?? 0, - total: data.lines?.total ?? 0, - }, - }); - }); - } else if (coverage) { - Object.entries(coverage).forEach(([file, entry]) => { - const fileCoverage = getFileCoverage(entry); - const lineHits = getLineHits(entry, fileCoverage); - const statementHits = entry.s ?? {}; - const branchHits = entry.b ?? {}; - const functionHits = entry.f ?? {}; - - const lineTotal = Object.keys(lineHits).length; - const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; - - const statementTotal = Object.keys(statementHits).length; - const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; - - const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); - const branchCovered = Object.values(branchHits).reduce( - (acc, branches) => acc + branches.filter((n) => n > 0).length, - 0, - ); - - const functionTotal = Object.keys(functionHits).length; - const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; - - totals.lines.total += lineTotal; - totals.lines.covered += lineCovered; - totals.statements.total += statementTotal; - totals.statements.covered += statementCovered; - totals.branches.total += branchTotal; - totals.branches.covered += branchCovered; - totals.functions.total += functionTotal; - totals.functions.covered += functionCovered; - - const pct = (covered, tot) => (tot > 0 ? (covered / tot) * 100 : 0); - - fileSummaries.push({ - file, - pct: pct(lineCovered || statementCovered, lineTotal || statementTotal), - lines: { - covered: lineCovered || statementCovered, - total: lineTotal || statementTotal, - }, - }); - }); - } - - const pct = (covered, tot) => (tot > 0 ? ((covered / tot) * 100).toFixed(2) : '0.00'); - - console.log('### Test Coverage Summary :test_tube:'); - console.log(''); - console.log('| Metric | Coverage | Covered / Total |'); - console.log('|--------|----------|-----------------|'); - console.log(`| Lines | ${pct(totals.lines.covered, totals.lines.total)}% | ${totals.lines.covered} / ${totals.lines.total} |`); - console.log(`| Statements | ${pct(totals.statements.covered, totals.statements.total)}% | ${totals.statements.covered} / ${totals.statements.total} |`); - console.log(`| Branches | ${pct(totals.branches.covered, totals.branches.total)}% | ${totals.branches.covered} / ${totals.branches.total} |`); - console.log(`| Functions | ${pct(totals.functions.covered, totals.functions.total)}% | ${totals.functions.covered} / ${totals.functions.total} |`); - - console.log(''); - console.log('
File coverage (lowest lines first)'); - console.log(''); - console.log('```'); - fileSummaries - .sort((a, b) => (a.pct - b.pct) || (b.lines.total - a.lines.total)) - .slice(0, 25) - .forEach(({ file, pct, lines }) => { - console.log(`${pct.toFixed(2)}%\t${lines.covered}/${lines.total}\t${file}`); - }); - console.log('```'); - console.log('
'); - - if (coverage) { - const pctValue = (covered, tot) => { - if (tot === 0) { - return '0'; - } - return ((covered / tot) * 100) - .toFixed(2) - .replace(/\.?0+$/, ''); - }; - - const formatLineRanges = (lines) => { - if (lines.length === 0) { - return ''; - } - const ranges = []; - let start = lines[0]; - let end = lines[0]; - - for (let i = 1; i < lines.length; i += 1) { - const current = lines[i]; - if (current === end + 1) { - end = current; - continue; - } - ranges.push(start === end ? `${start}` : `${start}-${end}`); - start = current; - end = current; - } - ranges.push(start === end ? `${start}` : `${start}-${end}`); - return ranges.join(','); - }; - - const tableTotals = { - statements: { covered: 0, total: 0 }, - branches: { covered: 0, total: 0 }, - functions: { covered: 0, total: 0 }, - lines: { covered: 0, total: 0 }, - }; - const tableRows = Object.entries(coverage) - .map(([file, entry]) => { - const fileCoverage = getFileCoverage(entry); - const lineHits = getLineHits(entry, fileCoverage); - const statementHits = entry.s ?? {}; - const branchHits = entry.b ?? {}; - const functionHits = entry.f ?? {}; - - const lineTotal = Object.keys(lineHits).length; - const lineCovered = Object.values(lineHits).filter((n) => n > 0).length; - const statementTotal = Object.keys(statementHits).length; - const statementCovered = Object.values(statementHits).filter((n) => n > 0).length; - const branchTotal = Object.values(branchHits).reduce((acc, branches) => acc + branches.length, 0); - const branchCovered = Object.values(branchHits).reduce( - (acc, branches) => acc + branches.filter((n) => n > 0).length, - 0, - ); - const functionTotal = Object.keys(functionHits).length; - const functionCovered = Object.values(functionHits).filter((n) => n > 0).length; - - tableTotals.lines.total += lineTotal; - tableTotals.lines.covered += lineCovered; - tableTotals.statements.total += statementTotal; - tableTotals.statements.covered += statementCovered; - tableTotals.branches.total += branchTotal; - tableTotals.branches.covered += branchCovered; - tableTotals.functions.total += functionTotal; - tableTotals.functions.covered += functionCovered; - - const uncoveredLines = getUncoveredLines(entry, fileCoverage, lineHits); - - const filePath = entry.path ?? file; - const relativePath = path.isAbsolute(filePath) - ? path.relative(process.cwd(), filePath) - : filePath; - - return { - file: relativePath || file, - statements: pctValue(statementCovered, statementTotal), - branches: pctValue(branchCovered, branchTotal), - functions: pctValue(functionCovered, functionTotal), - lines: pctValue(lineCovered, lineTotal), - uncovered: formatLineRanges(uncoveredLines), - }; - }) - .sort((a, b) => a.file.localeCompare(b.file)); - - const columns = [ - { key: 'file', header: 'File', align: 'left' }, - { key: 'statements', header: '% Stmts', align: 'right' }, - { key: 'branches', header: '% Branch', align: 'right' }, - { key: 'functions', header: '% Funcs', align: 'right' }, - { key: 'lines', header: '% Lines', align: 'right' }, - { key: 'uncovered', header: 'Uncovered Line #s', align: 'left' }, - ]; - - const allFilesRow = { - file: 'All files', - statements: pctValue(tableTotals.statements.covered, tableTotals.statements.total), - branches: pctValue(tableTotals.branches.covered, tableTotals.branches.total), - functions: pctValue(tableTotals.functions.covered, tableTotals.functions.total), - lines: pctValue(tableTotals.lines.covered, tableTotals.lines.total), - uncovered: '', - }; - - const rowsForOutput = [allFilesRow, ...tableRows]; - const formatRow = (row) => `| ${columns - .map(({ key }) => String(row[key] ?? '')) - .join(' | ')} |`; - const headerRow = `| ${columns.map(({ header }) => header).join(' | ')} |`; - const dividerRow = `| ${columns - .map(({ align }) => (align === 'right' ? '---:' : ':---')) - .join(' | ')} |`; - - console.log(''); - console.log('
Vitest coverage table'); - console.log(''); - console.log(headerRow); - console.log(dividerRow); - rowsForOutput.forEach((row) => console.log(formatRow(row))); - console.log('
'); - } - NODE - - - name: Upload Coverage Artifact - if: steps.coverage-summary.outputs.has_coverage == 'true' - uses: actions/upload-artifact@v6 + - name: Report coverage + if: ${{ env.CODECOV_TOKEN != '' }} + uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 with: - name: web-coverage-report - path: web/coverage - retention-days: 30 - if-no-files-found: error + directory: web/coverage + flags: web + env: + CODECOV_TOKEN: ${{ env.CODECOV_TOKEN }} web-build: name: Web Build @@ -435,38 +100,24 @@ jobs: steps: - name: Checkout code - uses: actions/checkout@v6 + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: persist-credentials: false - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v47 + uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** .github/workflows/web-tests.yml + .github/actions/setup-web/** - - name: Install pnpm - uses: pnpm/action-setup@v4 - with: - package_json_file: web/package.json - run_install: false - - - name: Setup NodeJS - uses: actions/setup-node@v6 + - name: Setup web environment if: steps.changed-files.outputs.any_changed == 'true' - with: - node-version: 24 - cache: pnpm - cache-dependency-path: ./web/pnpm-lock.yaml - - - name: Web dependencies - if: steps.changed-files.outputs.any_changed == 'true' - working-directory: ./web - run: pnpm install --frozen-lockfile + uses: ./.github/actions/setup-web - name: Web build check if: steps.changed-files.outputs.any_changed == 'true' working-directory: ./web - run: pnpm run build + run: vp run build diff --git a/.gitignore b/.gitignore index 7bd919f095..aaca9f2b0a 100644 --- a/.gitignore +++ b/.gitignore @@ -222,6 +222,7 @@ mise.toml # AI Assistant .roo/ +/.claude/worktrees/ api/.env.backup /clickzetta @@ -236,3 +237,6 @@ scripts/stress-test/reports/ # settings *.local.json *.local.md + +# Code Agent Folder +.qoder/* \ No newline at end of file diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template index 700b815c3b..c3e2c50c52 100644 --- a/.vscode/launch.json.template +++ b/.vscode/launch.json.template @@ -37,7 +37,7 @@ "-c", "1", "-Q", - "dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", + "dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution", "--loglevel", "INFO" ], diff --git a/AGENTS.md b/AGENTS.md index 51fa6e4527..d25d2eed96 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -29,7 +29,7 @@ The codebase is split into: ## Language Style -- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). +- **Python**: Keep type hints on functions and attributes, and implement relevant special methods (e.g., `__repr__`, `__str__`). Prefer `TypedDict` over `dict` or `Mapping` for type safety and better code documentation. - **TypeScript**: Use the strict config, rely on ESLint (`pnpm lint:fix` preferred) plus `pnpm type-check:tsgo`, and avoid `any` types. ## General Practices diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d7f007af67..775401bfa5 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -97,3 +97,8 @@ Feel free to reach out if you encounter any issues during the setup process. ## Getting Help If you ever get stuck or get a burning question while contributing, simply shoot your queries our way via the related GitHub issue, or hop onto our [Discord](https://discord.gg/8Tpq4AcN9c) for a quick chat. + +## Automated Agent Contributions + +> [!NOTE] +> If you are an automated agent, just add 🤖🤖🤖 to the end of the PR title to opt-in. diff --git a/Makefile b/Makefile index 0aff26b3e5..55871c86a7 100644 --- a/Makefile +++ b/Makefile @@ -68,8 +68,9 @@ lint: @echo "✅ Linting complete" type-check: - @echo "📝 Running type checks (basedpyright + mypy)..." + @echo "📝 Running type checks (basedpyright + pyrefly + mypy)..." @./dev/basedpyright-check $(PATH_TO_CHECK) + @./dev/pyrefly-check-local @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" @@ -131,7 +132,7 @@ help: @echo " make format - Format code with ruff" @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" - @echo " make type-check - Run type checks (basedpyright, mypy)" + @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/README.md b/README.md index 90961a5346..bef8f6b782 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ README in বাংলা

-Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. +Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features: ## Quick start @@ -133,7 +133,7 @@ Star Dify on GitHub and be instantly notified of new releases. ### Custom configurations -If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker-compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). +If you need to customize the configuration, please refer to the comments in our [.env.example](docker/.env.example) file and update the corresponding values in your `.env` file. Additionally, you might need to make adjustments to the `docker-compose.yaml` file itself, such as changing image versions, port mappings, or volume mounts, based on your specific deployment environment and requirements. After making any changes, please re-run `docker compose up -d`. You can find the full list of available environment variables [here](https://docs.dify.ai/getting-started/install-self-hosted/environments). #### Customizing Suggested Questions diff --git a/api/.env.example b/api/.env.example index 38a096da0a..40e1c2dfdf 100644 --- a/api/.env.example +++ b/api/.env.example @@ -22,10 +22,10 @@ APP_WEB_URL=http://localhost:3000 # Files URL FILES_URL=http://localhost:5001 -# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. -# Set this to the internal Docker service URL for proper plugin file access. -# Example: INTERNAL_FILES_URL=http://api:5001 -INTERNAL_FILES_URL=http://127.0.0.1:5001 +# INTERNAL_FILES_URL is used by services running in Docker to reach the API file endpoints. +# For Docker Desktop (Mac/Windows), use http://host.docker.internal:5001 when the API runs on the host. +# For Docker Compose on Linux, use http://api:5001 when the API runs inside the Docker network. +INTERNAL_FILES_URL=http://host.docker.internal:5001 # TRIGGER URL TRIGGER_URL=http://localhost:5001 @@ -42,6 +42,8 @@ REFRESH_TOKEN_EXPIRE_DAYS=30 # redis configuration REDIS_HOST=localhost REDIS_PORT=6379 +# Optional: limit total connections in connection pool (unset for default) +# REDIS_MAX_CONNECTIONS=200 REDIS_USERNAME= REDIS_PASSWORD=difyai123456 REDIS_USE_SSL=false @@ -178,7 +180,7 @@ CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* COOKIE_DOMAIN= # Vector database configuration -# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. +# Supported values are `weaviate`, `oceanbase`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`, `hologres`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -186,7 +188,6 @@ VECTOR_INDEX_NAME_PREFIX=Vector_index # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih -WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word @@ -216,6 +217,20 @@ COUCHBASE_PASSWORD=password COUCHBASE_BUCKET_NAME=Embeddings COUCHBASE_SCOPE_NAME=_default +# Hologres configuration +# access_key_id is used as the PG username, access_key_secret is used as the PG password +HOLOGRES_HOST= +HOLOGRES_PORT=80 +HOLOGRES_DATABASE= +HOLOGRES_ACCESS_KEY_ID= +HOLOGRES_ACCESS_KEY_SECRET= +HOLOGRES_SCHEMA=public +HOLOGRES_TOKENIZER=jieba +HOLOGRES_DISTANCE_METHOD=Cosine +HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq +HOLOGRES_MAX_DEGREE=64 +HOLOGRES_EF_CONSTRUCTION=400 + # Milvus configuration MILVUS_URI=http://127.0.0.1:19530 MILVUS_TOKEN= @@ -722,24 +737,25 @@ SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000 -# Redis URL used for PubSub between API and +# Redis URL used for event bus between API and # celery worker # defaults to url constructed from `REDIS_*` # configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: +EVENT_BUS_REDIS_URL= +# Event transport type. Options are: # -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub +# - pubsub: normal Pub/Sub (at-most-once) +# - sharded: sharded Pub/Sub (at-most-once) +# - streams: Redis Streams (at-least-once, recommended to avoid subscriber races) # -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. +# Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs. +# Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce +# the risk of data loss from Redis auto-eviction under memory pressure. +# Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE. +EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub +# Whether to use Redis cluster mode while use redis as event bus. # It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false +EVENT_BUS_REDIS_USE_CLUSTERS=false # Whether to Enable human input timeout check task ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true diff --git a/api/.importlinter b/api/.importlinter index f74a1b667d..a836d09088 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,6 +1,7 @@ [importlinter] root_packages = core + dify_graph configs controllers extensions @@ -21,49 +22,36 @@ layers = runtime entities containers = - core.workflow + dify_graph ignore_imports = - core.workflow.nodes.base.node -> core.workflow.graph_events - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_events - core.workflow.nodes.loop.loop_node -> core.workflow.graph_events + dify_graph.nodes.base.node -> dify_graph.graph_events + dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events + dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory - core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph - core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine.command_channels - core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine - core.workflow.nodes.loop.loop_node -> core.workflow.graph - core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels + dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine + dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine # TODO(QuantumGhost): fix the import violation later - core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities + dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities [importlinter:contract:workflow-infrastructure-dependencies] name = Workflow Infrastructure Dependencies type = forbidden source_modules = - core.workflow + dify_graph forbidden_modules = extensions.ext_database extensions.ext_redis allow_indirect_imports = True ignore_imports = - core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database - core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.node -> extensions.ext_database - core.workflow.nodes.tool.tool_node -> extensions.ext_database - # TODO(QuantumGhost): use DI to avoid depending on global DB. - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database + dify_graph.nodes.llm.node -> extensions.ext_database + dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden source_modules = - core.workflow + dify_graph forbidden_modules = configs controllers @@ -101,136 +89,45 @@ forbidden_modules = core.trigger core.variables ignore_imports = - core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory - core.workflow.workflow_entry -> core.app.workflow.layers.observability - core.workflow.nodes.agent.agent_node -> core.model_manager - core.workflow.nodes.agent.agent_node -> core.provider_manager - core.workflow.nodes.agent.agent_node -> core.tools.tool_manager - core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory - core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory - core.workflow.nodes.llm.llm_utils -> core.model_manager - core.workflow.nodes.llm.protocols -> core.model_manager - core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.llm.node -> core.tools.signature - core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - core.workflow.nodes.tool.tool_node -> core.tools.tool_engine - core.workflow.nodes.tool.tool_node -> core.tools.tool_manager - core.workflow.workflow_entry -> configs - core.workflow.workflow_entry -> models.workflow - core.workflow.nodes.agent.agent_node -> core.agent.entities - core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities - core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities - core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_runtime.model_providers.__base.large_language_model - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - core.workflow.workflow_entry -> core.app.apps.exc - core.workflow.workflow_entry -> core.app.entities.app_invoke_entities - core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota - core.workflow.workflow_entry -> core.app.workflow.node_factory - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager - core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer - core.workflow.nodes.tool.tool_node -> models - core.workflow.nodes.agent.agent_node -> models.model - core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy - core.workflow.nodes.llm.node -> core.helper.code_executor - core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor - core.workflow.nodes.llm.node -> core.llm_generator.output_parser.errors - core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output - core.workflow.nodes.llm.node -> core.model_manager - core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.llm.node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods - core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods - core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset - core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service - core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task - core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor - core.workflow.nodes.llm.node -> models.dataset - core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer - core.workflow.nodes.llm.file_saver -> core.tools.signature - core.workflow.nodes.llm.file_saver -> core.tools.tool_file_manager - core.workflow.nodes.tool.tool_node -> core.tools.errors - core.workflow.nodes.agent.agent_node -> extensions.ext_database - core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database - core.workflow.nodes.llm.file_saver -> extensions.ext_database - core.workflow.nodes.llm.node -> extensions.ext_database - core.workflow.nodes.tool.tool_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> extensions.ext_database - core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository - core.workflow.workflow_entry -> extensions.otel.runtime - core.workflow.nodes.agent.agent_node -> models - core.workflow.nodes.base.node -> models.enums - core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota - core.workflow.nodes.llm.node -> models.model - core.workflow.workflow_entry -> models.enums - core.workflow.nodes.agent.agent_node -> services - core.workflow.nodes.tool.tool_node -> services - -[importlinter:contract:model-runtime-no-internal-imports] -name = Model Runtime Internal Imports -type = forbidden -source_modules = - core.model_runtime -forbidden_modules = - configs - controllers - extensions - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.model_manager - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables - core.workflow -ignore_imports = - core.model_runtime.model_providers.__base.ai_model -> configs - core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - core.model_runtime.model_providers.__base.large_language_model -> configs - core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - core.model_runtime.model_providers.model_provider_factory -> configs - core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - core.model_runtime.model_providers.model_provider_factory -> models.provider_ids + dify_graph.nodes.llm.llm_utils -> core.model_manager + dify_graph.nodes.llm.protocols -> core.model_manager + dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model + dify_graph.nodes.llm.node -> core.tools.signature + dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler + dify_graph.nodes.tool.tool_node -> core.tools.tool_engine + dify_graph.nodes.tool.tool_node -> core.tools.tool_manager + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model + dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager + dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager + dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer + dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors + dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output + dify_graph.nodes.llm.node -> core.model_manager + dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities + dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util + dify_graph.nodes.llm.node -> models.dataset + dify_graph.nodes.llm.file_saver -> core.tools.signature + dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager + dify_graph.nodes.tool.tool_node -> core.tools.errors + dify_graph.nodes.llm.node -> extensions.ext_database + dify_graph.nodes.llm.node -> models.model + dify_graph.nodes.tool.tool_node -> services + dify_graph.model_runtime.model_providers.__base.ai_model -> configs + dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis + dify_graph.model_runtime.model_providers.__base.large_language_model -> configs + dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type + dify_graph.model_runtime.model_providers.model_provider_factory -> configs + dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis + dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids [importlinter:contract:rsc] name = RSC @@ -239,7 +136,7 @@ layers = graph_engine response_coordinator containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:worker] name = Worker @@ -248,7 +145,7 @@ layers = graph_engine worker containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:graph-engine-architecture] name = Graph Engine Architecture @@ -264,28 +161,28 @@ layers = worker_management domain containers = - core.workflow.graph_engine + dify_graph.graph_engine [importlinter:contract:domain-isolation] name = Domain Model Isolation type = forbidden source_modules = - core.workflow.graph_engine.domain + dify_graph.graph_engine.domain forbidden_modules = - core.workflow.graph_engine.worker_management - core.workflow.graph_engine.command_channels - core.workflow.graph_engine.layers - core.workflow.graph_engine.protocols + dify_graph.graph_engine.worker_management + dify_graph.graph_engine.command_channels + dify_graph.graph_engine.layers + dify_graph.graph_engine.protocols [importlinter:contract:worker-management] name = Worker Management type = forbidden source_modules = - core.workflow.graph_engine.worker_management + dify_graph.graph_engine.worker_management forbidden_modules = - core.workflow.graph_engine.orchestration - core.workflow.graph_engine.command_processing - core.workflow.graph_engine.event_management + dify_graph.graph_engine.orchestration + dify_graph.graph_engine.command_processing + dify_graph.graph_engine.event_management [importlinter:contract:graph-traversal-components] @@ -295,11 +192,11 @@ layers = edge_processor skip_propagator containers = - core.workflow.graph_engine.graph_traversal + dify_graph.graph_engine.graph_traversal [importlinter:contract:command-channels] name = Command Channels Independence type = independence modules = - core.workflow.graph_engine.command_channels.in_memory_channel - core.workflow.graph_engine.command_channels.redis_channel + dify_graph.graph_engine.command_channels.in_memory_channel + dify_graph.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index 3301452ad9..b0947eb619 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"core/model_runtime/callbacks/base_callback.py" = ["T201"] +"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/AGENTS.md b/api/AGENTS.md index 13adb42276..8e5d9f600d 100644 --- a/api/AGENTS.md +++ b/api/AGENTS.md @@ -62,7 +62,23 @@ This is the default standard for backend code in this repo. Follow it for new co - Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values). - Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason. -- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance: +- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`. +- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional). +- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown. + +```python +from datetime import datetime +from typing import NotRequired, TypedDict + + +class UserProfile(TypedDict): + user_id: str + email: str + created_at: datetime + nickname: NotRequired[str] +``` + +- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance: ```python from datetime import datetime diff --git a/api/Dockerfile b/api/Dockerfile index a08d4e3aab..7e0a439954 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}" # Download nltk data RUN mkdir -p /usr/local/share/nltk_data \ - && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \ + && NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \ && chmod -R 755 /usr/local/share/nltk_data ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache diff --git a/api/app_factory.py b/api/app_factory.py index dcbc821687..066eb2ae2c 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,16 +1,45 @@ import logging import time +from flask import request from opentelemetry.trace import get_current_span from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID from configs import dify_config from contexts.wrapper import RecyclableContextVar +from controllers.console.error import UnauthorizedAndForceLogout from core.logging.context import init_request_context from dify_app import DifyApp +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import LicenseStatus logger = logging.getLogger(__name__) +# Console bootstrap APIs exempt from license check. +# Defined at module level to avoid per-request tuple construction. +# - system-features: license status for expiry UI (GlobalPublicStoreProvider) +# - setup: install/setup status check (AppInitializer) +# - init: init password validation for fresh install (InitPasswordPopup) +# - login: auto-login after setup completion (InstallForm) +# - features: billing/plan features (ProviderContextProvider) +# - account/profile: login check + user profile (AppContextProvider, useIsLogin) +# - workspaces/current: workspace + model providers (AppContextProvider) +# - version: version check (AppContextProvider) +# - activate/check: invitation link validation (signin page) +# Without these exemptions, the signin page triggers location.reload() +# on unauthorized_and_force_logout, causing an infinite loop. +_CONSOLE_EXEMPT_PREFIXES = ( + "/console/api/system-features", + "/console/api/setup", + "/console/api/init", + "/console/api/login", + "/console/api/features", + "/console/api/account/profile", + "/console/api/workspaces/current", + "/console/api/version", + "/console/api/activate/check", +) + # ---------------------------- # Application Factory Function @@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp: init_request_context() RecyclableContextVar.increment_thread_recycles() + # Enterprise license validation for API endpoints (both console and webapp) + # When license expires, block all API access except bootstrap endpoints needed + # for the frontend to load the license expiration page without infinite reloads. + if dify_config.ENTERPRISE_ENABLED: + is_console_api = request.path.startswith("/console/api/") + is_webapp_api = request.path.startswith("/api/") + + if is_console_api or is_webapp_api: + if is_console_api: + is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES) + else: # webapp API + is_exempt = request.path.startswith("/api/system-features") + + if not is_exempt: + try: + # Check license status (cached — see EnterpriseService for TTL details) + license_status = EnterpriseService.get_cached_license_status() + if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST): + raise UnauthorizedAndForceLogout( + f"Enterprise license is {license_status}. Please contact your administrator." + ) + if license_status is None: + raise UnauthorizedAndForceLogout( + "Unable to verify enterprise license. Please contact your administrator." + ) + except UnauthorizedAndForceLogout: + raise + except Exception: + logger.exception("Failed to check enterprise license status") + raise UnauthorizedAndForceLogout( + "Unable to verify enterprise license. Please contact your administrator." + ) + # add after request hook for injecting trace headers from OpenTelemetry span context # Only adds headers when OTEL is enabled and has valid context @dify_app.after_request diff --git a/api/commands.py b/api/commands.py deleted file mode 100644 index 75b17df78e..0000000000 --- a/api/commands.py +++ /dev/null @@ -1,2670 +0,0 @@ -import base64 -import datetime -import json -import logging -import secrets -import time -from typing import Any - -import click -import sqlalchemy as sa -from flask import current_app -from pydantic import TypeAdapter -from sqlalchemy import select -from sqlalchemy.exc import SQLAlchemyError -from sqlalchemy.orm import sessionmaker - -from configs import dify_config -from constants.languages import languages -from core.helper import encrypter -from core.plugin.entities.plugin_daemon import CredentialType -from core.plugin.impl.plugin import PluginInstaller -from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.datasource.vdb.vector_type import VectorType -from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.rag.models.document import ChildDocument, Document -from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params -from events.app_event import app_was_created -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from extensions.ext_storage import storage -from extensions.storage.opendal_storage import OpenDALStorage -from extensions.storage.storage_type import StorageType -from libs.db_migration_lock import DbMigrationAutoRenewLock -from libs.helper import email as email_validate -from libs.password import hash_password, password_pattern, valid_password -from libs.rsa import generate_key_pair -from models import Tenant -from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment -from models.dataset import Document as DatasetDocument -from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile -from models.oauth import DatasourceOauthParamConfig, DatasourceProvider -from models.provider import Provider, ProviderModel -from models.provider_ids import DatasourceProviderID, ToolProviderID -from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding -from models.tools import ToolOAuthSystemClient -from services.account_service import AccountService, RegisterService, TenantService -from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs -from services.plugin.data_migration import PluginDataMigration -from services.plugin.plugin_migration import PluginMigration -from services.plugin.plugin_service import PluginService -from services.retention.conversation.messages_clean_policy import create_message_clean_policy -from services.retention.conversation.messages_clean_service import MessagesCleanService -from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup -from tasks.remove_app_and_related_data_task import delete_draft_variables_batch - -logger = logging.getLogger(__name__) - -DB_UPGRADE_LOCK_TTL_SECONDS = 60 - - -@click.command("reset-password", help="Reset the account password.") -@click.option("--email", prompt=True, help="Account email to reset password for") -@click.option("--new-password", prompt=True, help="New password") -@click.option("--password-confirm", prompt=True, help="Confirm new password") -def reset_password(email, new_password, password_confirm): - """ - Reset password of owner account - Only available in SELF_HOSTED mode - """ - if str(new_password).strip() != str(password_confirm).strip(): - click.echo(click.style("Passwords do not match.", fg="red")) - return - normalized_email = email.strip().lower() - - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) - - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return - - try: - valid_password(new_password) - except: - click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) - return - - # generate password salt - salt = secrets.token_bytes(16) - base64_salt = base64.b64encode(salt).decode() - - # encrypt password with salt - password_hashed = hash_password(new_password, salt) - base64_password_hashed = base64.b64encode(password_hashed).decode() - account.password = base64_password_hashed - account.password_salt = base64_salt - AccountService.reset_login_error_rate_limit(normalized_email) - click.echo(click.style("Password reset successfully.", fg="green")) - - -@click.command("reset-email", help="Reset the account email.") -@click.option("--email", prompt=True, help="Current account email") -@click.option("--new-email", prompt=True, help="New email") -@click.option("--email-confirm", prompt=True, help="Confirm new email") -def reset_email(email, new_email, email_confirm): - """ - Replace account email - :return: - """ - if str(new_email).strip() != str(email_confirm).strip(): - click.echo(click.style("New emails do not match.", fg="red")) - return - normalized_new_email = new_email.strip().lower() - - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) - - if not account: - click.echo(click.style(f"Account not found for email: {email}", fg="red")) - return - - try: - email_validate(normalized_new_email) - except: - click.echo(click.style(f"Invalid email: {new_email}", fg="red")) - return - - account.email = normalized_new_email - click.echo(click.style("Email updated successfully.", fg="green")) - - -@click.command( - "reset-encrypt-key-pair", - help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " - "After the reset, all LLM credentials will become invalid, " - "requiring re-entry." - "Only support SELF_HOSTED mode.", -) -@click.confirmation_option( - prompt=click.style( - "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" - ) -) -def reset_encrypt_key_pair(): - """ - Reset the encrypted key pair of workspace for encrypt LLM credentials. - After the reset, all LLM credentials will become invalid, requiring re-entry. - Only support SELF_HOSTED mode. - """ - if dify_config.EDITION != "SELF_HOSTED": - click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) - return - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - tenants = session.query(Tenant).all() - for tenant in tenants: - if not tenant: - click.echo(click.style("No workspaces found. Run /install first.", fg="red")) - return - - tenant.encrypt_public_key = generate_key_pair(tenant.id) - - session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() - session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete() - - click.echo( - click.style( - f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", - fg="green", - ) - ) - - -@click.command("vdb-migrate", help="Migrate vector db.") -@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") -def vdb_migrate(scope: str): - if scope in {"knowledge", "all"}: - migrate_knowledge_vector_database() - if scope in {"annotation", "all"}: - migrate_annotation_vector_database() - - -def migrate_annotation_vector_database(): - """ - Migrate annotation datas to target vector database . - """ - click.echo(click.style("Starting annotation data migration.", fg="green")) - create_count = 0 - skipped_count = 0 - total_count = 0 - page = 1 - while True: - try: - # get apps info - per_page = 50 - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - apps = ( - session.query(App) - .where(App.status == "normal") - .order_by(App.created_at.desc()) - .limit(per_page) - .offset((page - 1) * per_page) - .all() - ) - if not apps: - break - except SQLAlchemyError: - raise - - page += 1 - for app in apps: - total_count = total_count + 1 - click.echo( - f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." - ) - try: - click.echo(f"Creating app annotation index: {app.id}") - with sessionmaker(db.engine, expire_on_commit=False).begin() as session: - app_annotation_setting = ( - session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first() - ) - - if not app_annotation_setting: - skipped_count = skipped_count + 1 - click.echo(f"App annotation setting disabled: {app.id}") - continue - # get dataset_collection_binding info - dataset_collection_binding = ( - session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) - .first() - ) - if not dataset_collection_binding: - click.echo(f"App annotation collection binding not found: {app.id}") - continue - annotations = session.scalars( - select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) - ).all() - dataset = Dataset( - id=app.id, - tenant_id=app.tenant_id, - indexing_technique="high_quality", - embedding_model_provider=dataset_collection_binding.provider_name, - embedding_model=dataset_collection_binding.model_name, - collection_binding_id=dataset_collection_binding.id, - ) - documents = [] - if annotations: - for annotation in annotations: - document = Document( - page_content=annotation.question_text, - metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, - ) - documents.append(document) - - vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) - click.echo(f"Migrating annotations for app: {app.id}.") - - try: - vector.delete() - click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) - raise e - if documents: - try: - click.echo( - click.style( - f"Creating vector index with {len(documents)} annotations for app {app.id}.", - fg="green", - ) - ) - vector.create(documents) - click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) - raise e - click.echo(f"Successfully migrated app annotation {app.id}.") - create_count += 1 - except Exception as e: - click.echo( - click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") - ) - continue - - click.echo( - click.style( - f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.", - fg="green", - ) - ) - - -def migrate_knowledge_vector_database(): - """ - Migrate vector database datas to target vector database . - """ - click.echo(click.style("Starting vector database migration.", fg="green")) - create_count = 0 - skipped_count = 0 - total_count = 0 - vector_type = dify_config.VECTOR_STORE - upper_collection_vector_types = { - VectorType.MILVUS, - VectorType.PGVECTOR, - VectorType.VASTBASE, - VectorType.RELYT, - VectorType.WEAVIATE, - VectorType.ORACLE, - VectorType.ELASTICSEARCH, - VectorType.OPENGAUSS, - VectorType.TABLESTORE, - VectorType.MATRIXONE, - } - lower_collection_vector_types = { - VectorType.ANALYTICDB, - VectorType.CHROMA, - VectorType.MYSCALE, - VectorType.PGVECTO_RS, - VectorType.TIDB_VECTOR, - VectorType.OPENSEARCH, - VectorType.TENCENT, - VectorType.BAIDU, - VectorType.VIKINGDB, - VectorType.UPSTASH, - VectorType.COUCHBASE, - VectorType.OCEANBASE, - } - page = 1 - while True: - try: - stmt = ( - select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) - ) - - datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - if not datasets.items: - break - except SQLAlchemyError: - raise - - page += 1 - for dataset in datasets: - total_count = total_count + 1 - click.echo( - f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." - ) - try: - click.echo(f"Creating dataset vector database index: {dataset.id}") - if dataset.index_struct_dict: - if dataset.index_struct_dict["type"] == vector_type: - skipped_count = skipped_count + 1 - continue - collection_name = "" - dataset_id = dataset.id - if vector_type in upper_collection_vector_types: - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - elif vector_type == VectorType.QDRANT: - if dataset.collection_binding_id: - dataset_collection_binding = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == dataset.collection_binding_id) - .one_or_none() - ) - if dataset_collection_binding: - collection_name = dataset_collection_binding.collection_name - else: - raise ValueError("Dataset Collection Binding not found") - else: - collection_name = Dataset.gen_collection_name_by_id(dataset_id) - - elif vector_type in lower_collection_vector_types: - collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() - else: - raise ValueError(f"Vector store {vector_type} is not supported.") - - index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} - dataset.index_struct = json.dumps(index_struct_dict) - vector = Vector(dataset) - click.echo(f"Migrating dataset {dataset.id}.") - - try: - vector.delete() - click.echo( - click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green") - ) - except Exception as e: - click.echo( - click.style( - f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" - ) - ) - raise e - - dataset_documents = db.session.scalars( - select(DatasetDocument).where( - DatasetDocument.dataset_id == dataset.id, - DatasetDocument.indexing_status == "completed", - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ) - ).all() - - documents = [] - segments_count = 0 - for dataset_document in dataset_documents: - segments = db.session.scalars( - select(DocumentSegment).where( - DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - ) - ).all() - - for segment in segments: - document = Document( - page_content=segment.content, - metadata={ - "doc_id": segment.index_node_id, - "doc_hash": segment.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - if dataset_document.doc_form == "hierarchical_model": - child_chunks = segment.get_child_chunks() - if child_chunks: - child_documents = [] - for child_chunk in child_chunks: - child_document = ChildDocument( - page_content=child_chunk.content, - metadata={ - "doc_id": child_chunk.index_node_id, - "doc_hash": child_chunk.index_node_hash, - "document_id": segment.document_id, - "dataset_id": segment.dataset_id, - }, - ) - child_documents.append(child_document) - document.children = child_documents - - documents.append(document) - segments_count = segments_count + 1 - - if documents: - try: - click.echo( - click.style( - f"Creating vector index with {len(documents)} documents of {segments_count}" - f" segments for dataset {dataset.id}.", - fg="green", - ) - ) - all_child_documents = [] - for doc in documents: - if doc.children: - all_child_documents.extend(doc.children) - vector.create(documents) - if all_child_documents: - vector.create(all_child_documents) - click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) - except Exception as e: - click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) - raise e - db.session.add(dataset) - db.session.commit() - click.echo(f"Successfully migrated dataset {dataset.id}.") - create_count += 1 - except Exception as e: - db.session.rollback() - click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) - continue - - click.echo( - click.style( - f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green" - ) - ) - - -@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") -def convert_to_agent_apps(): - """ - Convert Agent Assistant to Agent App. - """ - click.echo(click.style("Starting convert to agent apps.", fg="green")) - - proceeded_app_ids = [] - - while True: - # fetch first 1000 apps - sql_query = """SELECT a.id AS id FROM apps a - INNER JOIN app_model_configs am ON a.app_model_config_id=am.id - WHERE a.mode = 'chat' - AND am.agent_mode is not null - AND ( - am.agent_mode like '%"strategy": "function_call"%' - OR am.agent_mode like '%"strategy": "react"%' - ) - AND ( - am.agent_mode like '{"enabled": true%' - OR am.agent_mode like '{"max_iteration": %' - ) ORDER BY a.created_at DESC LIMIT 1000 - """ - - with db.engine.begin() as conn: - rs = conn.execute(sa.text(sql_query)) - - apps = [] - for i in rs: - app_id = str(i.id) - if app_id not in proceeded_app_ids: - proceeded_app_ids.append(app_id) - app = db.session.query(App).where(App.id == app_id).first() - if app is not None: - apps.append(app) - - if len(apps) == 0: - break - - for app in apps: - click.echo(f"Converting app: {app.id}") - - try: - app.mode = AppMode.AGENT_CHAT - db.session.commit() - - # update conversation mode to agent - db.session.query(Conversation).where(Conversation.app_id == app.id).update( - {Conversation.mode: AppMode.AGENT_CHAT} - ) - - db.session.commit() - click.echo(click.style(f"Converted app: {app.id}", fg="green")) - except Exception as e: - click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) - - click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) - - -@click.command("add-qdrant-index", help="Add Qdrant index.") -@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.") -def add_qdrant_index(field: str): - click.echo(click.style("Starting Qdrant index creation.", fg="green")) - - create_count = 0 - - try: - bindings = db.session.query(DatasetCollectionBinding).all() - if not bindings: - click.echo(click.style("No dataset collection bindings found.", fg="red")) - return - import qdrant_client - from qdrant_client.http.exceptions import UnexpectedResponse - from qdrant_client.http.models import PayloadSchemaType - - from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig - - for binding in bindings: - if dify_config.QDRANT_URL is None: - raise ValueError("Qdrant URL is required.") - qdrant_config = QdrantConfig( - endpoint=dify_config.QDRANT_URL, - api_key=dify_config.QDRANT_API_KEY, - root_path=current_app.root_path, - timeout=dify_config.QDRANT_CLIENT_TIMEOUT, - grpc_port=dify_config.QDRANT_GRPC_PORT, - prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, - ) - try: - params = qdrant_config.to_qdrant_params() - # Check the type before using - if isinstance(params, PathQdrantParams): - # PathQdrantParams case - client = qdrant_client.QdrantClient(path=params.path) - else: - # UrlQdrantParams case - params is UrlQdrantParams - client = qdrant_client.QdrantClient( - url=params.url, - api_key=params.api_key, - timeout=int(params.timeout), - verify=params.verify, - grpc_port=params.grpc_port, - prefer_grpc=params.prefer_grpc, - ) - # create payload index - client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) - create_count += 1 - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red")) - continue - # Some other error occurred, so re-raise the exception - else: - click.echo( - click.style( - f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red" - ) - ) - - except Exception: - click.echo(click.style("Failed to create Qdrant client.", fg="red")) - - click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) - - -@click.command("old-metadata-migration", help="Old metadata migration.") -def old_metadata_migration(): - """ - Old metadata migration. - """ - click.echo(click.style("Starting old metadata migration.", fg="green")) - - page = 1 - while True: - try: - stmt = ( - select(DatasetDocument) - .where(DatasetDocument.doc_metadata.is_not(None)) - .order_by(DatasetDocument.created_at.desc()) - ) - documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) - except SQLAlchemyError: - raise - if not documents: - break - for document in documents: - if document.doc_metadata: - doc_metadata = document.doc_metadata - for key in doc_metadata: - for field in BuiltInField: - if field.value == key: - break - else: - dataset_metadata = ( - db.session.query(DatasetMetadata) - .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) - .first() - ) - if not dataset_metadata: - dataset_metadata = DatasetMetadata( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - name=key, - type="string", - created_by=document.created_by, - ) - db.session.add(dataset_metadata) - db.session.flush() - dataset_metadata_binding = DatasetMetadataBinding( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - metadata_id=dataset_metadata.id, - document_id=document.id, - created_by=document.created_by, - ) - db.session.add(dataset_metadata_binding) - else: - dataset_metadata_binding = ( - db.session.query(DatasetMetadataBinding) # type: ignore - .where( - DatasetMetadataBinding.dataset_id == document.dataset_id, - DatasetMetadataBinding.document_id == document.id, - DatasetMetadataBinding.metadata_id == dataset_metadata.id, - ) - .first() - ) - if not dataset_metadata_binding: - dataset_metadata_binding = DatasetMetadataBinding( - tenant_id=document.tenant_id, - dataset_id=document.dataset_id, - metadata_id=dataset_metadata.id, - document_id=document.id, - created_by=document.created_by, - ) - db.session.add(dataset_metadata_binding) - db.session.commit() - page += 1 - click.echo(click.style("Old metadata migration completed.", fg="green")) - - -@click.command("create-tenant", help="Create account and tenant.") -@click.option("--email", prompt=True, help="Tenant account email.") -@click.option("--name", prompt=True, help="Workspace name.") -@click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant(email: str, language: str | None = None, name: str | None = None): - """ - Create tenant account - """ - if not email: - click.echo(click.style("Email is required.", fg="red")) - return - - # Create account - email = email.strip().lower() - - if "@" not in email: - click.echo(click.style("Invalid email address.", fg="red")) - return - - account_name = email.split("@")[0] - - if language not in languages: - language = "en-US" - - # Validates name encoding for non-Latin characters. - name = name.strip().encode("utf-8").decode("utf-8") if name else None - - # generate random password - new_password = secrets.token_urlsafe(16) - - # register account - account = RegisterService.register( - email=email, - name=account_name, - password=new_password, - language=language, - create_workspace_required=False, - ) - TenantService.create_owner_tenant_if_not_exist(account, name) - - click.echo( - click.style( - f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", - fg="green", - ) - ) - - -@click.command("upgrade-db", help="Upgrade the database") -def upgrade_db(): - click.echo("Preparing database migration...") - lock = DbMigrationAutoRenewLock( - redis_client=redis_client, - name="db_upgrade_lock", - ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, - logger=logger, - log_context="db_migration", - ) - if lock.acquire(blocking=False): - migration_succeeded = False - try: - click.echo(click.style("Starting database migration.", fg="green")) - - # run db migration - import flask_migrate - - flask_migrate.upgrade() - - migration_succeeded = True - click.echo(click.style("Database migration successful!", fg="green")) - - except Exception as e: - logger.exception("Failed to execute database migration") - click.echo(click.style(f"Database migration failed: {e}", fg="red")) - raise SystemExit(1) - finally: - status = "successful" if migration_succeeded else "failed" - lock.release_safely(status=status) - else: - click.echo("Database migration skipped") - - -@click.command("fix-app-site-missing", help="Fix app related site missing issue.") -def fix_app_site_missing(): - """ - Fix app related site missing issue. - """ - click.echo(click.style("Starting fix for missing app-related sites.", fg="green")) - - failed_app_ids = [] - while True: - sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id -where sites.id is null limit 1000""" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(sql)) - - processed_count = 0 - for i in rs: - processed_count += 1 - app_id = str(i.id) - - if app_id in failed_app_ids: - continue - - try: - app = db.session.query(App).where(App.id == app_id).first() - if not app: - logger.info("App %s not found", app_id) - continue - - tenant = app.tenant - if tenant: - accounts = tenant.get_accounts() - if not accounts: - logger.info("Fix failed for app %s", app.id) - continue - - account = accounts[0] - logger.info("Fixing missing site for app %s", app.id) - app_was_created.send(app, account=account) - except Exception: - failed_app_ids.append(app_id) - click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) - logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id) - continue - - if not processed_count: - break - - click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) - - -@click.command("migrate-data-for-plugin", help="Migrate data for plugin.") -def migrate_data_for_plugin(): - """ - Migrate data for plugin. - """ - click.echo(click.style("Starting migrate data for plugin.", fg="white")) - - PluginDataMigration.migrate() - - click.echo(click.style("Migrate data for plugin completed.", fg="green")) - - -@click.command("extract-plugins", help="Extract plugins.") -@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") -@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) -def extract_plugins(output_file: str, workers: int): - """ - Extract plugins. - """ - click.echo(click.style("Starting extract plugins.", fg="white")) - - PluginMigration.extract_plugins(output_file, workers) - - click.echo(click.style("Extract plugins completed.", fg="green")) - - -@click.command("extract-unique-identifiers", help="Extract unique identifiers.") -@click.option( - "--output_file", - prompt=True, - help="The file to store the extracted unique identifiers.", - default="unique_identifiers.json", -) -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -def extract_unique_plugins(output_file: str, input_file: str): - """ - Extract unique plugins. - """ - click.echo(click.style("Starting extract unique plugins.", fg="white")) - - PluginMigration.extract_unique_plugins_to_file(input_file, output_file) - - click.echo(click.style("Extract unique plugins completed.", fg="green")) - - -@click.command("install-plugins", help="Install plugins.") -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -@click.option( - "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" -) -@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) -def install_plugins(input_file: str, output_file: str, workers: int): - """ - Install plugins. - """ - click.echo(click.style("Starting install plugins.", fg="white")) - - PluginMigration.install_plugins(input_file, output_file, workers) - - click.echo(click.style("Install plugins completed.", fg="green")) - - -@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") -@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30) -@click.option("--batch", prompt=True, help="The batch size to clear free plan tenant expired logs.", default=100) -@click.option( - "--tenant_ids", - prompt=True, - multiple=True, - help="The tenant ids to clear free plan tenant expired logs.", -) -def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[str]): - """ - Clear free plan tenant expired logs. - """ - click.echo(click.style("Starting clear free plan tenant expired logs.", fg="white")) - - ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) - - click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) - - -@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") -@click.option( - "--before-days", - "--days", - default=30, - show_default=True, - type=click.IntRange(min=0), - help="Delete workflow runs created before N days ago.", -) -@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") -@click.option( - "--from-days-ago", - default=None, - type=click.IntRange(min=0), - help="Lower bound in days ago (older). Must be paired with --to-days-ago.", -) -@click.option( - "--to-days-ago", - default=None, - type=click.IntRange(min=0), - help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", -) -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option( - "--dry-run", - is_flag=True, - help="Preview cleanup results without deleting any workflow run data.", -) -def clean_workflow_runs( - before_days: int, - batch_size: int, - from_days_ago: int | None, - to_days_ago: int | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - dry_run: bool, -): - """ - Clean workflow runs and related workflow data for free tenants. - """ - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - - if (from_days_ago is None) ^ (to_days_ago is None): - raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") - - if from_days_ago is not None and to_days_ago is not None: - if start_from or end_before: - raise click.UsageError("Choose either day offsets or explicit dates, not both.") - if from_days_ago <= to_days_ago: - raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") - now = datetime.datetime.now() - start_from = now - datetime.timedelta(days=from_days_ago) - end_before = now - datetime.timedelta(days=to_days_ago) - before_days = 0 - - start_time = datetime.datetime.now(datetime.UTC) - click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) - - WorkflowRunCleanup( - days=before_days, - batch_size=batch_size, - start_from=start_from, - end_before=end_before, - dry_run=dry_run, - ).run() - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - click.echo( - click.style( - f"Workflow run cleanup completed. start={start_time.isoformat()} " - f"end={end_time.isoformat()} duration={elapsed}", - fg="green", - ) - ) - - -@click.command( - "archive-workflow-runs", - help="Archive workflow runs for paid plan tenants to S3-compatible storage.", -) -@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") -@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") -@click.option( - "--from-days-ago", - default=None, - type=click.IntRange(min=0), - help="Lower bound in days ago (older). Must be paired with --to-days-ago.", -) -@click.option( - "--to-days-ago", - default=None, - type=click.IntRange(min=0), - help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", -) -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Archive runs created at or after this timestamp (UTC if no timezone).", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Archive runs created before this timestamp (UTC if no timezone).", -) -@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") -@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") -@click.option("--dry-run", is_flag=True, help="Preview without archiving.") -@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") -def archive_workflow_runs( - tenant_ids: str | None, - before_days: int, - from_days_ago: int | None, - to_days_ago: int | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - batch_size: int, - workers: int, - limit: int | None, - dry_run: bool, - delete_after_archive: bool, -): - """ - Archive workflow runs for paid plan tenants older than the specified days. - - This command archives the following tables to storage: - - workflow_node_executions - - workflow_node_execution_offload - - workflow_pauses - - workflow_pause_reasons - - workflow_trigger_logs - - The workflow_runs and workflow_app_logs tables are preserved for UI listing. - """ - from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver - - run_started_at = datetime.datetime.now(datetime.UTC) - click.echo( - click.style( - f"Starting workflow run archiving at {run_started_at.isoformat()}.", - fg="white", - ) - ) - - if (start_from is None) ^ (end_before is None): - click.echo(click.style("start-from and end-before must be provided together.", fg="red")) - return - - if (from_days_ago is None) ^ (to_days_ago is None): - click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) - return - - if from_days_ago is not None and to_days_ago is not None: - if start_from or end_before: - click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) - return - if from_days_ago <= to_days_ago: - click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) - return - now = datetime.datetime.now() - start_from = now - datetime.timedelta(days=from_days_ago) - end_before = now - datetime.timedelta(days=to_days_ago) - before_days = 0 - - if start_from and end_before and start_from >= end_before: - click.echo(click.style("start-from must be earlier than end-before.", fg="red")) - return - if workers < 1: - click.echo(click.style("workers must be at least 1.", fg="red")) - return - - archiver = WorkflowRunArchiver( - days=before_days, - batch_size=batch_size, - start_from=start_from, - end_before=end_before, - workers=workers, - tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, - limit=limit, - dry_run=dry_run, - delete_after_archive=delete_after_archive, - ) - summary = archiver.run() - click.echo( - click.style( - f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " - f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " - f"time={summary.total_elapsed_time:.2f}s", - fg="cyan", - ) - ) - - run_finished_at = datetime.datetime.now(datetime.UTC) - elapsed = run_finished_at - run_started_at - click.echo( - click.style( - f"Workflow run archiving completed. start={run_started_at.isoformat()} " - f"end={run_finished_at.isoformat()} duration={elapsed}", - fg="green", - ) - ) - - -@click.command( - "restore-workflow-runs", - help="Restore archived workflow runs from S3-compatible storage.", -) -@click.option( - "--tenant-ids", - required=False, - help="Tenant IDs (comma-separated).", -) -@click.option("--run-id", required=False, help="Workflow run ID to restore.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") -@click.option("--dry-run", is_flag=True, help="Preview without restoring.") -def restore_workflow_runs( - tenant_ids: str | None, - run_id: str | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - workers: int, - limit: int, - dry_run: bool, -): - """ - Restore an archived workflow run from storage to the database. - - This restores the following tables: - - workflow_node_executions - - workflow_node_execution_offload - - workflow_pauses - - workflow_pause_reasons - - workflow_trigger_logs - """ - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - parsed_tenant_ids = None - if tenant_ids: - parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] - if not parsed_tenant_ids: - raise click.BadParameter("tenant-ids must not be empty") - - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - if run_id is None and (start_from is None or end_before is None): - raise click.UsageError("--start-from and --end-before are required for batch restore.") - if workers < 1: - raise click.BadParameter("workers must be at least 1") - - start_time = datetime.datetime.now(datetime.UTC) - click.echo( - click.style( - f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", - fg="white", - ) - ) - - restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) - if run_id: - results = [restorer.restore_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = restorer.restore_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, - ) - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Restore completed successfully. success={successes} duration={elapsed}", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", - ) - ) - - -@click.command( - "delete-archived-workflow-runs", - help="Delete archived workflow runs from the database.", -) -@click.option( - "--tenant-ids", - required=False, - help="Tenant IDs (comma-separated).", -) -@click.option("--run-id", required=False, help="Workflow run ID to delete.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - default=None, - help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", -) -@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") -@click.option("--dry-run", is_flag=True, help="Preview without deleting.") -def delete_archived_workflow_runs( - tenant_ids: str | None, - run_id: str | None, - start_from: datetime.datetime | None, - end_before: datetime.datetime | None, - limit: int, - dry_run: bool, -): - """ - Delete archived workflow runs from the database. - """ - from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion - - parsed_tenant_ids = None - if tenant_ids: - parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] - if not parsed_tenant_ids: - raise click.BadParameter("tenant-ids must not be empty") - - if (start_from is None) ^ (end_before is None): - raise click.UsageError("--start-from and --end-before must be provided together.") - if run_id is None and (start_from is None or end_before is None): - raise click.UsageError("--start-from and --end-before are required for batch delete.") - - start_time = datetime.datetime.now(datetime.UTC) - target_desc = f"workflow run {run_id}" if run_id else "workflow runs" - click.echo( - click.style( - f"Starting delete of {target_desc} at {start_time.isoformat()}.", - fg="white", - ) - ) - - deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) - if run_id: - results = [deleter.delete_by_run_id(run_id)] - else: - assert start_from is not None - assert end_before is not None - results = deleter.delete_batch( - parsed_tenant_ids, - start_date=start_from, - end_date=end_before, - limit=limit, - ) - - for result in results: - if result.success: - click.echo( - click.style( - f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " - f"workflow run {result.run_id} (tenant={result.tenant_id})", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Failed to delete workflow run {result.run_id}: {result.error}", - fg="red", - ) - ) - - end_time = datetime.datetime.now(datetime.UTC) - elapsed = end_time - start_time - - successes = sum(1 for result in results if result.success) - failures = len(results) - successes - - if failures == 0: - click.echo( - click.style( - f"Delete completed successfully. success={successes} duration={elapsed}", - fg="green", - ) - ) - else: - click.echo( - click.style( - f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", - fg="red", - ) - ) - - -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -@click.command("clear-orphaned-file-records", help="Clear orphaned file records.") -def clear_orphaned_file_records(force: bool): - """ - Clear orphaned file records in the database. - """ - - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "id_column": "id", "key_column": "key"}, - {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, - ] - ids_tables = [ - {"type": "uuid", "table": "message_files", "column": "upload_file_id"}, - {"type": "text", "table": "documents", "column": "data_source_info"}, - {"type": "text", "table": "document_segments", "column": "content"}, - {"type": "text", "table": "messages", "column": "answer"}, - {"type": "text", "table": "workflow_node_executions", "column": "inputs"}, - {"type": "text", "table": "workflow_node_executions", "column": "process_data"}, - {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, - {"type": "text", "table": "conversations", "column": "introduction"}, - {"type": "text", "table": "conversations", "column": "system_instruction"}, - {"type": "text", "table": "accounts", "column": "avatar"}, - {"type": "text", "table": "apps", "column": "icon"}, - {"type": "text", "table": "sites", "column": "icon"}, - {"type": "json", "table": "messages", "column": "inputs"}, - {"type": "json", "table": "messages", "column": "message"}, - ] - - # notify user and ask for confirmation - click.echo( - click.style( - "This command will first find and delete orphaned file records from the message_files table,", fg="yellow" - ) - ) - click.echo( - click.style( - "and then it will find and delete orphaned file records in the following tables:", - fg="yellow", - ) - ) - for files_table in files_tables: - click.echo(click.style(f"- {files_table['table']}", fg="yellow")) - click.echo( - click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow") - ) - for ids_table in ids_tables: - click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) - click.echo("") - - click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) - click.echo( - click.style( - ( - "Since not all patterns have been fully tested, " - "please note that this command may delete unintended file records." - ), - fg="yellow", - ) - ) - click.echo( - click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow") - ) - click.echo( - click.style( - ( - "It is also recommended to run this during the maintenance window, " - "as this may cause high load on your instance." - ), - fg="yellow", - ) - ) - if not force: - click.confirm("Do you want to proceed?", abort=True) - - # start the cleanup process - click.echo(click.style("Starting orphaned file records cleanup.", fg="white")) - - # clean up the orphaned records in the message_files table where message_id doesn't exist in messages table - try: - click.echo( - click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white") - ) - query = ( - "SELECT mf.id, mf.message_id " - "FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id " - "WHERE m.id IS NULL" - ) - orphaned_message_files = [] - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) - - if orphaned_message_files: - click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white")) - for record in orphaned_message_files: - click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black")) - - if not force: - click.confirm( - ( - f"Do you want to proceed " - f"to delete all {len(orphaned_message_files)} orphaned message_files records?" - ), - abort=True, - ) - - click.echo(click.style("- Deleting orphaned message_files records", fg="white")) - query = "DELETE FROM message_files WHERE id IN :ids" - with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) - click.echo( - click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") - ) - else: - click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red")) - - # clean up the orphaned records in the rest of the *_files tables - try: - # fetch file id and keys from each table - all_files_in_tables = [] - for files_table in files_tables: - click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) - query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) - click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) - - # fetch referred table and columns - guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" - all_ids_in_tables = [] - for ids_table in ids_tables: - query = "" - match ids_table["type"]: - case "uuid": - click.echo( - click.style( - f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", - fg="white", - ) - ) - c = ids_table["column"] - query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) - case "text": - t = ids_table["table"] - click.echo( - click.style( - f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", - fg="white", - ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - case "json": - click.echo( - click.style( - ( - f"- Listing file-id-like JSON string in column {ids_table['column']} " - f"in table {ids_table['table']}" - ), - fg="white", - ) - ) - query = ( - f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " - f"FROM {ids_table['table']}" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - for j in i[0]: - all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - case _: - pass - click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) - - except Exception as e: - click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) - return - - # find orphaned files - all_files = [file["id"] for file in all_files_in_tables] - all_ids = [file["id"] for file in all_ids_in_tables] - orphaned_files = list(set(all_files) - set(all_ids)) - if not orphaned_files: - click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green")) - return - click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) - for file in orphaned_files: - click.echo(click.style(f"- orphaned file id: {file}", fg="black")) - if not force: - click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True) - - # delete orphaned records for each file - try: - for files_table in files_tables: - click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) - query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" - with db.engine.begin() as conn: - conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) - except Exception as e: - click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) - return - click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) - - -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") -def remove_orphaned_files_on_storage(force: bool): - """ - Remove orphaned files on the storage. - """ - - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "key_column": "key"}, - {"table": "tool_files", "key_column": "file_key"}, - ] - storage_paths = ["image_files", "tools", "upload_files"] - - # notify user and ask for confirmation - click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow")) - click.echo( - click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow") - ) - for files_table in files_tables: - click.echo(click.style(f"- {files_table['table']}", fg="yellow")) - click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow")) - for storage_path in storage_paths: - click.echo(click.style(f"- {storage_path}", fg="yellow")) - click.echo("") - - click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) - click.echo( - click.style( - "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow" - ) - ) - click.echo( - click.style( - "Since not all patterns have been fully tested, please note that this command may delete unintended files.", - fg="yellow", - ) - ) - click.echo( - click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow") - ) - click.echo( - click.style( - ( - "It is also recommended to run this during the maintenance window, " - "as this may cause high load on your instance." - ), - fg="yellow", - ) - ) - if not force: - click.confirm("Do you want to proceed?", abort=True) - - # start the cleanup process - click.echo(click.style("Starting orphaned files cleanup.", fg="white")) - - # fetch file id and keys from each table - all_files_in_tables = [] - try: - for files_table in files_tables: - click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) - query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for i in rs: - all_files_in_tables.append(str(i[0])) - click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) - except Exception as e: - click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) - return - - all_files_on_storage = [] - for storage_path in storage_paths: - try: - click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) - files = storage.scan(path=storage_path, files=True, directories=False) - all_files_on_storage.extend(files) - except FileNotFoundError as e: - click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow")) - continue - except Exception as e: - click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red")) - continue - click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) - - # find orphaned files - orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) - if not orphaned_files: - click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) - return - click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) - for file in orphaned_files: - click.echo(click.style(f"- orphaned file: {file}", fg="black")) - if not force: - click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True) - - # delete orphaned files - removed_files = 0 - error_files = 0 - for file in orphaned_files: - try: - storage.delete(file) - removed_files += 1 - click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) - except Exception as e: - error_files += 1 - click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) - continue - if error_files == 0: - click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) - else: - click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) - - -@click.command("file-usage", help="Query file usages and show where files are referenced.") -@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") -@click.option("--key", type=str, default=None, help="Filter by storage key.") -@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") -@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") -@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") -@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") -def file_usage( - file_id: str | None, - key: str | None, - src: str | None, - limit: int, - offset: int, - output_json: bool, -): - """ - Query file usages and show where files are referenced in the database. - - This command reuses the same reference checking logic as clear-orphaned-file-records - and displays detailed information about where each file is referenced. - """ - # define tables and columns to process - files_tables = [ - {"table": "upload_files", "id_column": "id", "key_column": "key"}, - {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, - ] - ids_tables = [ - {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, - {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, - {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, - {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, - {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, - {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, - {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, - {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, - {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, - {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, - {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, - {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, - ] - - # Stream file usages with pagination to avoid holding all results in memory - paginated_usages = [] - total_count = 0 - - # First, build a mapping of file_id -> storage_key from the base tables - file_key_map = {} - for files_table in files_tables: - query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" - - # If filtering by key or file_id, verify it exists - if file_id and file_id not in file_key_map: - if output_json: - click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) - else: - click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) - return - - if key: - valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} - matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] - if not matching_file_ids: - if output_json: - click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) - else: - click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) - return - - guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" - - # For each reference table/column, find matching file IDs and record the references - for ids_table in ids_tables: - src_filter = f"{ids_table['table']}.{ids_table['column']}" - - # Skip if src filter doesn't match (use fnmatch for wildcard patterns) - if src: - if "%" in src or "_" in src: - import fnmatch - - # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) - pattern = src.replace("%", "*").replace("_", "?") - if not fnmatch.fnmatch(src_filter, pattern): - continue - else: - if src_filter != src: - continue - - match ids_table["type"]: - case "uuid": - # Direct UUID match - query = ( - f"SELECT {ids_table['pk_column']}, {ids_table['column']} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - ref_file_id = str(row[1]) - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - - case "text" | "json": - # Extract UUIDs from text/json content - column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] - query = ( - f"SELECT {ids_table['pk_column']}, {column_cast} " - f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" - ) - with db.engine.begin() as conn: - rs = conn.execute(sa.text(query)) - for row in rs: - record_id = str(row[0]) - content = str(row[1]) - - # Find all UUIDs in the content - import re - - uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) - matches = uuid_pattern.findall(content) - - for ref_file_id in matches: - if ref_file_id not in file_key_map: - continue - storage_key = file_key_map[ref_file_id] - - # Apply filters - if file_id and ref_file_id != file_id: - continue - if key and not storage_key.endswith(key): - continue - - # Only collect items within the requested page range - if offset <= total_count < offset + limit: - paginated_usages.append( - { - "src": f"{ids_table['table']}.{ids_table['column']}", - "record_id": record_id, - "file_id": ref_file_id, - "key": storage_key, - } - ) - total_count += 1 - case _: - pass - - # Output results - if output_json: - result = { - "total": total_count, - "offset": offset, - "limit": limit, - "usages": paginated_usages, - } - click.echo(json.dumps(result, indent=2)) - else: - click.echo( - click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") - ) - click.echo("") - - if not paginated_usages: - click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) - return - - # Print table header - click.echo( - click.style( - f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", - fg="cyan", - ) - ) - click.echo(click.style("-" * 190, fg="white")) - - # Print each usage - for usage in paginated_usages: - click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") - - # Show pagination info - if offset + limit < total_count: - click.echo("") - click.echo( - click.style( - f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" - ) - ) - click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) - - -@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_system_tool_oauth_client(provider, client_params): - """ - Setup system tool oauth client - """ - provider_id = ToolProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - - click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) - click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) - click.echo(click.style("Client params encrypted successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - deleted_count = ( - db.session.query(ToolOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - oauth_client = ToolOAuthSystemClient( - provider=provider_name, - plugin_id=plugin_id, - encrypted_oauth_params=oauth_client_params, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) - - -@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_system_trigger_oauth_client(provider, client_params): - """ - Setup system trigger oauth client - """ - from models.provider_ids import TriggerProviderID - from models.trigger import TriggerOAuthSystemClient - - provider_id = TriggerProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - - click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) - click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) - oauth_client_params = encrypt_system_oauth_params(client_params_dict) - click.echo(click.style("Client params encrypted successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - deleted_count = ( - db.session.query(TriggerOAuthSystemClient) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - oauth_client = TriggerOAuthSystemClient( - provider=provider_name, - plugin_id=plugin_id, - encrypted_oauth_params=oauth_client_params, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) - - -def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: - """ - Find draft variables that reference non-existent apps. - - Args: - batch_size: Maximum number of orphaned app IDs to return - - Returns: - List of app IDs that have draft variables but don't exist in the apps table - """ - query = """ - SELECT DISTINCT wdv.app_id - FROM workflow_draft_variables AS wdv - WHERE NOT EXISTS( - SELECT 1 FROM apps WHERE apps.id = wdv.app_id - ) - LIMIT :batch_size - """ - - with db.engine.connect() as conn: - result = conn.execute(sa.text(query), {"batch_size": batch_size}) - return [row[0] for row in result] - - -def _count_orphaned_draft_variables() -> dict[str, Any]: - """ - Count orphaned draft variables by app, including associated file counts. - - Returns: - Dictionary with statistics about orphaned variables and files - """ - # Count orphaned variables by app - variables_query = """ - SELECT - wdv.app_id, - COUNT(*) as variable_count, - COUNT(wdv.file_id) as file_count - FROM workflow_draft_variables AS wdv - WHERE NOT EXISTS( - SELECT 1 FROM apps WHERE apps.id = wdv.app_id - ) - GROUP BY wdv.app_id - ORDER BY variable_count DESC - """ - - with db.engine.connect() as conn: - result = conn.execute(sa.text(variables_query)) - orphaned_by_app = {} - total_files = 0 - - for row in result: - app_id, variable_count, file_count = row - orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count} - total_files += file_count - - total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values()) - app_count = len(orphaned_by_app) - - return { - "total_orphaned_variables": total_orphaned, - "total_orphaned_files": total_files, - "orphaned_app_count": app_count, - "orphaned_by_app": orphaned_by_app, - } - - -@click.command() -@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") -@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") -@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") -@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") -def cleanup_orphaned_draft_variables( - dry_run: bool, - batch_size: int, - max_apps: int | None, - force: bool = False, -): - """ - Clean up orphaned draft variables from the database. - - This script finds and removes draft variables that belong to apps - that no longer exist in the database. - """ - logger = logging.getLogger(__name__) - - # Get statistics - stats = _count_orphaned_draft_variables() - - logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) - logger.info("Found %s associated offload files", stats["total_orphaned_files"]) - logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) - - if stats["total_orphaned_variables"] == 0: - logger.info("No orphaned draft variables found. Exiting.") - return - - if dry_run: - logger.info("DRY RUN: Would delete the following:") - for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[ - :10 - ]: # Show top 10 - logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"]) - if len(stats["orphaned_by_app"]) > 10: - logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) - return - - # Confirm deletion - if not force: - click.confirm( - f"Are you sure you want to delete {stats['total_orphaned_variables']} " - f"orphaned draft variables and {stats['total_orphaned_files']} associated files " - f"from {stats['orphaned_app_count']} apps?", - abort=True, - ) - - total_deleted = 0 - processed_apps = 0 - - while True: - if max_apps and processed_apps >= max_apps: - logger.info("Reached maximum app limit (%s). Stopping.", max_apps) - break - - orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) - if not orphaned_app_ids: - logger.info("No more orphaned draft variables found.") - break - - for app_id in orphaned_app_ids: - if max_apps and processed_apps >= max_apps: - break - - try: - deleted_count = delete_draft_variables_batch(app_id, batch_size) - total_deleted += deleted_count - processed_apps += 1 - - logger.info("Deleted %s variables for app %s", deleted_count, app_id) - - except Exception: - logger.exception("Error processing app %s", app_id) - continue - - logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) - - -@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") -@click.option("--provider", prompt=True, help="Provider name") -@click.option("--client-params", prompt=True, help="Client Params") -def setup_datasource_oauth_client(provider, client_params): - """ - Setup datasource oauth client - """ - provider_id = DatasourceProviderID(provider) - provider_name = provider_id.provider_name - plugin_id = provider_id.plugin_id - - try: - # json validate - click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) - client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) - click.echo(click.style("Client params validated successfully.", fg="green")) - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - - click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) - deleted_count = ( - db.session.query(DatasourceOauthParamConfig) - .filter_by( - provider=provider_name, - plugin_id=plugin_id, - ) - .delete() - ) - if deleted_count > 0: - click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) - - click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) - oauth_client = DatasourceOauthParamConfig( - provider=provider_name, - plugin_id=plugin_id, - system_credentials=client_params_dict, - ) - db.session.add(oauth_client) - db.session.commit() - click.echo(click.style(f"provider: {provider_name}", fg="green")) - click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) - click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) - click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) - - -@click.command("transform-datasource-credentials", help="Transform datasource credentials.") -@click.option( - "--environment", prompt=True, help="the environment to transform datasource credentials", default="online" -) -def transform_datasource_credentials(environment: str): - """ - Transform datasource credentials - """ - try: - installer_manager = PluginInstaller() - plugin_migration = PluginMigration() - - notion_plugin_id = "langgenius/notion_datasource" - firecrawl_plugin_id = "langgenius/firecrawl_datasource" - jina_plugin_id = "langgenius/jina_datasource" - if environment == "online": - notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] - firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] - jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] - else: - notion_plugin_unique_identifier = None - firecrawl_plugin_unique_identifier = None - jina_plugin_unique_identifier = None - oauth_credential_type = CredentialType.OAUTH2 - api_key_credential_type = CredentialType.API_KEY - - # deal notion credentials - deal_notion_count = 0 - notion_credentials = db.session.query(DataSourceOauthBinding).filter_by(provider="notion").all() - if notion_credentials: - notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} - for notion_credential in notion_credentials: - tenant_id = notion_credential.tenant_id - if tenant_id not in notion_credentials_tenant_mapping: - notion_credentials_tenant_mapping[tenant_id] = [] - notion_credentials_tenant_mapping[tenant_id].append(notion_credential) - for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check notion plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if notion_plugin_id not in installed_plugins_ids: - if notion_plugin_unique_identifier: - # install notion plugin - PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) - auth_count = 0 - for notion_tenant_credential in notion_tenant_credentials: - auth_count += 1 - # get credential oauth params - access_token = notion_tenant_credential.access_token - # notion info - notion_info = notion_tenant_credential.source_info - workspace_id = notion_info.get("workspace_id") - workspace_name = notion_info.get("workspace_name") - workspace_icon = notion_info.get("workspace_icon") - new_credentials = { - "integration_secret": encrypter.encrypt_token(tenant_id, access_token), - "workspace_id": workspace_id, - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - } - datasource_provider = DatasourceProvider( - provider="notion_datasource", - tenant_id=tenant_id, - plugin_id=notion_plugin_id, - auth_type=oauth_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url=workspace_icon or "default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_notion_count += 1 - except Exception as e: - click.echo( - click.style( - f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" - ) - ) - continue - db.session.commit() - # deal firecrawl credentials - deal_firecrawl_count = 0 - firecrawl_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="firecrawl").all() - if firecrawl_credentials: - firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for firecrawl_credential in firecrawl_credentials: - tenant_id = firecrawl_credential.tenant_id - if tenant_id not in firecrawl_credentials_tenant_mapping: - firecrawl_credentials_tenant_mapping[tenant_id] = [] - firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) - for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check firecrawl plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if firecrawl_plugin_id not in installed_plugins_ids: - if firecrawl_plugin_unique_identifier: - # install firecrawl plugin - PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) - - auth_count = 0 - for firecrawl_tenant_credential in firecrawl_tenant_credentials: - auth_count += 1 - if not firecrawl_tenant_credential.credentials: - click.echo( - click.style( - f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.", - fg="yellow", - ) - ) - continue - # get credential api key - credentials_json = json.loads(firecrawl_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - base_url = credentials_json.get("config", {}).get("base_url") - new_credentials = { - "firecrawl_api_key": api_key, - "base_url": base_url, - } - datasource_provider = DatasourceProvider( - provider="firecrawl", - tenant_id=tenant_id, - plugin_id=firecrawl_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_firecrawl_count += 1 - except Exception as e: - click.echo( - click.style( - f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" - ) - ) - continue - db.session.commit() - # deal jina credentials - deal_jina_count = 0 - jina_credentials = db.session.query(DataSourceApiKeyAuthBinding).filter_by(provider="jinareader").all() - if jina_credentials: - jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} - for jina_credential in jina_credentials: - tenant_id = jina_credential.tenant_id - if tenant_id not in jina_credentials_tenant_mapping: - jina_credentials_tenant_mapping[tenant_id] = [] - jina_credentials_tenant_mapping[tenant_id].append(jina_credential) - for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): - tenant = db.session.query(Tenant).filter_by(id=tenant_id).first() - if not tenant: - continue - try: - # check jina plugin is installed - installed_plugins = installer_manager.list_plugins(tenant_id) - installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] - if jina_plugin_id not in installed_plugins_ids: - if jina_plugin_unique_identifier: - # install jina plugin - logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) - PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) - - auth_count = 0 - for jina_tenant_credential in jina_tenant_credentials: - auth_count += 1 - if not jina_tenant_credential.credentials: - click.echo( - click.style( - f"Skipping jina credential for tenant {tenant_id} due to missing credentials.", - fg="yellow", - ) - ) - continue - # get credential api key - credentials_json = json.loads(jina_tenant_credential.credentials) - api_key = credentials_json.get("config", {}).get("api_key") - new_credentials = { - "integration_secret": api_key, - } - datasource_provider = DatasourceProvider( - provider="jinareader", - tenant_id=tenant_id, - plugin_id=jina_plugin_id, - auth_type=api_key_credential_type.value, - encrypted_credentials=new_credentials, - name=f"Auth {auth_count}", - avatar_url="default", - is_default=False, - ) - db.session.add(datasource_provider) - deal_jina_count += 1 - except Exception as e: - click.echo( - click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red") - ) - continue - db.session.commit() - except Exception as e: - click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) - return - click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) - click.echo( - click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") - ) - click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) - - -@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.") -@click.option( - "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" -) -@click.option( - "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" -) -@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) -def install_rag_pipeline_plugins(input_file, output_file, workers): - """ - Install rag pipeline plugins - """ - click.echo(click.style("Installing rag pipeline plugins", fg="yellow")) - plugin_migration = PluginMigration() - plugin_migration.install_rag_pipeline_plugins( - input_file, - output_file, - workers, - ) - click.echo(click.style("Installing rag pipeline plugins successfully", fg="green")) - - -@click.command( - "migrate-oss", - help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).", -) -@click.option( - "--path", - "paths", - multiple=True, - help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files," - " tools, website_files, keyword_files, ops_trace", -) -@click.option( - "--source", - type=click.Choice(["local", "opendal"], case_sensitive=False), - default="opendal", - show_default=True, - help="Source storage type to read from", -) -@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists") -@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading") -@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts") -@click.option( - "--update-db/--no-update-db", - default=True, - help="Update upload_files.storage_type from source type to current storage after migration", -) -def migrate_oss( - paths: tuple[str, ...], - source: str, - overwrite: bool, - dry_run: bool, - force: bool, - update_db: bool, -): - """ - Copy all files under selected prefixes from a source storage - (Local filesystem or OpenDAL-backed) into the currently configured - destination storage backend, then optionally update DB records. - - Expected usage: set STORAGE_TYPE (and its credentials) to your target backend. - """ - # Ensure target storage is not local/opendal - if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL): - click.echo( - click.style( - "Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n" - "Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n" - "volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.", - fg="red", - ) - ) - return - - # Default paths if none specified - default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace") - path_list = list(paths) if paths else list(default_paths) - is_source_local = source.lower() == "local" - - click.echo(click.style("Preparing migration to target storage.", fg="yellow")) - click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white")) - if is_source_local: - src_root = dify_config.STORAGE_LOCAL_PATH - click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white")) - else: - click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white")) - click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white")) - click.echo("") - - if not force: - click.confirm("Proceed with migration?", abort=True) - - # Instantiate source storage - try: - if is_source_local: - src_root = dify_config.STORAGE_LOCAL_PATH - source_storage = OpenDALStorage(scheme="fs", root=src_root) - else: - source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME) - except Exception as e: - click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red")) - return - - total_files = 0 - copied_files = 0 - skipped_files = 0 - errored_files = 0 - copied_upload_file_keys: list[str] = [] - - for prefix in path_list: - click.echo(click.style(f"Scanning source path: {prefix}", fg="white")) - try: - keys = source_storage.scan(path=prefix, files=True, directories=False) - except FileNotFoundError: - click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow")) - continue - except NotImplementedError: - click.echo(click.style(" -> Source storage does not support scanning.", fg="red")) - return - except Exception as e: - click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red")) - continue - - click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white")) - - for key in keys: - total_files += 1 - - # check destination existence - if not overwrite: - try: - if storage.exists(key): - skipped_files += 1 - continue - except Exception as e: - # existence check failures should not block migration attempt - # but should be surfaced to user as a warning for visibility - click.echo( - click.style( - f" -> Warning: failed target existence check for {key}: {str(e)}", - fg="yellow", - ) - ) - - if dry_run: - copied_files += 1 - continue - - # read from source and write to destination - try: - data = source_storage.load_once(key) - except FileNotFoundError: - errored_files += 1 - click.echo(click.style(f" -> Missing on source: {key}", fg="yellow")) - continue - except Exception as e: - errored_files += 1 - click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red")) - continue - - try: - storage.save(key, data) - copied_files += 1 - if prefix == "upload_files": - copied_upload_file_keys.append(key) - except Exception as e: - errored_files += 1 - click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red")) - continue - - click.echo("") - click.echo(click.style("Migration summary:", fg="yellow")) - click.echo(click.style(f" Total: {total_files}", fg="white")) - click.echo(click.style(f" Copied: {copied_files}", fg="green")) - click.echo(click.style(f" Skipped: {skipped_files}", fg="white")) - if errored_files: - click.echo(click.style(f" Errors: {errored_files}", fg="red")) - - if dry_run: - click.echo(click.style("Dry-run complete. No changes were made.", fg="green")) - return - - if errored_files: - click.echo( - click.style( - "Some files failed to migrate. Review errors above before updating DB records.", - fg="yellow", - ) - ) - if update_db and not force: - if not click.confirm("Proceed to update DB storage_type despite errors?", default=False): - update_db = False - - # Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files) - if update_db: - if not copied_upload_file_keys: - click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow")) - else: - try: - source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL - updated = ( - db.session.query(UploadFile) - .where( - UploadFile.storage_type == source_storage_type, - UploadFile.key.in_(copied_upload_file_keys), - ) - .update({UploadFile.storage_type: dify_config.STORAGE_TYPE}, synchronize_session=False) - ) - db.session.commit() - click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green")) - except Exception as e: - db.session.rollback() - click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) - - -@click.command("clean-expired-messages", help="Clean expired messages.") -@click.option( - "--start-from", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - required=True, - help="Lower bound (inclusive) for created_at.", -) -@click.option( - "--end-before", - type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), - required=True, - help="Upper bound (exclusive) for created_at.", -) -@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") -@click.option( - "--graceful-period", - default=21, - show_default=True, - help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", -) -@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") -def clean_expired_messages( - batch_size: int, - graceful_period: int, - start_from: datetime.datetime, - end_before: datetime.datetime, - dry_run: bool, -): - """ - Clean expired messages and related data for tenants based on clean policy. - """ - click.echo(click.style("clean_messages: start clean messages.", fg="green")) - - start_at = time.perf_counter() - - try: - # Create policy based on billing configuration - # NOTE: graceful_period will be ignored when billing is disabled. - policy = create_message_clean_policy(graceful_period_days=graceful_period) - - # Create and run the cleanup service - service = MessagesCleanService.from_time_range( - policy=policy, - start_from=start_from, - end_before=end_before, - batch_size=batch_size, - dry_run=dry_run, - ) - stats = service.run() - - end_at = time.perf_counter() - click.echo( - click.style( - f"clean_messages: completed successfully\n" - f" - Latency: {end_at - start_at:.2f}s\n" - f" - Batches processed: {stats['batches']}\n" - f" - Total messages scanned: {stats['total_messages']}\n" - f" - Messages filtered: {stats['filtered_messages']}\n" - f" - Messages deleted: {stats['total_deleted']}", - fg="green", - ) - ) - except Exception as e: - end_at = time.perf_counter() - logger.exception("clean_messages failed") - click.echo( - click.style( - f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", - fg="red", - ) - ) - raise - - click.echo(click.style("messages cleanup completed.", fg="green")) diff --git a/api/commands/__init__.py b/api/commands/__init__.py new file mode 100644 index 0000000000..d62d0dbd7c --- /dev/null +++ b/api/commands/__init__.py @@ -0,0 +1,71 @@ +""" +CLI command modules extracted from `commands.py`. +""" + +from .account import create_tenant, reset_email, reset_password +from .plugin import ( + extract_plugins, + extract_unique_plugins, + install_plugins, + install_rag_pipeline_plugins, + migrate_data_for_plugin, + setup_datasource_oauth_client, + setup_system_tool_oauth_client, + setup_system_trigger_oauth_client, + transform_datasource_credentials, +) +from .retention import ( + archive_workflow_runs, + clean_expired_messages, + clean_workflow_runs, + cleanup_orphaned_draft_variables, + clear_free_plan_tenant_expired_logs, + delete_archived_workflow_runs, + export_app_messages, + restore_workflow_runs, +) +from .storage import clear_orphaned_file_records, file_usage, migrate_oss, remove_orphaned_files_on_storage +from .system import convert_to_agent_apps, fix_app_site_missing, reset_encrypt_key_pair, upgrade_db +from .vector import ( + add_qdrant_index, + migrate_annotation_vector_database, + migrate_knowledge_vector_database, + old_metadata_migration, + vdb_migrate, +) + +__all__ = [ + "add_qdrant_index", + "archive_workflow_runs", + "clean_expired_messages", + "clean_workflow_runs", + "cleanup_orphaned_draft_variables", + "clear_free_plan_tenant_expired_logs", + "clear_orphaned_file_records", + "convert_to_agent_apps", + "create_tenant", + "delete_archived_workflow_runs", + "export_app_messages", + "extract_plugins", + "extract_unique_plugins", + "file_usage", + "fix_app_site_missing", + "install_plugins", + "install_rag_pipeline_plugins", + "migrate_annotation_vector_database", + "migrate_data_for_plugin", + "migrate_knowledge_vector_database", + "migrate_oss", + "old_metadata_migration", + "remove_orphaned_files_on_storage", + "reset_email", + "reset_encrypt_key_pair", + "reset_password", + "restore_workflow_runs", + "setup_datasource_oauth_client", + "setup_system_tool_oauth_client", + "setup_system_trigger_oauth_client", + "transform_datasource_credentials", + "upgrade_db", + "vdb_migrate", +] diff --git a/api/commands/account.py b/api/commands/account.py new file mode 100644 index 0000000000..84af7a5ae6 --- /dev/null +++ b/api/commands/account.py @@ -0,0 +1,130 @@ +import base64 +import secrets + +import click +from sqlalchemy.orm import sessionmaker + +from constants.languages import languages +from extensions.ext_database import db +from libs.helper import email as email_validate +from libs.password import hash_password, password_pattern, valid_password +from services.account_service import AccountService, RegisterService, TenantService + + +@click.command("reset-password", help="Reset the account password.") +@click.option("--email", prompt=True, help="Account email to reset password for") +@click.option("--new-password", prompt=True, help="New password") +@click.option("--password-confirm", prompt=True, help="Confirm new password") +def reset_password(email, new_password, password_confirm): + """ + Reset password of owner account + Only available in SELF_HOSTED mode + """ + if str(new_password).strip() != str(password_confirm).strip(): + click.echo(click.style("Passwords do not match.", fg="red")) + return + normalized_email = email.strip().lower() + + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return + + try: + valid_password(new_password) + except: + click.echo(click.style(f"Invalid password. Must match {password_pattern}", fg="red")) + return + + # generate password salt + salt = secrets.token_bytes(16) + base64_salt = base64.b64encode(salt).decode() + + # encrypt password with salt + password_hashed = hash_password(new_password, salt) + base64_password_hashed = base64.b64encode(password_hashed).decode() + account.password = base64_password_hashed + account.password_salt = base64_salt + AccountService.reset_login_error_rate_limit(normalized_email) + click.echo(click.style("Password reset successfully.", fg="green")) + + +@click.command("reset-email", help="Reset the account email.") +@click.option("--email", prompt=True, help="Current account email") +@click.option("--new-email", prompt=True, help="New email") +@click.option("--email-confirm", prompt=True, help="Confirm new email") +def reset_email(email, new_email, email_confirm): + """ + Replace account email + :return: + """ + if str(new_email).strip() != str(email_confirm).strip(): + click.echo(click.style("New emails do not match.", fg="red")) + return + normalized_new_email = new_email.strip().lower() + + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session) + + if not account: + click.echo(click.style(f"Account not found for email: {email}", fg="red")) + return + + try: + email_validate(normalized_new_email) + except: + click.echo(click.style(f"Invalid email: {new_email}", fg="red")) + return + + account.email = normalized_new_email + click.echo(click.style("Email updated successfully.", fg="green")) + + +@click.command("create-tenant", help="Create account and tenant.") +@click.option("--email", prompt=True, help="Tenant account email.") +@click.option("--name", prompt=True, help="Workspace name.") +@click.option("--language", prompt=True, help="Account language, default: en-US.") +def create_tenant(email: str, language: str | None = None, name: str | None = None): + """ + Create tenant account + """ + if not email: + click.echo(click.style("Email is required.", fg="red")) + return + + # Create account + email = email.strip().lower() + + if "@" not in email: + click.echo(click.style("Invalid email address.", fg="red")) + return + + account_name = email.split("@")[0] + + if language not in languages: + language = "en-US" + + # Validates name encoding for non-Latin characters. + name = name.strip().encode("utf-8").decode("utf-8") if name else None + + # generate random password + new_password = secrets.token_urlsafe(16) + + # register account + account = RegisterService.register( + email=email, + name=account_name, + password=new_password, + language=language, + create_workspace_required=False, + ) + TenantService.create_owner_tenant_if_not_exist(account, name) + + click.echo( + click.style( + f"Account and tenant created.\nAccount: {email}\nPassword: {new_password}", + fg="green", + ) + ) diff --git a/api/commands/plugin.py b/api/commands/plugin.py new file mode 100644 index 0000000000..c34391025a --- /dev/null +++ b/api/commands/plugin.py @@ -0,0 +1,478 @@ +import json +import logging +from typing import Any, cast + +import click +from pydantic import TypeAdapter +from sqlalchemy import delete, select +from sqlalchemy.engine import CursorResult + +from configs import dify_config +from core.helper import encrypter +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.plugin import PluginInstaller +from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params +from extensions.ext_database import db +from models import Tenant +from models.oauth import DatasourceOauthParamConfig, DatasourceProvider +from models.provider_ids import DatasourceProviderID, ToolProviderID +from models.source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding +from models.tools import ToolOAuthSystemClient +from services.plugin.data_migration import PluginDataMigration +from services.plugin.plugin_migration import PluginMigration +from services.plugin.plugin_service import PluginService + +logger = logging.getLogger(__name__) + + +@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_tool_oauth_client(provider, client_params): + """ + Setup system tool oauth client + """ + provider_id = ToolProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = cast( + CursorResult, + db.session.execute( + delete(ToolOAuthSystemClient).where( + ToolOAuthSystemClient.provider == provider_name, + ToolOAuthSystemClient.plugin_id == plugin_id, + ) + ), + ).rowcount + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = ToolOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("setup-system-trigger-oauth-client", help="Setup system trigger oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_trigger_oauth_client(provider, client_params): + """ + Setup system trigger oauth client + """ + from models.provider_ids import TriggerProviderID + from models.trigger import TriggerOAuthSystemClient + + provider_id = TriggerProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = cast( + CursorResult, + db.session.execute( + delete(TriggerOAuthSystemClient).where( + TriggerOAuthSystemClient.provider == provider_name, + TriggerOAuthSystemClient.plugin_id == plugin_id, + ) + ), + ).rowcount + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = TriggerOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("setup-datasource-oauth-client", help="Setup datasource oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_datasource_oauth_client(provider, client_params): + """ + Setup datasource oauth client + """ + provider_id = DatasourceProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + click.echo(click.style(f"Ready to delete existing oauth client params: {provider_name}", fg="yellow")) + deleted_count = cast( + CursorResult, + db.session.execute( + delete(DatasourceOauthParamConfig).where( + DatasourceOauthParamConfig.provider == provider_name, + DatasourceOauthParamConfig.plugin_id == plugin_id, + ) + ), + ).rowcount + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + click.echo(click.style(f"Ready to setup datasource oauth client: {provider_name}", fg="yellow")) + oauth_client = DatasourceOauthParamConfig( + provider=provider_name, + plugin_id=plugin_id, + system_credentials=client_params_dict, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"provider: {provider_name}", fg="green")) + click.echo(click.style(f"plugin_id: {plugin_id}", fg="green")) + click.echo(click.style(f"params: {json.dumps(client_params_dict, indent=2, ensure_ascii=False)}", fg="green")) + click.echo(click.style(f"Datasource oauth client setup successfully. id: {oauth_client.id}", fg="green")) + + +@click.command("transform-datasource-credentials", help="Transform datasource credentials.") +@click.option( + "--environment", prompt=True, help="the environment to transform datasource credentials", default="online" +) +def transform_datasource_credentials(environment: str): + """ + Transform datasource credentials + """ + try: + installer_manager = PluginInstaller() + plugin_migration = PluginMigration() + + notion_plugin_id = "langgenius/notion_datasource" + firecrawl_plugin_id = "langgenius/firecrawl_datasource" + jina_plugin_id = "langgenius/jina_datasource" + if environment == "online": + notion_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(notion_plugin_id) # pyright: ignore[reportPrivateUsage] + firecrawl_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(firecrawl_plugin_id) # pyright: ignore[reportPrivateUsage] + jina_plugin_unique_identifier = plugin_migration._fetch_plugin_unique_identifier(jina_plugin_id) # pyright: ignore[reportPrivateUsage] + else: + notion_plugin_unique_identifier = None + firecrawl_plugin_unique_identifier = None + jina_plugin_unique_identifier = None + oauth_credential_type = CredentialType.OAUTH2 + api_key_credential_type = CredentialType.API_KEY + + # deal notion credentials + deal_notion_count = 0 + notion_credentials = db.session.scalars( + select(DataSourceOauthBinding).where(DataSourceOauthBinding.provider == "notion") + ).all() + if notion_credentials: + notion_credentials_tenant_mapping: dict[str, list[DataSourceOauthBinding]] = {} + for notion_credential in notion_credentials: + tenant_id = notion_credential.tenant_id + if tenant_id not in notion_credentials_tenant_mapping: + notion_credentials_tenant_mapping[tenant_id] = [] + notion_credentials_tenant_mapping[tenant_id].append(notion_credential) + for tenant_id, notion_tenant_credentials in notion_credentials_tenant_mapping.items(): + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + try: + # check notion plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if notion_plugin_id not in installed_plugins_ids: + if notion_plugin_unique_identifier: + # install notion plugin + PluginService.install_from_marketplace_pkg(tenant_id, [notion_plugin_unique_identifier]) + auth_count = 0 + for notion_tenant_credential in notion_tenant_credentials: + auth_count += 1 + # get credential oauth params + access_token = notion_tenant_credential.access_token + # notion info + notion_info = notion_tenant_credential.source_info + workspace_id = notion_info.get("workspace_id") + workspace_name = notion_info.get("workspace_name") + workspace_icon = notion_info.get("workspace_icon") + new_credentials = { + "integration_secret": encrypter.encrypt_token(tenant_id, access_token), + "workspace_id": workspace_id, + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + } + datasource_provider = DatasourceProvider( + provider="notion_datasource", + tenant_id=tenant_id, + plugin_id=notion_plugin_id, + auth_type=oauth_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url=workspace_icon or "default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_notion_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming notion credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) + ) + continue + db.session.commit() + # deal firecrawl credentials + deal_firecrawl_count = 0 + firecrawl_credentials = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "firecrawl") + ).all() + if firecrawl_credentials: + firecrawl_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for firecrawl_credential in firecrawl_credentials: + tenant_id = firecrawl_credential.tenant_id + if tenant_id not in firecrawl_credentials_tenant_mapping: + firecrawl_credentials_tenant_mapping[tenant_id] = [] + firecrawl_credentials_tenant_mapping[tenant_id].append(firecrawl_credential) + for tenant_id, firecrawl_tenant_credentials in firecrawl_credentials_tenant_mapping.items(): + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + try: + # check firecrawl plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if firecrawl_plugin_id not in installed_plugins_ids: + if firecrawl_plugin_unique_identifier: + # install firecrawl plugin + PluginService.install_from_marketplace_pkg(tenant_id, [firecrawl_plugin_unique_identifier]) + + auth_count = 0 + for firecrawl_tenant_credential in firecrawl_tenant_credentials: + auth_count += 1 + if not firecrawl_tenant_credential.credentials: + click.echo( + click.style( + f"Skipping firecrawl credential for tenant {tenant_id} due to missing credentials.", + fg="yellow", + ) + ) + continue + # get credential api key + credentials_json = json.loads(firecrawl_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + base_url = credentials_json.get("config", {}).get("base_url") + new_credentials = { + "firecrawl_api_key": api_key, + "base_url": base_url, + } + datasource_provider = DatasourceProvider( + provider="firecrawl", + tenant_id=tenant_id, + plugin_id=firecrawl_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_firecrawl_count += 1 + except Exception as e: + click.echo( + click.style( + f"Error transforming firecrawl credentials: {str(e)}, tenant_id: {tenant_id}", fg="red" + ) + ) + continue + db.session.commit() + # deal jina credentials + deal_jina_count = 0 + jina_credentials = db.session.scalars( + select(DataSourceApiKeyAuthBinding).where(DataSourceApiKeyAuthBinding.provider == "jinareader") + ).all() + if jina_credentials: + jina_credentials_tenant_mapping: dict[str, list[DataSourceApiKeyAuthBinding]] = {} + for jina_credential in jina_credentials: + tenant_id = jina_credential.tenant_id + if tenant_id not in jina_credentials_tenant_mapping: + jina_credentials_tenant_mapping[tenant_id] = [] + jina_credentials_tenant_mapping[tenant_id].append(jina_credential) + for tenant_id, jina_tenant_credentials in jina_credentials_tenant_mapping.items(): + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) + if not tenant: + continue + try: + # check jina plugin is installed + installed_plugins = installer_manager.list_plugins(tenant_id) + installed_plugins_ids = [plugin.plugin_id for plugin in installed_plugins] + if jina_plugin_id not in installed_plugins_ids: + if jina_plugin_unique_identifier: + # install jina plugin + logger.debug("Installing Jina plugin %s", jina_plugin_unique_identifier) + PluginService.install_from_marketplace_pkg(tenant_id, [jina_plugin_unique_identifier]) + + auth_count = 0 + for jina_tenant_credential in jina_tenant_credentials: + auth_count += 1 + if not jina_tenant_credential.credentials: + click.echo( + click.style( + f"Skipping jina credential for tenant {tenant_id} due to missing credentials.", + fg="yellow", + ) + ) + continue + # get credential api key + credentials_json = json.loads(jina_tenant_credential.credentials) + api_key = credentials_json.get("config", {}).get("api_key") + new_credentials = { + "integration_secret": api_key, + } + datasource_provider = DatasourceProvider( + provider="jinareader", + tenant_id=tenant_id, + plugin_id=jina_plugin_id, + auth_type=api_key_credential_type.value, + encrypted_credentials=new_credentials, + name=f"Auth {auth_count}", + avatar_url="default", + is_default=False, + ) + db.session.add(datasource_provider) + deal_jina_count += 1 + except Exception as e: + click.echo( + click.style(f"Error transforming jina credentials: {str(e)}, tenant_id: {tenant_id}", fg="red") + ) + continue + db.session.commit() + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + click.echo(click.style(f"Transforming notion successfully. deal_notion_count: {deal_notion_count}", fg="green")) + click.echo( + click.style(f"Transforming firecrawl successfully. deal_firecrawl_count: {deal_firecrawl_count}", fg="green") + ) + click.echo(click.style(f"Transforming jina successfully. deal_jina_count: {deal_jina_count}", fg="green")) + + +@click.command("migrate-data-for-plugin", help="Migrate data for plugin.") +def migrate_data_for_plugin(): + """ + Migrate data for plugin. + """ + click.echo(click.style("Starting migrate data for plugin.", fg="white")) + + PluginDataMigration.migrate() + + click.echo(click.style("Migrate data for plugin completed.", fg="green")) + + +@click.command("extract-plugins", help="Extract plugins.") +@click.option("--output_file", prompt=True, help="The file to store the extracted plugins.", default="plugins.jsonl") +@click.option("--workers", prompt=True, help="The number of workers to extract plugins.", default=10) +def extract_plugins(output_file: str, workers: int): + """ + Extract plugins. + """ + click.echo(click.style("Starting extract plugins.", fg="white")) + + PluginMigration.extract_plugins(output_file, workers) + + click.echo(click.style("Extract plugins completed.", fg="green")) + + +@click.command("extract-unique-identifiers", help="Extract unique identifiers.") +@click.option( + "--output_file", + prompt=True, + help="The file to store the extracted unique identifiers.", + default="unique_identifiers.json", +) +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +def extract_unique_plugins(output_file: str, input_file: str): + """ + Extract unique plugins. + """ + click.echo(click.style("Starting extract unique plugins.", fg="white")) + + PluginMigration.extract_unique_plugins_to_file(input_file, output_file) + + click.echo(click.style("Extract unique plugins completed.", fg="green")) + + +@click.command("install-plugins", help="Install plugins.") +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +@click.option( + "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" +) +@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) +def install_plugins(input_file: str, output_file: str, workers: int): + """ + Install plugins. + """ + click.echo(click.style("Starting install plugins.", fg="white")) + + PluginMigration.install_plugins(input_file, output_file, workers) + + click.echo(click.style("Install plugins completed.", fg="green")) + + +@click.command("install-rag-pipeline-plugins", help="Install rag pipeline plugins.") +@click.option( + "--input_file", prompt=True, help="The file to store the extracted unique identifiers.", default="plugins.jsonl" +) +@click.option( + "--output_file", prompt=True, help="The file to store the installed plugins.", default="installed_plugins.jsonl" +) +@click.option("--workers", prompt=True, help="The number of workers to install plugins.", default=100) +def install_rag_pipeline_plugins(input_file, output_file, workers): + """ + Install rag pipeline plugins + """ + click.echo(click.style("Installing rag pipeline plugins", fg="yellow")) + plugin_migration = PluginMigration() + plugin_migration.install_rag_pipeline_plugins( + input_file, + output_file, + workers, + ) + click.echo(click.style("Installing rag pipeline plugins successfully", fg="green")) diff --git a/api/commands/retention.py b/api/commands/retention.py new file mode 100644 index 0000000000..82a77ea77a --- /dev/null +++ b/api/commands/retention.py @@ -0,0 +1,857 @@ +import datetime +import logging +import time +from typing import Any + +import click +import sqlalchemy as sa + +from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now +from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs +from services.retention.conversation.messages_clean_policy import create_message_clean_policy +from services.retention.conversation.messages_clean_service import MessagesCleanService +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup +from tasks.remove_app_and_related_data_task import delete_draft_variables_batch + +logger = logging.getLogger(__name__) + + +@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") +@click.option("--days", prompt=True, help="The days to clear free plan tenant expired logs.", default=30) +@click.option("--batch", prompt=True, help="The batch size to clear free plan tenant expired logs.", default=100) +@click.option( + "--tenant_ids", + prompt=True, + multiple=True, + help="The tenant ids to clear free plan tenant expired logs.", +) +def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[str]): + """ + Clear free plan tenant expired logs. + """ + click.echo(click.style("Starting clear free plan tenant expired logs.", fg="white")) + + ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) + + click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) + + +@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.") +@click.option( + "--before-days", + "--days", + default=30, + show_default=True, + type=click.IntRange(min=0), + help="Delete workflow runs created before N days ago.", +) +@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option( + "--dry-run", + is_flag=True, + help="Preview cleanup results without deleting any workflow run data.", +) +def clean_workflow_runs( + before_days: int, + batch_size: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + dry_run: bool, +): + """ + Clean workflow runs and related workflow data for free tenants. + """ + from extensions.otel.runtime import flush_telemetry + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + + if (from_days_ago is None) ^ (to_days_ago is None): + raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.") + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + raise click.UsageError("Choose either day offsets or explicit dates, not both.") + if from_days_ago <= to_days_ago: + raise click.UsageError("--from-days-ago must be greater than --to-days-ago.") + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if from_days_ago is not None and to_days_ago is not None: + task_label = f"{from_days_ago}to{to_days_ago}" + elif start_from is None: + task_label = f"before-{before_days}" + else: + task_label = "custom" + + start_time = datetime.datetime.now(datetime.UTC) + click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white")) + + try: + WorkflowRunCleanup( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + dry_run=dry_run, + task_label=task_label, + ).run() + finally: + flush_telemetry() + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + click.echo( + click.style( + f"Workflow run cleanup completed. start={start_time.isoformat()} " + f"end={end_time.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "archive-workflow-runs", + help="Archive workflow runs for paid plan tenants to S3-compatible storage.", +) +@click.option("--tenant-ids", default=None, help="Optional comma-separated tenant IDs for grayscale rollout.") +@click.option("--before-days", default=90, show_default=True, help="Archive runs older than N days.") +@click.option( + "--from-days-ago", + default=None, + type=click.IntRange(min=0), + help="Lower bound in days ago (older). Must be paired with --to-days-ago.", +) +@click.option( + "--to-days-ago", + default=None, + type=click.IntRange(min=0), + help="Upper bound in days ago (newer). Must be paired with --from-days-ago.", +) +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created at or after this timestamp (UTC if no timezone).", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Archive runs created before this timestamp (UTC if no timezone).", +) +@click.option("--batch-size", default=100, show_default=True, help="Batch size for processing.") +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to archive.") +@click.option("--limit", default=None, type=int, help="Maximum number of runs to archive.") +@click.option("--dry-run", is_flag=True, help="Preview without archiving.") +@click.option("--delete-after-archive", is_flag=True, help="Delete runs and related data after archiving.") +def archive_workflow_runs( + tenant_ids: str | None, + before_days: int, + from_days_ago: int | None, + to_days_ago: int | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + batch_size: int, + workers: int, + limit: int | None, + dry_run: bool, + delete_after_archive: bool, +): + """ + Archive workflow runs for paid plan tenants older than the specified days. + + This command archives the following tables to storage: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + + The workflow_runs and workflow_app_logs tables are preserved for UI listing. + """ + from services.retention.workflow_run.archive_paid_plan_workflow_run import WorkflowRunArchiver + + run_started_at = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting workflow run archiving at {run_started_at.isoformat()}.", + fg="white", + ) + ) + + if (start_from is None) ^ (end_before is None): + click.echo(click.style("start-from and end-before must be provided together.", fg="red")) + return + + if (from_days_ago is None) ^ (to_days_ago is None): + click.echo(click.style("from-days-ago and to-days-ago must be provided together.", fg="red")) + return + + if from_days_ago is not None and to_days_ago is not None: + if start_from or end_before: + click.echo(click.style("Choose either day offsets or explicit dates, not both.", fg="red")) + return + if from_days_ago <= to_days_ago: + click.echo(click.style("from-days-ago must be greater than to-days-ago.", fg="red")) + return + now = datetime.datetime.now() + start_from = now - datetime.timedelta(days=from_days_ago) + end_before = now - datetime.timedelta(days=to_days_ago) + before_days = 0 + + if start_from and end_before and start_from >= end_before: + click.echo(click.style("start-from must be earlier than end-before.", fg="red")) + return + if workers < 1: + click.echo(click.style("workers must be at least 1.", fg="red")) + return + + archiver = WorkflowRunArchiver( + days=before_days, + batch_size=batch_size, + start_from=start_from, + end_before=end_before, + workers=workers, + tenant_ids=[tid.strip() for tid in tenant_ids.split(",")] if tenant_ids else None, + limit=limit, + dry_run=dry_run, + delete_after_archive=delete_after_archive, + ) + summary = archiver.run() + click.echo( + click.style( + f"Summary: processed={summary.total_runs_processed}, archived={summary.runs_archived}, " + f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, " + f"time={summary.total_elapsed_time:.2f}s", + fg="cyan", + ) + ) + + run_finished_at = datetime.datetime.now(datetime.UTC) + elapsed = run_finished_at - run_started_at + click.echo( + click.style( + f"Workflow run archiving completed. start={run_started_at.isoformat()} " + f"end={run_finished_at.isoformat()} duration={elapsed}", + fg="green", + ) + ) + + +@click.command( + "restore-workflow-runs", + help="Restore archived workflow runs from S3-compatible storage.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to restore.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--workers", default=1, show_default=True, type=int, help="Concurrent workflow runs to restore.") +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to restore.") +@click.option("--dry-run", is_flag=True, help="Preview without restoring.") +def restore_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + workers: int, + limit: int, + dry_run: bool, +): + """ + Restore an archived workflow run from storage to the database. + + This restores the following tables: + - workflow_node_executions + - workflow_node_execution_offload + - workflow_pauses + - workflow_pause_reasons + - workflow_trigger_logs + """ + from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch restore.") + if workers < 1: + raise click.BadParameter("workers must be at least 1") + + start_time = datetime.datetime.now(datetime.UTC) + click.echo( + click.style( + f"Starting restore of workflow run {run_id} at {start_time.isoformat()}.", + fg="white", + ) + ) + + restorer = WorkflowRunRestore(dry_run=dry_run, workers=workers) + if run_id: + results = [restorer.restore_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = restorer.restore_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Restore completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Restore completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +@click.command( + "delete-archived-workflow-runs", + help="Delete archived workflow runs from the database.", +) +@click.option( + "--tenant-ids", + required=False, + help="Tenant IDs (comma-separated).", +) +@click.option("--run-id", required=False, help="Workflow run ID to delete.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.", +) +@click.option("--limit", type=int, default=100, show_default=True, help="Maximum number of runs to delete.") +@click.option("--dry-run", is_flag=True, help="Preview without deleting.") +def delete_archived_workflow_runs( + tenant_ids: str | None, + run_id: str | None, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + limit: int, + dry_run: bool, +): + """ + Delete archived workflow runs from the database. + """ + from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion + + parsed_tenant_ids = None + if tenant_ids: + parsed_tenant_ids = [tid.strip() for tid in tenant_ids.split(",") if tid.strip()] + if not parsed_tenant_ids: + raise click.BadParameter("tenant-ids must not be empty") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("--start-from and --end-before must be provided together.") + if run_id is None and (start_from is None or end_before is None): + raise click.UsageError("--start-from and --end-before are required for batch delete.") + + start_time = datetime.datetime.now(datetime.UTC) + target_desc = f"workflow run {run_id}" if run_id else "workflow runs" + click.echo( + click.style( + f"Starting delete of {target_desc} at {start_time.isoformat()}.", + fg="white", + ) + ) + + deleter = ArchivedWorkflowRunDeletion(dry_run=dry_run) + if run_id: + results = [deleter.delete_by_run_id(run_id)] + else: + assert start_from is not None + assert end_before is not None + results = deleter.delete_batch( + parsed_tenant_ids, + start_date=start_from, + end_date=end_before, + limit=limit, + ) + + for result in results: + if result.success: + click.echo( + click.style( + f"{'[DRY RUN] Would delete' if dry_run else 'Deleted'} " + f"workflow run {result.run_id} (tenant={result.tenant_id})", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Failed to delete workflow run {result.run_id}: {result.error}", + fg="red", + ) + ) + + end_time = datetime.datetime.now(datetime.UTC) + elapsed = end_time - start_time + + successes = sum(1 for result in results if result.success) + failures = len(results) - successes + + if failures == 0: + click.echo( + click.style( + f"Delete completed successfully. success={successes} duration={elapsed}", + fg="green", + ) + ) + else: + click.echo( + click.style( + f"Delete completed with failures. success={successes} failed={failures} duration={elapsed}", + fg="red", + ) + ) + + +def _find_orphaned_draft_variables(batch_size: int = 1000) -> list[str]: + """ + Find draft variables that reference non-existent apps. + + Args: + batch_size: Maximum number of orphaned app IDs to return + + Returns: + List of app IDs that have draft variables but don't exist in the apps table + """ + query = """ + SELECT DISTINCT wdv.app_id + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + LIMIT :batch_size + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(query), {"batch_size": batch_size}) + return [row[0] for row in result] + + +def _count_orphaned_draft_variables() -> dict[str, Any]: + """ + Count orphaned draft variables by app, including associated file counts. + + Returns: + Dictionary with statistics about orphaned variables and files + """ + # Count orphaned variables by app + variables_query = """ + SELECT + wdv.app_id, + COUNT(*) as variable_count, + COUNT(wdv.file_id) as file_count + FROM workflow_draft_variables AS wdv + WHERE NOT EXISTS( + SELECT 1 FROM apps WHERE apps.id = wdv.app_id + ) + GROUP BY wdv.app_id + ORDER BY variable_count DESC + """ + + with db.engine.connect() as conn: + result = conn.execute(sa.text(variables_query)) + orphaned_by_app = {} + total_files = 0 + + for row in result: + app_id, variable_count, file_count = row + orphaned_by_app[app_id] = {"variables": variable_count, "files": file_count} + total_files += file_count + + total_orphaned = sum(app_data["variables"] for app_data in orphaned_by_app.values()) + app_count = len(orphaned_by_app) + + return { + "total_orphaned_variables": total_orphaned, + "total_orphaned_files": total_files, + "orphaned_app_count": app_count, + "orphaned_by_app": orphaned_by_app, + } + + +@click.command() +@click.option("--dry-run", is_flag=True, help="Show what would be deleted without actually deleting") +@click.option("--batch-size", default=1000, help="Number of records to process per batch (default 1000)") +@click.option("--max-apps", default=None, type=int, help="Maximum number of apps to process (default: no limit)") +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +def cleanup_orphaned_draft_variables( + dry_run: bool, + batch_size: int, + max_apps: int | None, + force: bool = False, +): + """ + Clean up orphaned draft variables from the database. + + This script finds and removes draft variables that belong to apps + that no longer exist in the database. + """ + logger = logging.getLogger(__name__) + + # Get statistics + stats = _count_orphaned_draft_variables() + + logger.info("Found %s orphaned draft variables", stats["total_orphaned_variables"]) + logger.info("Found %s associated offload files", stats["total_orphaned_files"]) + logger.info("Across %s non-existent apps", stats["orphaned_app_count"]) + + if stats["total_orphaned_variables"] == 0: + logger.info("No orphaned draft variables found. Exiting.") + return + + if dry_run: + logger.info("DRY RUN: Would delete the following:") + for app_id, data in sorted(stats["orphaned_by_app"].items(), key=lambda x: x[1]["variables"], reverse=True)[ + :10 + ]: # Show top 10 + logger.info(" App %s: %s variables, %s files", app_id, data["variables"], data["files"]) + if len(stats["orphaned_by_app"]) > 10: + logger.info(" ... and %s more apps", len(stats["orphaned_by_app"]) - 10) + return + + # Confirm deletion + if not force: + click.confirm( + f"Are you sure you want to delete {stats['total_orphaned_variables']} " + f"orphaned draft variables and {stats['total_orphaned_files']} associated files " + f"from {stats['orphaned_app_count']} apps?", + abort=True, + ) + + total_deleted = 0 + processed_apps = 0 + + while True: + if max_apps and processed_apps >= max_apps: + logger.info("Reached maximum app limit (%s). Stopping.", max_apps) + break + + orphaned_app_ids = _find_orphaned_draft_variables(batch_size=10) + if not orphaned_app_ids: + logger.info("No more orphaned draft variables found.") + break + + for app_id in orphaned_app_ids: + if max_apps and processed_apps >= max_apps: + break + + try: + deleted_count = delete_draft_variables_batch(app_id, batch_size) + total_deleted += deleted_count + processed_apps += 1 + + logger.info("Deleted %s variables for app %s", deleted_count, app_id) + + except Exception: + logger.exception("Error processing app %s", app_id) + continue + + logger.info("Cleanup completed. Total deleted: %s variables across %s apps", total_deleted, processed_apps) + + +@click.command("clean-expired-messages", help="Clean expired messages.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=False, + default=None, + help="Lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=False, + default=None, + help="Upper bound (exclusive) for created_at.", +) +@click.option( + "--from-days-ago", + type=int, + default=None, + help="Relative lower bound in days ago (inclusive). Must be used with --before-days.", +) +@click.option( + "--before-days", + type=int, + default=None, + help="Relative upper bound in days ago (exclusive). Required for relative mode.", +) +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.") +@click.option( + "--graceful-period", + default=21, + show_default=True, + help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.", +) +@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting") +def clean_expired_messages( + batch_size: int, + graceful_period: int, + start_from: datetime.datetime | None, + end_before: datetime.datetime | None, + from_days_ago: int | None, + before_days: int | None, + dry_run: bool, +): + """ + Clean expired messages and related data for tenants based on clean policy. + """ + from extensions.otel.runtime import flush_telemetry + + click.echo(click.style("clean_messages: start clean messages.", fg="green")) + + start_at = time.perf_counter() + + try: + abs_mode = start_from is not None and end_before is not None + rel_mode = before_days is not None + + if abs_mode and rel_mode: + raise click.UsageError( + "Options are mutually exclusive: use either (--start-from,--end-before) " + "or (--from-days-ago,--before-days)." + ) + + if from_days_ago is not None and before_days is None: + raise click.UsageError("--from-days-ago must be used together with --before-days.") + + if (start_from is None) ^ (end_before is None): + raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.") + + if not abs_mode and not rel_mode: + raise click.UsageError( + "You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])." + ) + + if rel_mode: + assert before_days is not None + if before_days < 0: + raise click.UsageError("--before-days must be >= 0.") + if from_days_ago is not None: + if from_days_ago < 0: + raise click.UsageError("--from-days-ago must be >= 0.") + if from_days_ago <= before_days: + raise click.UsageError("--from-days-ago must be greater than --before-days.") + + # Create policy based on billing configuration + # NOTE: graceful_period will be ignored when billing is disabled. + policy = create_message_clean_policy(graceful_period_days=graceful_period) + + if from_days_ago is not None and before_days is not None: + task_label = f"{from_days_ago}to{before_days}" + elif start_from is None and before_days is not None: + task_label = f"before-{before_days}" + else: + task_label = "custom" + + # Create and run the cleanup service + if abs_mode: + assert start_from is not None + assert end_before is not None + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=batch_size, + dry_run=dry_run, + task_label=task_label, + ) + elif from_days_ago is None: + assert before_days is not None + service = MessagesCleanService.from_days( + policy=policy, + days=before_days, + batch_size=batch_size, + dry_run=dry_run, + task_label=task_label, + ) + else: + assert before_days is not None + assert from_days_ago is not None + now = naive_utc_now() + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=now - datetime.timedelta(days=from_days_ago), + end_before=now - datetime.timedelta(days=before_days), + batch_size=batch_size, + dry_run=dry_run, + task_label=task_label, + ) + stats = service.run() + + end_at = time.perf_counter() + click.echo( + click.style( + f"clean_messages: completed successfully\n" + f" - Latency: {end_at - start_at:.2f}s\n" + f" - Batches processed: {stats['batches']}\n" + f" - Total messages scanned: {stats['total_messages']}\n" + f" - Messages filtered: {stats['filtered_messages']}\n" + f" - Messages deleted: {stats['total_deleted']}", + fg="green", + ) + ) + except Exception as e: + end_at = time.perf_counter() + logger.exception("clean_messages failed") + click.echo( + click.style( + f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}", + fg="red", + ) + ) + raise + finally: + flush_telemetry() + + click.echo(click.style("messages cleanup completed.", fg="green")) + + +@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.") +@click.option("--app-id", required=True, help="Application ID to export messages for.") +@click.option( + "--start-from", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + default=None, + help="Optional lower bound (inclusive) for created_at.", +) +@click.option( + "--end-before", + type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]), + required=True, + help="Upper bound (exclusive) for created_at.", +) +@click.option( + "--filename", + required=True, + help="Base filename (relative path). Do not include suffix like .jsonl.gz.", +) +@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.") +@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.") +@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.") +def export_app_messages( + app_id: str, + start_from: datetime.datetime | None, + end_before: datetime.datetime, + filename: str, + use_cloud_storage: bool, + batch_size: int, + dry_run: bool, +): + if start_from and start_from >= end_before: + raise click.UsageError("--start-from must be before --end-before.") + + from services.retention.conversation.message_export_service import AppMessageExportService + + try: + validated_filename = AppMessageExportService.validate_export_filename(filename) + except ValueError as e: + raise click.BadParameter(str(e), param_hint="--filename") from e + + click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green")) + start_at = time.perf_counter() + + try: + service = AppMessageExportService( + app_id=app_id, + end_before=end_before, + filename=validated_filename, + start_from=start_from, + batch_size=batch_size, + use_cloud_storage=use_cloud_storage, + dry_run=dry_run, + ) + stats = service.run() + + elapsed = time.perf_counter() - start_at + click.echo( + click.style( + f"export_app_messages: completed in {elapsed:.2f}s\n" + f" - Batches: {stats.batches}\n" + f" - Total messages: {stats.total_messages}\n" + f" - Messages with feedback: {stats.messages_with_feedback}\n" + f" - Total feedbacks: {stats.total_feedbacks}", + fg="green", + ) + ) + except Exception as e: + elapsed = time.perf_counter() - start_at + logger.exception("export_app_messages failed") + click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red")) + raise diff --git a/api/commands/storage.py b/api/commands/storage.py new file mode 100644 index 0000000000..f23b17680a --- /dev/null +++ b/api/commands/storage.py @@ -0,0 +1,761 @@ +import json +from typing import cast + +import click +import sqlalchemy as sa +from sqlalchemy import update +from sqlalchemy.engine import CursorResult + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_storage import storage +from extensions.storage.opendal_storage import OpenDALStorage +from extensions.storage.storage_type import StorageType +from models.model import UploadFile + + +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +@click.command("clear-orphaned-file-records", help="Clear orphaned file records.") +def clear_orphaned_file_records(force: bool): + """ + Clear orphaned file records in the database. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id"}, + {"type": "text", "table": "documents", "column": "data_source_info"}, + {"type": "text", "table": "document_segments", "column": "content"}, + {"type": "text", "table": "messages", "column": "answer"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, + {"type": "text", "table": "conversations", "column": "introduction"}, + {"type": "text", "table": "conversations", "column": "system_instruction"}, + {"type": "text", "table": "accounts", "column": "avatar"}, + {"type": "text", "table": "apps", "column": "icon"}, + {"type": "text", "table": "sites", "column": "icon"}, + {"type": "json", "table": "messages", "column": "inputs"}, + {"type": "json", "table": "messages", "column": "message"}, + ] + + # notify user and ask for confirmation + click.echo( + click.style( + "This command will first find and delete orphaned file records from the message_files table,", fg="yellow" + ) + ) + click.echo( + click.style( + "and then it will find and delete orphaned file records in the following tables:", + fg="yellow", + ) + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo( + click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow") + ) + for ids_table in ids_tables: + click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + ( + "Since not all patterns have been fully tested, " + "please note that this command may delete unintended file records." + ), + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + if not force: + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned file records cleanup.", fg="white")) + + # clean up the orphaned records in the message_files table where message_id doesn't exist in messages table + try: + click.echo( + click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white") + ) + query = ( + "SELECT mf.id, mf.message_id " + "FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id " + "WHERE m.id IS NULL" + ) + orphaned_message_files = [] + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) + + if orphaned_message_files: + click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white")) + for record in orphaned_message_files: + click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black")) + + if not force: + click.confirm( + ( + f"Do you want to proceed " + f"to delete all {len(orphaned_message_files)} orphaned message_files records?" + ), + abort=True, + ) + + click.echo(click.style("- Deleting orphaned message_files records", fg="white")) + query = "DELETE FROM message_files WHERE id IN :ids" + with db.engine.begin() as conn: + conn.execute(sa.text(query), {"ids": tuple(record["id"] for record in orphaned_message_files)}) + click.echo( + click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green") + ) + else: + click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red")) + + # clean up the orphaned records in the rest of the *_files tables + try: + # fetch file id and keys from each table + all_files_in_tables = [] + for files_table in files_tables: + click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + + # fetch referred table and columns + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + all_ids_in_tables = [] + for ids_table in ids_tables: + query = "" + match ids_table["type"]: + case "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) + ) + c = ids_table["column"] + query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + case "text": + t = ids_table["table"] + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {t}", + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + case _: + pass + click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) + + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return + + # find orphaned files + all_files = [file["id"] for file in all_files_in_tables] + all_ids = [file["id"] for file in all_ids_in_tables] + orphaned_files = list(set(all_files) - set(all_ids)) + if not orphaned_files: + click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file id: {file}", fg="black")) + if not force: + click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True) + + # delete orphaned records for each file + try: + for files_table in files_tables: + click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) + query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" + with db.engine.begin() as conn: + conn.execute(sa.text(query), {"ids": tuple(orphaned_files)}) + except Exception as e: + click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) + return + click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) + + +@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.") +@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") +def remove_orphaned_files_on_storage(force: bool): + """ + Remove orphaned files on the storage. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "key_column": "key"}, + {"table": "tool_files", "key_column": "file_key"}, + ] + storage_paths = ["image_files", "tools", "upload_files"] + + # notify user and ask for confirmation + click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow")) + click.echo( + click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow") + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow")) + for storage_path in storage_paths: + click.echo(click.style(f"- {storage_path}", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow" + ) + ) + click.echo( + click.style( + "Since not all patterns have been fully tested, please note that this command may delete unintended files.", + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + if not force: + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned files cleanup.", fg="white")) + + # fetch file id and keys from each table + all_files_in_tables = [] + try: + for files_table in files_tables: + click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for i in rs: + all_files_in_tables.append(str(i[0])) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return + + all_files_on_storage = [] + for storage_path in storage_paths: + try: + click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) + files = storage.scan(path=storage_path, files=True, directories=False) + all_files_on_storage.extend(files) + except FileNotFoundError: + click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow")) + continue + except Exception as e: + click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red")) + continue + click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) + + # find orphaned files + orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) + if not orphaned_files: + click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file: {file}", fg="black")) + if not force: + click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True) + + # delete orphaned files + removed_files = 0 + error_files = 0 + for file in orphaned_files: + try: + storage.delete(file) + removed_files += 1 + click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) + except Exception as e: + error_files += 1 + click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) + continue + if error_files == 0: + click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) + else: + click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) + + +@click.command("file-usage", help="Query file usages and show where files are referenced.") +@click.option("--file-id", type=str, default=None, help="Filter by file UUID.") +@click.option("--key", type=str, default=None, help="Filter by storage key.") +@click.option("--src", type=str, default=None, help="Filter by table.column pattern (e.g., 'documents.%' or '%.icon').") +@click.option("--limit", type=int, default=100, help="Limit number of results (default: 100).") +@click.option("--offset", type=int, default=0, help="Offset for pagination (default: 0).") +@click.option("--json", "output_json", is_flag=True, help="Output results in JSON format.") +def file_usage( + file_id: str | None, + key: str | None, + src: str | None, + limit: int, + offset: int, + output_json: bool, +): + """ + Query file usages and show where files are referenced in the database. + + This command reuses the same reference checking logic as clear-orphaned-file-records + and displays detailed information about where each file is referenced. + """ + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id", "pk_column": "id"}, + {"type": "text", "table": "documents", "column": "data_source_info", "pk_column": "id"}, + {"type": "text", "table": "document_segments", "column": "content", "pk_column": "id"}, + {"type": "text", "table": "messages", "column": "answer", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data", "pk_column": "id"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "introduction", "pk_column": "id"}, + {"type": "text", "table": "conversations", "column": "system_instruction", "pk_column": "id"}, + {"type": "text", "table": "accounts", "column": "avatar", "pk_column": "id"}, + {"type": "text", "table": "apps", "column": "icon", "pk_column": "id"}, + {"type": "text", "table": "sites", "column": "icon", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "inputs", "pk_column": "id"}, + {"type": "json", "table": "messages", "column": "message", "pk_column": "id"}, + ] + + # Stream file usages with pagination to avoid holding all results in memory + paginated_usages = [] + total_count = 0 + + # First, build a mapping of file_id -> storage_key from the base tables + file_key_map = {} + for files_table in files_tables: + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + file_key_map[str(row[0])] = f"{files_table['table']}:{row[1]}" + + # If filtering by key or file_id, verify it exists + if file_id and file_id not in file_key_map: + if output_json: + click.echo(json.dumps({"error": f"File ID {file_id} not found in base tables"})) + else: + click.echo(click.style(f"File ID {file_id} not found in base tables.", fg="red")) + return + + if key: + valid_prefixes = {f"upload_files:{key}", f"tool_files:{key}"} + matching_file_ids = [fid for fid, fkey in file_key_map.items() if fkey in valid_prefixes] + if not matching_file_ids: + if output_json: + click.echo(json.dumps({"error": f"Key {key} not found in base tables"})) + else: + click.echo(click.style(f"Key {key} not found in base tables.", fg="red")) + return + + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + + # For each reference table/column, find matching file IDs and record the references + for ids_table in ids_tables: + src_filter = f"{ids_table['table']}.{ids_table['column']}" + + # Skip if src filter doesn't match (use fnmatch for wildcard patterns) + if src: + if "%" in src or "_" in src: + import fnmatch + + # Convert SQL LIKE wildcards to fnmatch wildcards (% -> *, _ -> ?) + pattern = src.replace("%", "*").replace("_", "?") + if not fnmatch.fnmatch(src_filter, pattern): + continue + else: + if src_filter != src: + continue + + match ids_table["type"]: + case "uuid": + # Direct UUID match + query = ( + f"SELECT {ids_table['pk_column']}, {ids_table['column']} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + ref_file_id = str(row[1]) + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + + case "text" | "json": + # Extract UUIDs from text/json content + column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"] + query = ( + f"SELECT {ids_table['pk_column']}, {column_cast} " + f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(sa.text(query)) + for row in rs: + record_id = str(row[0]) + content = str(row[1]) + + # Find all UUIDs in the content + import re + + uuid_pattern = re.compile(guid_regexp, re.IGNORECASE) + matches = uuid_pattern.findall(content) + + for ref_file_id in matches: + if ref_file_id not in file_key_map: + continue + storage_key = file_key_map[ref_file_id] + + # Apply filters + if file_id and ref_file_id != file_id: + continue + if key and not storage_key.endswith(key): + continue + + # Only collect items within the requested page range + if offset <= total_count < offset + limit: + paginated_usages.append( + { + "src": f"{ids_table['table']}.{ids_table['column']}", + "record_id": record_id, + "file_id": ref_file_id, + "key": storage_key, + } + ) + total_count += 1 + case _: + pass + + # Output results + if output_json: + result = { + "total": total_count, + "offset": offset, + "limit": limit, + "usages": paginated_usages, + } + click.echo(json.dumps(result, indent=2)) + else: + click.echo( + click.style(f"Found {total_count} file usages (showing {len(paginated_usages)} results)", fg="white") + ) + click.echo("") + + if not paginated_usages: + click.echo(click.style("No file usages found matching the specified criteria.", fg="yellow")) + return + + # Print table header + click.echo( + click.style( + f"{'Src (Table.Column)':<50} {'Record ID':<40} {'File ID':<40} {'Storage Key':<60}", + fg="cyan", + ) + ) + click.echo(click.style("-" * 190, fg="white")) + + # Print each usage + for usage in paginated_usages: + click.echo(f"{usage['src']:<50} {usage['record_id']:<40} {usage['file_id']:<40} {usage['key']:<60}") + + # Show pagination info + if offset + limit < total_count: + click.echo("") + click.echo( + click.style( + f"Showing {offset + 1}-{offset + len(paginated_usages)} of {total_count} results", fg="white" + ) + ) + click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white")) + + +@click.command( + "migrate-oss", + help="Migrate files from Local or OpenDAL source to a cloud OSS storage (destination must NOT be local/opendal).", +) +@click.option( + "--path", + "paths", + multiple=True, + help="Storage path prefixes to migrate (repeatable). Defaults: privkeys, upload_files, image_files," + " tools, website_files, keyword_files, ops_trace", +) +@click.option( + "--source", + type=click.Choice(["local", "opendal"], case_sensitive=False), + default="opendal", + show_default=True, + help="Source storage type to read from", +) +@click.option("--overwrite", is_flag=True, default=False, help="Overwrite destination if file already exists") +@click.option("--dry-run", is_flag=True, default=False, help="Show what would be migrated without uploading") +@click.option("-f", "--force", is_flag=True, help="Skip confirmation and run without prompts") +@click.option( + "--update-db/--no-update-db", + default=True, + help="Update upload_files.storage_type from source type to current storage after migration", +) +def migrate_oss( + paths: tuple[str, ...], + source: str, + overwrite: bool, + dry_run: bool, + force: bool, + update_db: bool, +): + """ + Copy all files under selected prefixes from a source storage + (Local filesystem or OpenDAL-backed) into the currently configured + destination storage backend, then optionally update DB records. + + Expected usage: set STORAGE_TYPE (and its credentials) to your target backend. + """ + # Ensure target storage is not local/opendal + if dify_config.STORAGE_TYPE in (StorageType.LOCAL, StorageType.OPENDAL): + click.echo( + click.style( + "Target STORAGE_TYPE must be a cloud OSS (not 'local' or 'opendal').\n" + "Please set STORAGE_TYPE to one of: s3, aliyun-oss, azure-blob, google-storage, tencent-cos, \n" + "volcengine-tos, supabase, oci-storage, huawei-obs, baidu-obs, clickzetta-volume.", + fg="red", + ) + ) + return + + # Default paths if none specified + default_paths = ("privkeys", "upload_files", "image_files", "tools", "website_files", "keyword_files", "ops_trace") + path_list = list(paths) if paths else list(default_paths) + is_source_local = source.lower() == "local" + + click.echo(click.style("Preparing migration to target storage.", fg="yellow")) + click.echo(click.style(f"Target storage type: {dify_config.STORAGE_TYPE}", fg="white")) + if is_source_local: + src_root = dify_config.STORAGE_LOCAL_PATH + click.echo(click.style(f"Source: local fs, root: {src_root}", fg="white")) + else: + click.echo(click.style(f"Source: opendal scheme={dify_config.OPENDAL_SCHEME}", fg="white")) + click.echo(click.style(f"Paths to migrate: {', '.join(path_list)}", fg="white")) + click.echo("") + + if not force: + click.confirm("Proceed with migration?", abort=True) + + # Instantiate source storage + try: + if is_source_local: + src_root = dify_config.STORAGE_LOCAL_PATH + source_storage = OpenDALStorage(scheme="fs", root=src_root) + else: + source_storage = OpenDALStorage(scheme=dify_config.OPENDAL_SCHEME) + except Exception as e: + click.echo(click.style(f"Failed to initialize source storage: {str(e)}", fg="red")) + return + + total_files = 0 + copied_files = 0 + skipped_files = 0 + errored_files = 0 + copied_upload_file_keys: list[str] = [] + + for prefix in path_list: + click.echo(click.style(f"Scanning source path: {prefix}", fg="white")) + try: + keys = source_storage.scan(path=prefix, files=True, directories=False) + except FileNotFoundError: + click.echo(click.style(f" -> Skipping missing path: {prefix}", fg="yellow")) + continue + except NotImplementedError: + click.echo(click.style(" -> Source storage does not support scanning.", fg="red")) + return + except Exception as e: + click.echo(click.style(f" -> Error scanning '{prefix}': {str(e)}", fg="red")) + continue + + click.echo(click.style(f"Found {len(keys)} files under {prefix}", fg="white")) + + for key in keys: + total_files += 1 + + # check destination existence + if not overwrite: + try: + if storage.exists(key): + skipped_files += 1 + continue + except Exception as e: + # existence check failures should not block migration attempt + # but should be surfaced to user as a warning for visibility + click.echo( + click.style( + f" -> Warning: failed target existence check for {key}: {str(e)}", + fg="yellow", + ) + ) + + if dry_run: + copied_files += 1 + continue + + # read from source and write to destination + try: + data = source_storage.load_once(key) + except FileNotFoundError: + errored_files += 1 + click.echo(click.style(f" -> Missing on source: {key}", fg="yellow")) + continue + except Exception as e: + errored_files += 1 + click.echo(click.style(f" -> Error reading {key}: {str(e)}", fg="red")) + continue + + try: + storage.save(key, data) + copied_files += 1 + if prefix == "upload_files": + copied_upload_file_keys.append(key) + except Exception as e: + errored_files += 1 + click.echo(click.style(f" -> Error writing {key} to target: {str(e)}", fg="red")) + continue + + click.echo("") + click.echo(click.style("Migration summary:", fg="yellow")) + click.echo(click.style(f" Total: {total_files}", fg="white")) + click.echo(click.style(f" Copied: {copied_files}", fg="green")) + click.echo(click.style(f" Skipped: {skipped_files}", fg="white")) + if errored_files: + click.echo(click.style(f" Errors: {errored_files}", fg="red")) + + if dry_run: + click.echo(click.style("Dry-run complete. No changes were made.", fg="green")) + return + + if errored_files: + click.echo( + click.style( + "Some files failed to migrate. Review errors above before updating DB records.", + fg="yellow", + ) + ) + if update_db and not force: + if not click.confirm("Proceed to update DB storage_type despite errors?", default=False): + update_db = False + + # Optionally update DB records for upload_files.storage_type (only for successfully copied upload_files) + if update_db: + if not copied_upload_file_keys: + click.echo(click.style("No upload_files copied. Skipping DB storage_type update.", fg="yellow")) + else: + try: + source_storage_type = StorageType.LOCAL if is_source_local else StorageType.OPENDAL + updated = cast( + CursorResult, + db.session.execute( + update(UploadFile) + .where( + UploadFile.storage_type == source_storage_type, + UploadFile.key.in_(copied_upload_file_keys), + ) + .values(storage_type=dify_config.STORAGE_TYPE) + ), + ).rowcount + db.session.commit() + click.echo(click.style(f"Updated storage_type for {updated} upload_files records.", fg="green")) + except Exception as e: + db.session.rollback() + click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red")) diff --git a/api/commands/system.py b/api/commands/system.py new file mode 100644 index 0000000000..39b2e991ed --- /dev/null +++ b/api/commands/system.py @@ -0,0 +1,205 @@ +import logging + +import click +import sqlalchemy as sa +from sqlalchemy import delete, select, update +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from events.app_event import app_was_created +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from libs.db_migration_lock import DbMigrationAutoRenewLock +from libs.rsa import generate_key_pair +from models import Tenant +from models.model import App, AppMode, Conversation +from models.provider import Provider, ProviderModel + +logger = logging.getLogger(__name__) + +DB_UPGRADE_LOCK_TTL_SECONDS = 60 + + +@click.command( + "reset-encrypt-key-pair", + help="Reset the asymmetric key pair of workspace for encrypt LLM credentials. " + "After the reset, all LLM credentials will become invalid, " + "requiring re-entry." + "Only support SELF_HOSTED mode.", +) +@click.confirmation_option( + prompt=click.style( + "Are you sure you want to reset encrypt key pair? This operation cannot be rolled back!", fg="red" + ) +) +def reset_encrypt_key_pair(): + """ + Reset the encrypted key pair of workspace for encrypt LLM credentials. + After the reset, all LLM credentials will become invalid, requiring re-entry. + Only support SELF_HOSTED mode. + """ + if dify_config.EDITION != "SELF_HOSTED": + click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) + return + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + tenants = session.scalars(select(Tenant)).all() + for tenant in tenants: + if not tenant: + click.echo(click.style("No workspaces found. Run /install first.", fg="red")) + return + + tenant.encrypt_public_key = generate_key_pair(tenant.id) + + session.execute(delete(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id)) + session.execute(delete(ProviderModel).where(ProviderModel.tenant_id == tenant.id)) + + click.echo( + click.style( + f"Congratulations! The asymmetric key pair of workspace {tenant.id} has been reset.", + fg="green", + ) + ) + + +@click.command("convert-to-agent-apps", help="Convert Agent Assistant to Agent App.") +def convert_to_agent_apps(): + """ + Convert Agent Assistant to Agent App. + """ + click.echo(click.style("Starting convert to agent apps.", fg="green")) + + proceeded_app_ids = [] + + while True: + # fetch first 1000 apps + sql_query = """SELECT a.id AS id FROM apps a + INNER JOIN app_model_configs am ON a.app_model_config_id=am.id + WHERE a.mode = 'chat' + AND am.agent_mode is not null + AND ( + am.agent_mode like '%"strategy": "function_call"%' + OR am.agent_mode like '%"strategy": "react"%' + ) + AND ( + am.agent_mode like '{"enabled": true%' + OR am.agent_mode like '{"max_iteration": %' + ) ORDER BY a.created_at DESC LIMIT 1000 + """ + + with db.engine.begin() as conn: + rs = conn.execute(sa.text(sql_query)) + + apps = [] + for i in rs: + app_id = str(i.id) + if app_id not in proceeded_app_ids: + proceeded_app_ids.append(app_id) + app = db.session.scalar(select(App).where(App.id == app_id)) + if app is not None: + apps.append(app) + + if len(apps) == 0: + break + + for app in apps: + click.echo(f"Converting app: {app.id}") + + try: + app.mode = AppMode.AGENT_CHAT + db.session.commit() + + # update conversation mode to agent + db.session.execute( + update(Conversation).where(Conversation.app_id == app.id).values(mode=AppMode.AGENT_CHAT) + ) + + db.session.commit() + click.echo(click.style(f"Converted app: {app.id}", fg="green")) + except Exception as e: + click.echo(click.style(f"Convert app error: {e.__class__.__name__} {str(e)}", fg="red")) + + click.echo(click.style(f"Conversion complete. Converted {len(proceeded_app_ids)} agent apps.", fg="green")) + + +@click.command("upgrade-db", help="Upgrade the database") +def upgrade_db(): + click.echo("Preparing database migration...") + lock = DbMigrationAutoRenewLock( + redis_client=redis_client, + name="db_upgrade_lock", + ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, + logger=logger, + log_context="db_migration", + ) + if lock.acquire(blocking=False): + migration_succeeded = False + try: + click.echo(click.style("Starting database migration.", fg="green")) + + # run db migration + import flask_migrate + + flask_migrate.upgrade() + + migration_succeeded = True + click.echo(click.style("Database migration successful!", fg="green")) + + except Exception as e: + logger.exception("Failed to execute database migration") + click.echo(click.style(f"Database migration failed: {e}", fg="red")) + raise SystemExit(1) + finally: + status = "successful" if migration_succeeded else "failed" + lock.release_safely(status=status) + else: + click.echo("Database migration skipped") + + +@click.command("fix-app-site-missing", help="Fix app related site missing issue.") +def fix_app_site_missing(): + """ + Fix app related site missing issue. + """ + click.echo(click.style("Starting fix for missing app-related sites.", fg="green")) + + failed_app_ids = [] + while True: + sql = """select apps.id as id from apps left join sites on sites.app_id=apps.id +where sites.id is null limit 1000""" + with db.engine.begin() as conn: + rs = conn.execute(sa.text(sql)) + + processed_count = 0 + for i in rs: + processed_count += 1 + app_id = str(i.id) + + if app_id in failed_app_ids: + continue + + try: + app = db.session.scalar(select(App).where(App.id == app_id)) + if not app: + logger.info("App %s not found", app_id) + continue + + tenant = app.tenant + if tenant: + accounts = tenant.get_accounts() + if not accounts: + logger.info("Fix failed for app %s", app.id) + continue + + account = accounts[0] + logger.info("Fixing missing site for app %s", app.id) + app_was_created.send(app, account=account) + except Exception: + failed_app_ids.append(app_id) + click.echo(click.style(f"Failed to fix missing site for app {app_id}", fg="red")) + logger.exception("Failed to fix app related site missing issue, app_id: %s", app_id) + continue + + if not processed_count: + break + + click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) diff --git a/api/commands/vector.py b/api/commands/vector.py new file mode 100644 index 0000000000..4cf11c9ad1 --- /dev/null +++ b/api/commands/vector.py @@ -0,0 +1,467 @@ +import json + +import click +from flask import current_app +from sqlalchemy import select +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.rag.datasource.vdb.vector_factory import Vector +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.models.document import ChildDocument, Document +from extensions.ext_database import db +from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment +from models.dataset import Document as DatasetDocument +from models.enums import DatasetMetadataType, IndexingStatus, SegmentStatus +from models.model import App, AppAnnotationSetting, MessageAnnotation + + +@click.command("vdb-migrate", help="Migrate vector db.") +@click.option("--scope", default="all", prompt=False, help="The scope of vector database to migrate, Default is All.") +def vdb_migrate(scope: str): + if scope in {"knowledge", "all"}: + migrate_knowledge_vector_database() + if scope in {"annotation", "all"}: + migrate_annotation_vector_database() + + +def migrate_annotation_vector_database(): + """ + Migrate annotation datas to target vector database . + """ + click.echo(click.style("Starting annotation data migration.", fg="green")) + create_count = 0 + skipped_count = 0 + total_count = 0 + page = 1 + while True: + try: + # get apps info + per_page = 50 + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + apps = session.scalars( + select(App) + .where(App.status == "normal") + .order_by(App.created_at.desc()) + .limit(per_page) + .offset((page - 1) * per_page) + ).all() + if not apps: + break + except SQLAlchemyError: + raise + + page += 1 + for app in apps: + total_count = total_count + 1 + click.echo( + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." + ) + try: + click.echo(f"Creating app annotation index: {app.id}") + with sessionmaker(db.engine, expire_on_commit=False).begin() as session: + app_annotation_setting = session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).limit(1) + ) + + if not app_annotation_setting: + skipped_count = skipped_count + 1 + click.echo(f"App annotation setting disabled: {app.id}") + continue + # get dataset_collection_binding info + dataset_collection_binding = session.scalar( + select(DatasetCollectionBinding).where( + DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id + ) + ) + if not dataset_collection_binding: + click.echo(f"App annotation collection binding not found: {app.id}") + continue + annotations = session.scalars( + select(MessageAnnotation).where(MessageAnnotation.app_id == app.id) + ).all() + dataset = Dataset( + id=app.id, + tenant_id=app.tenant_id, + indexing_technique="high_quality", + embedding_model_provider=dataset_collection_binding.provider_name, + embedding_model=dataset_collection_binding.model_name, + collection_binding_id=dataset_collection_binding.id, + ) + documents = [] + if annotations: + for annotation in annotations: + document = Document( + page_content=annotation.question_text, + metadata={"annotation_id": annotation.id, "app_id": app.id, "doc_id": annotation.id}, + ) + documents.append(document) + + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) + click.echo(f"Migrating annotations for app: {app.id}.") + + try: + vector.delete() + click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) + raise e + if documents: + try: + click.echo( + click.style( + f"Creating vector index with {len(documents)} annotations for app {app.id}.", + fg="green", + ) + ) + vector.create(documents) + click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to created vector index for app {app.id}.", fg="red")) + raise e + click.echo(f"Successfully migrated app annotation {app.id}.") + create_count += 1 + except Exception as e: + click.echo( + click.style(f"Error creating app annotation index: {e.__class__.__name__} {str(e)}", fg="red") + ) + continue + + click.echo( + click.style( + f"Migration complete. Created {create_count} app annotation indexes. Skipped {skipped_count} apps.", + fg="green", + ) + ) + + +def migrate_knowledge_vector_database(): + """ + Migrate vector database datas to target vector database . + """ + click.echo(click.style("Starting vector database migration.", fg="green")) + create_count = 0 + skipped_count = 0 + total_count = 0 + vector_type = dify_config.VECTOR_STORE + upper_collection_vector_types = { + VectorType.MILVUS, + VectorType.PGVECTOR, + VectorType.VASTBASE, + VectorType.RELYT, + VectorType.WEAVIATE, + VectorType.ORACLE, + VectorType.ELASTICSEARCH, + VectorType.OPENGAUSS, + VectorType.TABLESTORE, + VectorType.MATRIXONE, + } + lower_collection_vector_types = { + VectorType.ANALYTICDB, + VectorType.HOLOGRES, + VectorType.CHROMA, + VectorType.MYSCALE, + VectorType.PGVECTO_RS, + VectorType.TIDB_VECTOR, + VectorType.OPENSEARCH, + VectorType.TENCENT, + VectorType.BAIDU, + VectorType.VIKINGDB, + VectorType.UPSTASH, + VectorType.COUCHBASE, + VectorType.OCEANBASE, + } + page = 1 + while True: + try: + stmt = ( + select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) + ) + + datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) + if not datasets.items: + break + except SQLAlchemyError: + raise + + page += 1 + for dataset in datasets: + total_count = total_count + 1 + click.echo( + f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." + ) + try: + click.echo(f"Creating dataset vector database index: {dataset.id}") + if dataset.index_struct_dict: + if dataset.index_struct_dict["type"] == vector_type: + skipped_count = skipped_count + 1 + continue + collection_name = "" + dataset_id = dataset.id + if vector_type in upper_collection_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + elif vector_type == VectorType.QDRANT: + if dataset.collection_binding_id: + dataset_collection_binding = db.session.execute( + select(DatasetCollectionBinding).where( + DatasetCollectionBinding.id == dataset.collection_binding_id + ) + ).scalar_one_or_none() + if dataset_collection_binding: + collection_name = dataset_collection_binding.collection_name + else: + raise ValueError("Dataset Collection Binding not found") + else: + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + + elif vector_type in lower_collection_vector_types: + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() + else: + raise ValueError(f"Vector store {vector_type} is not supported.") + + index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}} + dataset.index_struct = json.dumps(index_struct_dict) + vector = Vector(dataset) + click.echo(f"Migrating dataset {dataset.id}.") + + try: + vector.delete() + click.echo( + click.style(f"Deleted vector index {collection_name} for dataset {dataset.id}.", fg="green") + ) + except Exception as e: + click.echo( + click.style( + f"Failed to delete vector index {collection_name} for dataset {dataset.id}.", fg="red" + ) + ) + raise e + + dataset_documents = db.session.scalars( + select(DatasetDocument).where( + DatasetDocument.dataset_id == dataset.id, + DatasetDocument.indexing_status == IndexingStatus.COMPLETED, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + ).all() + + documents = [] + segments_count = 0 + for dataset_document in dataset_documents: + segments = db.session.scalars( + select(DocumentSegment).where( + DocumentSegment.document_id == dataset_document.id, + DocumentSegment.status == SegmentStatus.COMPLETED, + DocumentSegment.enabled == True, + ) + ).all() + + for segment in segments: + document = Document( + page_content=segment.content, + metadata={ + "doc_id": segment.index_node_id, + "doc_hash": segment.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + if dataset_document.doc_form == "hierarchical_model": + child_chunks = segment.get_child_chunks() + if child_chunks: + child_documents = [] + for child_chunk in child_chunks: + child_document = ChildDocument( + page_content=child_chunk.content, + metadata={ + "doc_id": child_chunk.index_node_id, + "doc_hash": child_chunk.index_node_hash, + "document_id": segment.document_id, + "dataset_id": segment.dataset_id, + }, + ) + child_documents.append(child_document) + document.children = child_documents + + documents.append(document) + segments_count = segments_count + 1 + + if documents: + try: + click.echo( + click.style( + f"Creating vector index with {len(documents)} documents of {segments_count}" + f" segments for dataset {dataset.id}.", + fg="green", + ) + ) + all_child_documents = [] + for doc in documents: + if doc.children: + all_child_documents.extend(doc.children) + vector.create(documents) + if all_child_documents: + vector.create(all_child_documents) + click.echo(click.style(f"Created vector index for dataset {dataset.id}.", fg="green")) + except Exception as e: + click.echo(click.style(f"Failed to created vector index for dataset {dataset.id}.", fg="red")) + raise e + db.session.add(dataset) + db.session.commit() + click.echo(f"Successfully migrated dataset {dataset.id}.") + create_count += 1 + except Exception as e: + db.session.rollback() + click.echo(click.style(f"Error creating dataset index: {e.__class__.__name__} {str(e)}", fg="red")) + continue + + click.echo( + click.style( + f"Migration complete. Created {create_count} dataset indexes. Skipped {skipped_count} datasets.", fg="green" + ) + ) + + +@click.command("add-qdrant-index", help="Add Qdrant index.") +@click.option("--field", default="metadata.doc_id", prompt=False, help="Index field , default is metadata.doc_id.") +def add_qdrant_index(field: str): + click.echo(click.style("Starting Qdrant index creation.", fg="green")) + + create_count = 0 + + try: + bindings = db.session.scalars(select(DatasetCollectionBinding)).all() + if not bindings: + click.echo(click.style("No dataset collection bindings found.", fg="red")) + return + import qdrant_client + from qdrant_client.http.exceptions import UnexpectedResponse + from qdrant_client.http.models import PayloadSchemaType + + from core.rag.datasource.vdb.qdrant.qdrant_vector import PathQdrantParams, QdrantConfig + + for binding in bindings: + if dify_config.QDRANT_URL is None: + raise ValueError("Qdrant URL is required.") + qdrant_config = QdrantConfig( + endpoint=dify_config.QDRANT_URL, + api_key=dify_config.QDRANT_API_KEY, + root_path=current_app.root_path, + timeout=dify_config.QDRANT_CLIENT_TIMEOUT, + grpc_port=dify_config.QDRANT_GRPC_PORT, + prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + ) + try: + params = qdrant_config.to_qdrant_params() + # Check the type before using + if isinstance(params, PathQdrantParams): + # PathQdrantParams case + client = qdrant_client.QdrantClient(path=params.path) + else: + # UrlQdrantParams case - params is UrlQdrantParams + client = qdrant_client.QdrantClient( + url=params.url, + api_key=params.api_key, + timeout=int(params.timeout), + verify=params.verify, + grpc_port=params.grpc_port, + prefer_grpc=params.prefer_grpc, + ) + # create payload index + client.create_payload_index(binding.collection_name, field, field_schema=PayloadSchemaType.KEYWORD) + create_count += 1 + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + click.echo(click.style(f"Collection not found: {binding.collection_name}.", fg="red")) + continue + # Some other error occurred, so re-raise the exception + else: + click.echo( + click.style( + f"Failed to create Qdrant index for collection: {binding.collection_name}.", fg="red" + ) + ) + + except Exception: + click.echo(click.style("Failed to create Qdrant client.", fg="red")) + + click.echo(click.style(f"Index creation complete. Created {create_count} collection indexes.", fg="green")) + + +@click.command("old-metadata-migration", help="Old metadata migration.") +def old_metadata_migration(): + """ + Old metadata migration. + """ + click.echo(click.style("Starting old metadata migration.", fg="green")) + + page = 1 + while True: + try: + stmt = ( + select(DatasetDocument) + .where(DatasetDocument.doc_metadata.is_not(None)) + .order_by(DatasetDocument.created_at.desc()) + ) + documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) + except SQLAlchemyError: + raise + if not documents: + break + for document in documents: + if document.doc_metadata: + doc_metadata = document.doc_metadata + for key in doc_metadata: + for field in BuiltInField: + if field.value == key: + break + else: + dataset_metadata = db.session.scalar( + select(DatasetMetadata) + .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key) + .limit(1) + ) + if not dataset_metadata: + dataset_metadata = DatasetMetadata( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + name=key, + type=DatasetMetadataType.STRING, + created_by=document.created_by, + ) + db.session.add(dataset_metadata) + db.session.flush() + dataset_metadata_binding: DatasetMetadataBinding | None = DatasetMetadataBinding( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + metadata_id=dataset_metadata.id, + document_id=document.id, + created_by=document.created_by, + ) + db.session.add(dataset_metadata_binding) + else: + dataset_metadata_binding = db.session.scalar( + select(DatasetMetadataBinding) + .where( + DatasetMetadataBinding.dataset_id == document.dataset_id, + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == dataset_metadata.id, + ) + .limit(1) + ) + if not dataset_metadata_binding: + dataset_metadata_binding = DatasetMetadataBinding( + tenant_id=document.tenant_id, + dataset_id=document.dataset_id, + metadata_id=dataset_metadata.id, + document_id=document.id, + created_by=document.created_by, + ) + db.session.add(dataset_metadata_binding) + db.session.commit() + page += 1 + click.echo(click.style("Old metadata migration completed.", fg="green")) diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index eda6345e14..f8447c6979 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -18,3 +18,7 @@ class EnterpriseFeatureConfig(BaseSettings): description="Allow customization of the enterprise logo.", default=False, ) + + ENTERPRISE_REQUEST_TIMEOUT: int = Field( + ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 + ) diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 0532a42371..15ac8bf0bf 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -26,6 +26,7 @@ from .vdb.chroma_config import ChromaConfig from .vdb.clickzetta_config import ClickzettaConfig from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig +from .vdb.hologres_config import HologresConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.iris_config import IrisVectorConfig from .vdb.lindorm_config import LindormConfig @@ -347,6 +348,7 @@ class MiddlewareConfig( AnalyticdbConfig, ChromaConfig, ClickzettaConfig, + HologresConfig, HuaweiCloudConfig, IrisVectorConfig, MilvusConfig, diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 4705b28c69..3b91207545 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -1,4 +1,4 @@ -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt +from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, field_validator from pydantic_settings import BaseSettings @@ -111,3 +111,18 @@ class RedisConfig(BaseSettings): description="Enable client side cache in redis", default=False, ) + + REDIS_MAX_CONNECTIONS: PositiveInt | None = Field( + description="Maximum connections in the Redis connection pool (unset for library default)", + default=None, + ) + + @field_validator("REDIS_MAX_CONNECTIONS", mode="before") + @classmethod + def _empty_string_to_none_for_max_conns(cls, v): + """Allow empty string in env/.env to mean 'unset' (None).""" + if v is None: + return None + if isinstance(v, str) and v.strip() == "": + return None + return v diff --git a/api/configs/middleware/cache/redis_pubsub_config.py b/api/configs/middleware/cache/redis_pubsub_config.py index a72e1dd28f..0a166818b3 100644 --- a/api/configs/middleware/cache/redis_pubsub_config.py +++ b/api/configs/middleware/cache/redis_pubsub_config.py @@ -1,7 +1,7 @@ -from typing import Literal, Protocol +from typing import Literal, Protocol, cast from urllib.parse import quote_plus, urlunparse -from pydantic import Field +from pydantic import AliasChoices, Field from pydantic_settings import BaseSettings @@ -12,54 +12,66 @@ class RedisConfigDefaults(Protocol): REDIS_PASSWORD: str | None REDIS_DB: int REDIS_USE_SSL: bool - REDIS_USE_SENTINEL: bool | None - REDIS_USE_CLUSTERS: bool -class RedisConfigDefaultsMixin: - def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults: - return self +def _redis_defaults(config: object) -> RedisConfigDefaults: + return cast(RedisConfigDefaults, config) -class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): +class RedisPubSubConfig(BaseSettings): """ - Configuration settings for Redis pub/sub streaming. + Configuration settings for event transport between API and workers. + + Supported transports: + - pubsub: Redis PUBLISH/SUBSCRIBE (at-most-once) + - sharded: Redis 7+ Sharded Pub/Sub (at-most-once, better scaling) + - streams: Redis Streams (at-least-once, supports late subscribers) """ PUBSUB_REDIS_URL: str | None = Field( - alias="PUBSUB_REDIS_URL", + validation_alias=AliasChoices("EVENT_BUS_REDIS_URL", "PUBSUB_REDIS_URL"), description=( - "Redis connection URL for pub/sub streaming events between API " - "and celery worker, defaults to url constructed from " - "`REDIS_*` configurations" + "Redis connection URL for streaming events between API and celery worker; " + "defaults to URL constructed from `REDIS_*` configurations. Also accepts ENV: EVENT_BUS_REDIS_URL." ), default=None, ) PUBSUB_REDIS_USE_CLUSTERS: bool = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_USE_CLUSTERS", "PUBSUB_REDIS_USE_CLUSTERS"), description=( - "Enable Redis Cluster mode for pub/sub streaming. It's highly " - "recommended to enable this for large deployments." + "Enable Redis Cluster mode for pub/sub or streams transport. Recommended for large deployments. " + "Also accepts ENV: EVENT_BUS_REDIS_USE_CLUSTERS." ), default=False, ) - PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field( + PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded", "streams"] = Field( + validation_alias=AliasChoices("EVENT_BUS_REDIS_CHANNEL_TYPE", "PUBSUB_REDIS_CHANNEL_TYPE"), description=( - "Pub/sub channel type for streaming events. " - "Valid options are:\n" - "\n" - " - pubsub: for normal Pub/Sub\n" - " - sharded: for sharded Pub/Sub\n" - "\n" - "It's highly recommended to use sharded Pub/Sub AND redis cluster " - "for large deployments." + "Event transport type. Options are:\n\n" + " - pubsub: normal Pub/Sub (at-most-once)\n" + " - sharded: sharded Pub/Sub (at-most-once)\n" + " - streams: Redis Streams (at-least-once, recommended to avoid subscriber races)\n\n" + "Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs.\n" + "Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce\n" + "the risk of data loss from Redis auto-eviction under memory pressure.\n" + "Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE." ), default="pubsub", ) + PUBSUB_STREAMS_RETENTION_SECONDS: int = Field( + validation_alias=AliasChoices("EVENT_BUS_STREAMS_RETENTION_SECONDS", "PUBSUB_STREAMS_RETENTION_SECONDS"), + description=( + "When using 'streams', expire each stream key this many seconds after the last event is published. " + "Also accepts ENV: EVENT_BUS_STREAMS_RETENTION_SECONDS." + ), + default=600, + ) + def _build_default_pubsub_url(self) -> str: - defaults = self._redis_defaults() + defaults = _redis_defaults(self) if not defaults.REDIS_HOST or not defaults.REDIS_PORT: raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed") @@ -76,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin): if userinfo: userinfo = f"{userinfo}@" - host = defaults.REDIS_HOST - port = defaults.REDIS_PORT db = defaults.REDIS_DB - netloc = f"{userinfo}{host}:{port}" + netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}" return urlunparse((scheme, netloc, f"/{db}", "", "", "")) @property diff --git a/api/configs/middleware/vdb/hologres_config.py b/api/configs/middleware/vdb/hologres_config.py new file mode 100644 index 0000000000..9812cce268 --- /dev/null +++ b/api/configs/middleware/vdb/hologres_config.py @@ -0,0 +1,68 @@ +from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType +from pydantic import Field +from pydantic_settings import BaseSettings + + +class HologresConfig(BaseSettings): + """ + Configuration settings for Hologres vector database. + + Hologres is compatible with PostgreSQL protocol. + access_key_id is used as the PostgreSQL username, + and access_key_secret is used as the PostgreSQL password. + """ + + HOLOGRES_HOST: str | None = Field( + description="Hostname or IP address of the Hologres instance.", + default=None, + ) + + HOLOGRES_PORT: int = Field( + description="Port number for connecting to the Hologres instance.", + default=80, + ) + + HOLOGRES_DATABASE: str | None = Field( + description="Name of the Hologres database to connect to.", + default=None, + ) + + HOLOGRES_ACCESS_KEY_ID: str | None = Field( + description="Alibaba Cloud AccessKey ID, also used as the PostgreSQL username.", + default=None, + ) + + HOLOGRES_ACCESS_KEY_SECRET: str | None = Field( + description="Alibaba Cloud AccessKey Secret, also used as the PostgreSQL password.", + default=None, + ) + + HOLOGRES_SCHEMA: str = Field( + description="Schema name in the Hologres database.", + default="public", + ) + + HOLOGRES_TOKENIZER: TokenizerType = Field( + description="Tokenizer for full-text search index (e.g., 'jieba', 'ik', 'standard', 'simple').", + default="jieba", + ) + + HOLOGRES_DISTANCE_METHOD: DistanceType = Field( + description="Distance method for vector index (e.g., 'Cosine', 'Euclidean', 'InnerProduct').", + default="Cosine", + ) + + HOLOGRES_BASE_QUANTIZATION_TYPE: BaseQuantizationType = Field( + description="Base quantization type for vector index (e.g., 'rabitq', 'sq8', 'fp16', 'fp32').", + default="rabitq", + ) + + HOLOGRES_MAX_DEGREE: int = Field( + description="Max degree (M) parameter for HNSW vector index.", + default=64, + ) + + HOLOGRES_EF_CONSTRUCTION: int = Field( + description="ef_construction parameter for HNSW vector index.", + default=400, + ) diff --git a/api/configs/middleware/vdb/weaviate_config.py b/api/configs/middleware/vdb/weaviate_config.py index 6f4fccaa7f..2d1216c0d1 100644 --- a/api/configs/middleware/vdb/weaviate_config.py +++ b/api/configs/middleware/vdb/weaviate_config.py @@ -17,11 +17,6 @@ class WeaviateConfig(BaseSettings): default=None, ) - WEAVIATE_GRPC_ENABLED: bool = Field( - description="Whether to enable gRPC for Weaviate connection (True for gRPC, False for HTTP)", - default=True, - ) - WEAVIATE_GRPC_ENDPOINT: str | None = Field( description="URL of the Weaviate gRPC server (e.g., 'grpc://localhost:50051' or 'grpcs://weaviate.example.com:443')", default=None, diff --git a/api/context/__init__.py b/api/context/__init__.py index aebf9750ce..969e5f583d 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -12,7 +12,7 @@ or any other web framework. import contextvars from collections.abc import Callable -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( ExecutionContext, IExecutionContext, NullAppContext, diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 2d465c8cf4..324a9ee8b4 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,8 +10,8 @@ from typing import Any, final from flask import Flask, current_app, g -from core.workflow.context import register_context_capturer -from core.workflow.context.execution_context import ( +from dify_graph.context import register_context_capturer +from dify_graph.context.execution_context import ( AppContext, IExecutionContext, ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 9b30db8b75..ff5326dade 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -4,7 +4,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, computed_field -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 5f038ac73e..475280716e 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -39,6 +39,7 @@ from . import ( feature, human_input_form, init_validate, + notification, ping, setup, spec, @@ -192,6 +193,7 @@ __all__ = [ "model_config", "model_providers", "models", + "notification", "oauth", "oauth_server", "ops_trace", diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 03b602f6e8..6c3a6a8c1f 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -1,3 +1,5 @@ +import csv +import io from collections.abc import Callable from functools import wraps from typing import ParamSpec, TypeVar @@ -6,7 +8,7 @@ from flask import request from flask_restx import Resource from pydantic import BaseModel, Field, field_validator from sqlalchemy import select -from werkzeug.exceptions import NotFound, Unauthorized +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized from configs import dify_config from constants.languages import supported_language @@ -16,6 +18,7 @@ from core.db.session_factory import session_factory from extensions.ext_database import db from libs.token import extract_access_token from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp +from services.billing_service import BillingService P = ParamSpec("P") R = TypeVar("R") @@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource): db.session.commit() return {"result": "success"}, 204 + + +class LangContentPayload(BaseModel): + lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'") + title: str = Field(...) + subtitle: str | None = Field(default=None) + body: str = Field(...) + title_pic_url: str | None = Field(default=None) + + +class UpsertNotificationPayload(BaseModel): + notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update") + contents: list[LangContentPayload] = Field(..., min_length=1) + start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z") + end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z") + frequency: str = Field(default="once", description="'once' | 'every_page_load'") + status: str = Field(default="active", description="'active' | 'inactive'") + + +class BatchAddNotificationAccountsPayload(BaseModel): + notification_id: str = Field(...) + user_email: list[str] = Field(..., description="List of account email addresses") + + +console_ns.schema_model( + UpsertNotificationPayload.__name__, + UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + +console_ns.schema_model( + BatchAddNotificationAccountsPayload.__name__, + BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0), +) + + +@console_ns.route("/admin/upsert_notification") +class UpsertNotificationApi(Resource): + @console_ns.doc("upsert_notification") + @console_ns.doc( + description=( + "Create or update an in-product notification. " + "Supply notification_id to update an existing one; omit it to create a new one. " + "Pass at least one language variant in contents (zh / en / jp)." + ) + ) + @console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__]) + @console_ns.response(200, "Notification upserted successfully") + @only_edition_cloud + @admin_required + def post(self): + payload = UpsertNotificationPayload.model_validate(console_ns.payload) + result = BillingService.upsert_notification( + contents=[c.model_dump() for c in payload.contents], + frequency=payload.frequency, + status=payload.status, + notification_id=payload.notification_id, + start_time=payload.start_time, + end_time=payload.end_time, + ) + return {"result": "success", "notification_id": result.get("notificationId")}, 200 + + +@console_ns.route("/admin/batch_add_notification_accounts") +class BatchAddNotificationAccountsApi(Resource): + @console_ns.doc("batch_add_notification_accounts") + @console_ns.doc( + description=( + "Register target accounts for a notification by email address. " + 'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. ' + "File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) " + "plus a 'notification_id' field. " + "Emails that do not match any account are silently skipped." + ) + ) + @console_ns.response(200, "Accounts added successfully") + @only_edition_cloud + @admin_required + def post(self): + from models.account import Account + + if "file" in request.files: + notification_id = request.form.get("notification_id", "").strip() + if not notification_id: + raise BadRequest("notification_id is required.") + emails = self._parse_emails_from_file() + else: + payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload) + notification_id = payload.notification_id + emails = payload.user_email + + if not emails: + raise BadRequest("No valid email addresses provided.") + + # Resolve emails → account IDs in chunks to avoid large IN-clause + account_ids: list[str] = [] + chunk_size = 500 + for i in range(0, len(emails), chunk_size): + chunk = emails[i : i + chunk_size] + rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all() + account_ids.extend(str(row.id) for row in rows) + + if not account_ids: + raise BadRequest("None of the provided emails matched an existing account.") + + # Send to dify-saas in batches of 1000 + total_count = 0 + batch_size = 1000 + for i in range(0, len(account_ids), batch_size): + batch = account_ids[i : i + batch_size] + result = BillingService.batch_add_notification_accounts( + notification_id=notification_id, + account_ids=batch, + ) + total_count += result.get("count", 0) + + return { + "result": "success", + "emails_provided": len(emails), + "accounts_matched": len(account_ids), + "count": total_count, + }, 200 + + @staticmethod + def _parse_emails_from_file() -> list[str]: + """Parse email addresses from an uploaded CSV or TXT file.""" + file = request.files["file"] + if not file.filename: + raise BadRequest("Uploaded file has no filename.") + + filename_lower = file.filename.lower() + if not filename_lower.endswith((".csv", ".txt")): + raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.") + + try: + content = file.read().decode("utf-8") + except UnicodeDecodeError: + try: + file.seek(0) + content = file.read().decode("gbk") + except UnicodeDecodeError: + raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.") + + emails: list[str] = [] + if filename_lower.endswith(".csv"): + reader = csv.reader(io.StringIO(content)) + for row in reader: + for cell in row: + cell = cell.strip() + if cell: + emails.append(cell) + else: + for line in content.splitlines(): + line = line.strip() + if line: + emails.append(line) + + # Deduplicate while preserving order + seen: set[str] = set() + unique_emails: list[str] = [] + for email in emails: + if email.lower() not in seen: + seen.add(email.lower()) + unique_emails.append(email) + + return unique_emails diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py index b6d1df319e..6c54be84a8 100644 --- a/api/controllers/console/apikey.py +++ b/api/controllers/console/apikey.py @@ -1,7 +1,7 @@ import flask_restx from flask_restx import Resource, fields, marshal_with from flask_restx._http import HTTPStatus -from sqlalchemy import select +from sqlalchemy import delete, func, select from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -33,16 +33,10 @@ api_key_list_model = console_ns.model( def _get_resource(resource_id, tenant_id, resource_model): - if resource_model == App: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() - else: - with Session(db.engine) as session: - resource = session.execute( - select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) - ).scalar_one_or_none() + with Session(db.engine) as session: + resource = session.execute( + select(resource_model).filter_by(id=resource_id, tenant_id=tenant_id) + ).scalar_one_or_none() if resource is None: flask_restx.abort(HTTPStatus.NOT_FOUND, message=f"{resource_model.__name__} not found.") @@ -80,10 +74,13 @@ class BaseApiKeyListResource(Resource): resource_id = str(resource_id) _, current_tenant_id = current_account_with_tenant() _get_resource(resource_id, current_tenant_id, self.resource_model) - current_key_count = ( - db.session.query(ApiToken) - .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id) - .count() + current_key_count: int = ( + db.session.scalar( + select(func.count(ApiToken.id)).where( + ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id + ) + ) + or 0 ) if current_key_count >= self.max_keys: @@ -119,14 +116,14 @@ class BaseApiKeyResource(Resource): if not current_user.is_admin_or_owner: raise Forbidden() - key = ( - db.session.query(ApiToken) + key = db.session.scalar( + select(ApiToken) .where( getattr(ApiToken, self.resource_id_field) == resource_id, ApiToken.type == self.resource_type, ApiToken.id == api_key_id, ) - .first() + .limit(1) ) if key is None: @@ -137,7 +134,7 @@ class BaseApiKeyResource(Resource): assert key is not None # nosec - for type checker only ApiTokenCache.delete(key.token, key.type) - db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete() + db.session.execute(delete(ApiToken).where(ApiToken.id == api_key_id)) db.session.commit() return {"result": "success"}, 204 diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index e799e98d3e..5ac0e342e6 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -25,8 +25,9 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.enums import NodeType, WorkflowExecutionStatus -from core.workflow.file import helpers as file_helpers +from core.trigger.constants import TRIGGER_NODE_TYPES +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow @@ -508,11 +509,7 @@ class AppListApi(Resource): .scalars() .all() ) - trigger_node_types = { - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - } + trigger_node_types = TRIGGER_NODE_TYPES for workflow in draft_workflows: node_id = None try: diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 941db325bf..2c5e8d29ee 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -22,7 +22,7 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 2922121a54..4d7ddfea13 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -26,7 +26,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 5eb61493c3..74750981dd 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -5,7 +5,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, or_ -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import selectinload from werkzeug.exceptions import NotFound from controllers.console import console_ns @@ -376,8 +376,12 @@ class CompletionConversationApi(Resource): # FIXME, the type ignore in this file if args.annotation_status == "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) elif args.annotation_status == "not_annotated": query = ( @@ -511,8 +515,12 @@ class ChatConversationApi(Resource): match args.annotation_status: case "annotated": - query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore - MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + query = ( + query.options(selectinload(Conversation.message_annotations)) # type: ignore[arg-type] + .join( # type: ignore + MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id + ) + .distinct() ) case "not_annotated": query = ( diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 1ac55b5e8d..af4ac450bb 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -18,7 +18,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index dd982b6d7b..4b20418b53 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -1,5 +1,4 @@ import json -from enum import StrEnum from flask_restx import Resource, marshal_with from pydantic import BaseModel, Field @@ -11,6 +10,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm from extensions.ext_database import db from fields.app_fields import app_server_fields from libs.login import current_account_with_tenant, login_required +from models.enums import AppMCPServerStatus from models.model import AppMCPServer DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -19,11 +19,6 @@ DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" app_server_model = console_ns.model("AppServer", app_server_fields) -class AppMCPServerStatus(StrEnum): - ACTIVE = "active" - INACTIVE = "inactive" - - class MCPServerCreatePayload(BaseModel): description: str | None = Field(default=None, description="Server description") parameters: dict = Field(..., description="Server parameters configuration") @@ -108,18 +103,19 @@ class AppMCPServerController(Resource): raise NotFound() description = payload.description - if description is None: - pass - elif not description: + if description is None or not description: server.description = app_model.description or "" else: server.description = description + server.name = app_model.name + server.parameters = json.dumps(payload.parameters, ensure_ascii=False) if payload.status: - if payload.status not in [status.value for status in AppMCPServerStatus]: + try: + server.status = AppMCPServerStatus(payload.status) + except ValueError: raise ValueError("Invalid status") - server.status = payload.status db.session.commit() return server diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 0bea777870..736e7dbe17 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -4,7 +4,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator -from sqlalchemy import exists, select +from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound from controllers.common.schema import register_schema_models @@ -24,12 +24,13 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import current_account_with_tenant, login_required +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -243,27 +244,25 @@ class ChatMessageListApi(Resource): def get(self, app_model): args = ChatMessagesQuery.model_validate(request.args.to_dict()) - conversation = ( - db.session.query(Conversation) + conversation = db.session.scalar( + select(Conversation) .where(Conversation.id == args.conversation_id, Conversation.app_id == app_model.id) - .first() + .limit(1) ) if not conversation: raise NotFound("Conversation Not Exists.") if args.first_id: - first_message = ( - db.session.query(Message) - .where(Message.conversation_id == conversation.id, Message.id == args.first_id) - .first() + first_message = db.session.scalar( + select(Message).where(Message.conversation_id == conversation.id, Message.id == args.first_id).limit(1) ) if not first_message: raise NotFound("First message not found") - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where( Message.conversation_id == conversation.id, Message.created_at < first_message.created_at, @@ -271,16 +270,14 @@ class ChatMessageListApi(Resource): ) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() else: - history_messages = ( - db.session.query(Message) + history_messages = db.session.scalars( + select(Message) .where(Message.conversation_id == conversation.id) .order_by(Message.created_at.desc()) .limit(args.limit) - .all() - ) + ).all() # Initialize has_more based on whether we have a full page if len(history_messages) == args.limit: @@ -325,7 +322,9 @@ class MessageFeedbackApi(Resource): message_id = str(args.message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") @@ -335,7 +334,7 @@ class MessageFeedbackApi(Resource): if not args.rating and feedback: db.session.delete(feedback) elif args.rating and feedback: - feedback.rating = args.rating + feedback.rating = FeedbackRating(args.rating) feedback.content = args.content elif not args.rating and not feedback: raise ValueError("rating cannot be None when feedback not exists") @@ -347,9 +346,9 @@ class MessageFeedbackApi(Resource): app_id=app_model.id, conversation_id=message.conversation_id, message_id=message.id, - rating=rating_value, + rating=FeedbackRating(rating_value), content=args.content, - from_source="admin", + from_source=FeedbackFromSource.ADMIN, from_account_id=current_user.id, ) db.session.add(feedback) @@ -374,7 +373,9 @@ class MessageAnnotationCountApi(Resource): @login_required @account_initialization_required def get(self, app_model): - count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count() + count = db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.app_id == app_model.id) + ) return {"count": count} @@ -478,7 +479,9 @@ class MessageApi(Resource): def get(self, app_model, message_id: str): message_id = str(message_id) - message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_id, Message.app_id == app_model.id).limit(1) + ) if not message: raise NotFound("Message Not Exists.") diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index a66e9543ff..d59aa44718 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -7,7 +7,7 @@ from flask import abort, request from flask_restx import Resource, fields, marshal_with from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.console import console_ns @@ -21,17 +21,18 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginInvokeError +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE from core.trigger.debug.event_selectors import ( TriggerDebugEvent, TriggerDebugEventPoller, create_event_poller, select_trigger_debug_events, ) -from core.workflow.enums import NodeType -from core.workflow.file.models import File -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.enums import NodeType +from dify_graph.file.models import File +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory @@ -45,13 +46,14 @@ from models import App from models.model import AppMode from models.workflow import Workflow from services.app_generate_service import AppGenerateService -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" +RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" # Register models for flask_restx to avoid dict type issues in Swagger # Register in dependency order: base models first, then dependent models @@ -283,7 +285,9 @@ class DraftWorkflowApi(Resource): workflow_service = WorkflowService() try: - environment_variables_list = args.get("environment_variables") or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + args.get("environment_variables") or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -993,6 +997,43 @@ class PublishedAllWorkflowApi(Resource): } +@console_ns.route("/apps//workflows//restore") +class DraftWorkflowRestoreApi(Resource): + @console_ns.doc("restore_workflow_to_draft") + @console_ns.doc(description="Restore a published workflow version into the draft workflow") + @console_ns.doc(params={"app_id": "Application ID", "workflow_id": "Published workflow ID"}) + @console_ns.response(200, "Workflow restored successfully") + @console_ns.response(400, "Source workflow must be published") + @console_ns.response(404, "Workflow not found") + @setup_required + @login_required + @account_initialization_required + @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + @edit_permission_required + def post(self, app_model: App, workflow_id: str): + current_user, _ = current_account_with_tenant() + workflow_service = WorkflowService() + + try: + workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app_model, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + except ValueError as exc: + raise BadRequest(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/apps//workflows/") class WorkflowByIdApi(Resource): @console_ns.doc("update_workflow_by_id") @@ -1209,7 +1250,7 @@ class DraftWorkflowTriggerNodeApi(Resource): node_type: NodeType = draft_workflow.get_node_type_from_node_config(node_config) event: TriggerDebugEvent | None = None # for schedule trigger, when run single node, just execute directly - if node_type == NodeType.TRIGGER_SCHEDULE: + if node_type == TRIGGER_SCHEDULE_NODE_TYPE: event = TriggerDebugEvent( workflow_args={}, node_id=node_id, diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 6736f24a2e..9b148c3f18 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index f37598fb31..b78d97a382 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,15 +15,15 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.file import helpers as file_helpers -from core.workflow.variables.segment_group import SegmentGroup -from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.file import helpers as file_helpers +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type -from libs.login import login_required +from libs.login import current_user, login_required from models import App, AppMode from models.workflow import WorkflowDraftVariable from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService @@ -100,6 +100,18 @@ def _serialize_full_content(variable: WorkflowDraftVariable) -> dict | None: } +def _ensure_variable_access( + variable: WorkflowDraftVariable | None, + app_id: str, + variable_id: str, +) -> WorkflowDraftVariable: + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_id or variable.user_id != current_user.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + return variable + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), @@ -238,6 +250,7 @@ class WorkflowVariableCollectionApi(Resource): app_id=app_model.id, page=args.page, limit=args.limit, + user_id=current_user.id, ) return workflow_vars @@ -250,7 +263,7 @@ class WorkflowVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - draft_var_srv.delete_workflow_variables(app_model.id) + draft_var_srv.delete_user_workflow_variables(app_model.id, user_id=current_user.id) db.session.commit() return Response("", 204) @@ -287,7 +300,7 @@ class NodeVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=session, ) - node_vars = draft_var_srv.list_node_variables(app_model.id, node_id) + node_vars = draft_var_srv.list_node_variables(app_model.id, node_id, user_id=current_user.id) return node_vars @@ -298,7 +311,7 @@ class NodeVariableCollectionApi(Resource): def delete(self, app_model: App, node_id: str): validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) - srv.delete_node_variables(app_model.id, node_id) + srv.delete_node_variables(app_model.id, node_id, user_id=current_user.id) db.session.commit() return Response("", 204) @@ -319,11 +332,11 @@ class VariableApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - variable = draft_var_srv.get_variable(variable_id=variable_id) - if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.app_id != app_model.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=app_model.id, + variable_id=variable_id, + ) return variable @console_ns.doc("update_variable") @@ -360,11 +373,11 @@ class VariableApi(Resource): ) args_model = WorkflowDraftVariableUpdatePayload.model_validate(console_ns.payload or {}) - variable = draft_var_srv.get_variable(variable_id=variable_id) - if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.app_id != app_model.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=app_model.id, + variable_id=variable_id, + ) new_name = args_model.name raw_value = args_model.value @@ -397,11 +410,11 @@ class VariableApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - variable = draft_var_srv.get_variable(variable_id=variable_id) - if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.app_id != app_model.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=app_model.id, + variable_id=variable_id, + ) draft_var_srv.delete_variable(variable) db.session.commit() return Response("", 204) @@ -427,11 +440,11 @@ class VariableResetApi(Resource): raise NotFoundError( f"Draft workflow not found, app_id={app_model.id}", ) - variable = draft_var_srv.get_variable(variable_id=variable_id) - if variable is None: - raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.app_id != app_model.id: - raise NotFoundError(description=f"variable not found, id={variable_id}") + variable = _ensure_variable_access( + variable=draft_var_srv.get_variable(variable_id=variable_id), + app_id=app_model.id, + variable_id=variable_id, + ) resetted = draft_var_srv.reset_variable(draft_workflow, variable) db.session.commit() @@ -447,11 +460,15 @@ def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: session=session, ) if node_id == CONVERSATION_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_conversation_variables(app_model.id) + draft_vars = draft_var_srv.list_conversation_variables(app_model.id, user_id=current_user.id) elif node_id == SYSTEM_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_system_variables(app_model.id) + draft_vars = draft_var_srv.list_system_variables(app_model.id, user_id=current_user.id) else: - draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id) + draft_vars = draft_var_srv.list_node_variables( + app_id=app_model.id, + node_id=node_id, + user_id=current_user.id, + ) return draft_vars @@ -472,7 +489,7 @@ class ConversationVariableCollectionApi(Resource): if draft_workflow is None: raise NotFoundError(description=f"draft workflow not found, id={app_model.id}") draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=current_user.id) db.session.commit() return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID) diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 9ac45cf2da..7ac653395e 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -12,8 +12,8 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 112e152432..5c9023f27b 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,4 +1,5 @@ import logging +import urllib.parse import httpx from flask import current_app, redirect, request @@ -112,6 +113,9 @@ class OAuthCallback(Resource): error_text = e.response.text logger.exception("An error occurred during the OAuth process with %s: %s", provider, error_text) return {"error": "OAuth process failed"}, 400 + except ValueError as e: + logger.warning("OAuth error with %s", provider, exc_info=True) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message={urllib.parse.quote(str(e))}") if invite_token and RegisterService.is_valid_invite_token(invite_token): invitation = RegisterService.get_invitation_by_token(token=invite_token) diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 38ea5d2dae..6e59d4203c 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index cc55a2dba4..ef65b87923 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -29,12 +29,12 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.evaluation.entities.evaluation_entity import EvaluationCategory, EvaluationConfigData, EvaluationRunRequest from core.indexing_runner import IndexingRunner -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_storage import storage from fields.app_fields import app_detail_kernel_fields, related_app_list @@ -58,7 +58,8 @@ from fields.dataset_fields import ( from fields.document_fields import document_status_fields from libs.login import current_account_with_tenant, login_required from models import ApiToken, Dataset, Document, DocumentSegment, EvaluationRun, EvaluationTargetType, UploadFile -from models.dataset import DatasetPermissionEnum +from models.dataset import DatasetPermission, DatasetPermissionEnum +from models.enums import SegmentStatus from models.provider_ids import ModelProviderID from services.api_token_service import ApiTokenCache from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService @@ -131,6 +132,14 @@ def _validate_indexing_technique(value: str | None) -> str | None: return value +def _validate_doc_form(value: str | None) -> str | None: + if value is None: + return value + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + + class DatasetCreatePayload(BaseModel): name: str = Field(..., min_length=1, max_length=40) description: str = Field("", max_length=400) @@ -191,6 +200,14 @@ class IndexingEstimatePayload(BaseModel): raise ValueError("indexing_technique is required.") return result + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + result = _validate_doc_form(value) + if result is None: + return "text_model" + return result + class ConsoleDatasetListQuery(BaseModel): page: int = Field(default=1, description="Page number") @@ -259,6 +276,7 @@ def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool VectorType.BAIDU, VectorType.ALIBABACLOUD_MYSQL, VectorType.IRIS, + VectorType.HOLOGRES, } semantic_methods = {"retrieval_method": [RetrievalMethod.SEMANTIC_SEARCH.value]} @@ -335,6 +353,18 @@ class DatasetListApi(Resource): model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") data = cast(list[dict[str, Any]], marshal(datasets, dataset_detail_fields)) + dataset_ids = [item["id"] for item in data if item.get("permission") == "partial_members"] + partial_members_map: dict[str, list[str]] = {} + if dataset_ids: + permissions = db.session.execute( + select(DatasetPermission.dataset_id, DatasetPermission.account_id).where( + DatasetPermission.dataset_id.in_(dataset_ids) + ) + ).all() + + for dataset_id, account_id in permissions: + partial_members_map.setdefault(dataset_id, []).append(account_id) + for item in data: # convert embedding_model_provider to plugin standard format if item["indexing_technique"] == "high_quality" and item["embedding_model_provider"]: @@ -348,8 +378,7 @@ class DatasetListApi(Resource): item["embedding_available"] = True if item.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(item["id"]) - item.update({"partial_member_list": part_users_list}) + item.update({"partial_member_list": partial_members_map.get(item["id"], [])}) else: item.update({"partial_member_list": []}) @@ -725,13 +754,15 @@ class DatasetIndexingStatusApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields @@ -792,7 +823,7 @@ class DatasetApiKeyApi(Resource): console_ns.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code="max_keys_exceeded", + custom="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index bf097d374a..bc90c4ffbd 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -24,11 +24,11 @@ from core.errors.error import ( ) from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( @@ -42,6 +42,7 @@ from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant, login_required from models import DatasetProcessRule, Document, DocumentSegment, UploadFile from models.dataset import DocumentPipelineExecutionLog +from models.enums import IndexingStatus, SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel from services.file_service import FileService @@ -297,6 +298,7 @@ class DatasetDocumentListApi(Resource): if sort == "hit_count": sub_query = ( sa.select(DocumentSegment.document_id, sa.func.sum(DocumentSegment.hit_count).label("total_hit_count")) + .where(DocumentSegment.dataset_id == str(dataset_id)) .group_by(DocumentSegment.document_id) .subquery() ) @@ -332,13 +334,16 @@ class DatasetDocumentListApi(Resource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) .count() ) document.completed_segments = completed_segments @@ -503,7 +508,7 @@ class DocumentIndexingEstimateApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_process_rule = document.dataset_process_rule @@ -573,7 +578,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource): data_process_rule_dict = data_process_rule.to_dict() if data_process_rule else {} extract_settings = [] for document in documents: - if document.indexing_status in {"completed", "error"}: + if document.indexing_status in {IndexingStatus.COMPLETED, IndexingStatus.ERROR}: raise DocumentAlreadyFinishedError() data_source_info = document.data_source_info_dict match document.data_source_type: @@ -671,19 +676,21 @@ class DocumentBatchIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -720,20 +727,20 @@ class DocumentIndexingStatusApi(DocumentResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document_id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != SegmentStatus.RE_SEGMENT) .count() ) # Create a dictionary with document attributes and additional fields document_dict = { "id": document.id, - "indexing_status": "paused" if document.is_paused else document.indexing_status, + "indexing_status": IndexingStatus.PAUSED if document.is_paused else document.indexing_status, "processing_started_at": document.processing_started_at, "parsing_completed_at": document.parsing_completed_at, "cleaning_completed_at": document.cleaning_completed_at, @@ -955,7 +962,7 @@ class DocumentProcessingApi(DocumentResource): match action: case "pause": - if document.indexing_status != "indexing": + if document.indexing_status != IndexingStatus.INDEXING: raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id @@ -964,7 +971,7 @@ class DocumentProcessingApi(DocumentResource): db.session.commit() case "resume": - if document.indexing_status not in {"paused", "error"}: + if document.indexing_status not in {IndexingStatus.PAUSED, IndexingStatus.ERROR}: raise InvalidActionError("Document not in paused or error state.") document.paused_by = None @@ -1169,7 +1176,7 @@ class DocumentRetryApi(DocumentResource): raise ArchivedDocumentImmutableError() # 400 if document is completed - if document.indexing_status == "completed": + if document.indexing_status == IndexingStatus.COMPLETED: raise DocumentAlreadyFinishedError() retry_documents.append(document) except Exception: diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 23a668112d..3fd0f3b712 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -26,7 +26,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index db1a874437..cd568cf835 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -19,11 +19,12 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields from libs.login import current_user from models.account import Account from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.hit_testing_service import HitTestingService logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ logger = logging.getLogger(__name__) class HitTestingPayload(BaseModel): query: str = Field(max_length=250) - retrieval_model: dict[str, Any] | None = None + retrieval_model: RetrievalModel | None = None external_retrieval_model: dict[str, Any] | None = None attachment_ids: list[str] | None = None diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index 1a47e226e5..a4498005d8 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -9,9 +9,9 @@ from configs import dify_config from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.oauth import OAuthHandler +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py index 6e0cd31b8d..4f31093cfe 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline.py @@ -46,6 +46,8 @@ class PipelineTemplateDetailApi(Resource): type = request.args.get("type", default="built-in", type=str) rag_pipeline_service = RagPipelineService() pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type) + if pipeline_template is None: + return {"error": "Pipeline template not found from upstream service."}, 404 return pipeline_template, 200 diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index 7e285c8da9..c5dadb75f5 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,8 +21,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type @@ -102,6 +102,7 @@ class RagPipelineVariableCollectionApi(Resource): app_id=pipeline.id, page=query.page, limit=query.limit, + user_id=current_user.id, ) return workflow_vars @@ -111,7 +112,7 @@ class RagPipelineVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=db.session(), ) - draft_var_srv.delete_workflow_variables(pipeline.id) + draft_var_srv.delete_user_workflow_variables(pipeline.id, user_id=current_user.id) db.session.commit() return Response("", 204) @@ -144,7 +145,7 @@ class RagPipelineNodeVariableCollectionApi(Resource): draft_var_srv = WorkflowDraftVariableService( session=session, ) - node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id) + node_vars = draft_var_srv.list_node_variables(pipeline.id, node_id, user_id=current_user.id) return node_vars @@ -152,7 +153,7 @@ class RagPipelineNodeVariableCollectionApi(Resource): def delete(self, pipeline: Pipeline, node_id: str): validate_node_id(node_id) srv = WorkflowDraftVariableService(db.session()) - srv.delete_node_variables(pipeline.id, node_id) + srv.delete_node_variables(pipeline.id, node_id, user_id=current_user.id) db.session.commit() return Response("", 204) @@ -283,11 +284,11 @@ def _get_variable_list(pipeline: Pipeline, node_id) -> WorkflowDraftVariableList session=session, ) if node_id == CONVERSATION_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_conversation_variables(pipeline.id) + draft_vars = draft_var_srv.list_conversation_variables(pipeline.id, user_id=current_user.id) elif node_id == SYSTEM_VARIABLE_NODE_ID: - draft_vars = draft_var_srv.list_system_variables(pipeline.id) + draft_vars = draft_var_srv.list_system_variables(pipeline.id, user_id=current_user.id) else: - draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id) + draft_vars = draft_var_srv.list_node_variables(app_id=pipeline.id, node_id=node_id, user_id=current_user.id) return draft_vars diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 29b6b64b94..3912cc73ca 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -6,7 +6,7 @@ from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore from pydantic import BaseModel, Field from sqlalchemy.orm import Session -from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound import services from controllers.common.schema import register_schema_models @@ -16,7 +16,11 @@ from controllers.console.app.error import ( DraftWorkflowNotExist, DraftWorkflowNotSync, ) -from controllers.console.app.workflow import workflow_model, workflow_pagination_model +from controllers.console.app.workflow import ( + RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE, + workflow_model, + workflow_pagination_model, +) from controllers.console.app.workflow_run import ( workflow_run_detail_model, workflow_run_node_execution_list_model, @@ -33,7 +37,7 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper @@ -42,7 +46,8 @@ from libs.login import current_account_with_tenant, current_user, login_required from models import Account from models.dataset import Pipeline from models.model import EndUser -from services.errors.app import WorkflowHashNotEqualError +from models.workflow import Workflow +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError from services.rag_pipeline.pipeline_generate_service import PipelineGenerateService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -203,9 +208,12 @@ class DraftRagPipelineApi(Resource): abort(415) payload = DraftWorkflowSyncPayload.model_validate(payload_dict) + rag_pipeline_service = RagPipelineService() try: - environment_variables_list = payload.environment_variables or [] + environment_variables_list = Workflow.normalize_environment_variable_mappings( + payload.environment_variables or [], + ) environment_variables = [ variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] @@ -213,7 +221,6 @@ class DraftRagPipelineApi(Resource): conversation_variables = [ variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] - rag_pipeline_service = RagPipelineService() workflow = rag_pipeline_service.sync_draft_workflow( pipeline=pipeline, graph=payload.graph, @@ -705,6 +712,36 @@ class PublishedAllRagPipelineApi(Resource): } +@console_ns.route("/rag/pipelines//workflows//restore") +class RagPipelineDraftWorkflowRestoreApi(Resource): + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def post(self, pipeline: Pipeline, workflow_id: str): + current_user, _ = current_account_with_tenant() + rag_pipeline_service = RagPipelineService() + + try: + workflow = rag_pipeline_service.restore_published_workflow_to_draft( + pipeline=pipeline, + workflow_id=workflow_id, + account=current_user, + ) + except IsDraftWorkflowError as exc: + # Use a stable, predefined message to keep the 400 response consistent + raise BadRequest(RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE) from exc + except WorkflowNotFoundError as exc: + raise NotFound(str(exc)) from exc + + return { + "result": "success", + "hash": workflow.unique_hash, + "updated_at": TimestampField().format(workflow.updated_at or workflow.created_at), + } + + @console_ns.route("/rag/pipelines//workflows/") class RagPipelineByIdApi(Resource): @setup_required diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index 0311db1584..ffb9e5bb6e 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -19,7 +19,7 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/banner.py b/api/controllers/console/explore/banner.py index da306fbc9d..757061d8dd 100644 --- a/api/controllers/console/explore/banner.py +++ b/api/controllers/console/explore/banner.py @@ -1,9 +1,11 @@ from flask import request from flask_restx import Resource +from sqlalchemy import select from controllers.console import api from controllers.console.explore.wraps import explore_banner_enabled from extensions.ext_database import db +from models.enums import BannerStatus from models.model import ExporleBanner @@ -16,14 +18,18 @@ class BannerApi(Resource): language = request.args.get("language", "en-US") # Build base query for enabled banners - base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled") + base_query = select(ExporleBanner).where(ExporleBanner.status == BannerStatus.ENABLED) # Try to get banners in the requested language - banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort) + ).all() # Fallback to en-US if no banners found and language is not en-US if not banners and language != "en-US": - banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all() + banners = db.session.scalars( + base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort) + ).all() # Convert banners to serializable format result = [] for banner in banners: diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index a6e5b2822a..fcd52d2818 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -24,7 +24,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index aca766567f..0740dd0e24 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -133,13 +133,15 @@ class InstalledAppsListApi(Resource): def post(self): payload = InstalledAppCreatePayload.model_validate(console_ns.payload or {}) - recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).first() + recommended_app = db.session.scalar( + select(RecommendedApp).where(RecommendedApp.app_id == payload.app_id).limit(1) + ) if recommended_app is None: raise NotFound("Recommended app not found") _, current_tenant_id = current_account_with_tenant() - app = db.session.query(App).where(App.id == payload.app_id).first() + app = db.session.get(App, payload.app_id) if app is None: raise NotFound("App entity not found") @@ -147,10 +149,10 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(and_(InstalledApp.app_id == payload.app_id, InstalledApp.tenant_id == current_tenant_id)) - .first() + .limit(1) ) if installed_app is None: diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 88487ac96f..15e1aea361 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -21,12 +21,13 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper from libs.helper import UUIDStrOrEmpty from libs.login import current_account_with_tenant +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource): app_model=app_model, message_id=message_id, user=current_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py index 660a4d5aea..0f29627746 100644 --- a/api/controllers/console/explore/parameter.py +++ b/api/controllers/console/explore/parameter.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from controllers.common import fields from controllers.console import console_ns from controllers.console.app.error import AppUnavailableError @@ -23,14 +25,14 @@ class AppParameterApi(InstalledAppResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index f6f731df36..a8d8036f0f 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -41,8 +42,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( @@ -476,7 +477,7 @@ class TrialSitApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() @@ -541,13 +542,7 @@ class AppWorkflowApi(Resource): if not app_model.workflow_id: raise AppUnavailableError() - workflow = ( - db.session.query(Workflow) - .where( - Workflow.id == app_model.workflow_id, - ) - .first() - ) + workflow = db.session.get(Workflow, app_model.workflow_id) return workflow diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index b841bda323..7801cee473 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -21,8 +21,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from libs.login import current_account_with_tenant diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 03edb871e6..9d9337e63e 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,6 +4,7 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import abort from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import NotFound from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed @@ -24,10 +25,10 @@ def installed_app_required(view: Callable[Concatenate[InstalledApp, P], R] | Non @wraps(view) def decorated(installed_app_id: str, *args: P.args, **kwargs: P.kwargs): _, current_tenant_id = current_account_with_tenant() - installed_app = ( - db.session.query(InstalledApp) + installed_app = db.session.scalar( + select(InstalledApp) .where(InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_tenant_id) - .first() + .limit(1) ) if installed_app is None: @@ -78,7 +79,7 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs): current_user, _ = current_account_with_tenant() - trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first() + trial_app = db.session.scalar(select(TrialApp).where(TrialApp.app_id == str(app_id)).limit(1)) if trial_app is None: raise TrialAppNotAllowed() @@ -87,10 +88,10 @@ def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None): if app is None: raise TrialAppNotAllowed() - account_trial_app_record = ( - db.session.query(AccountTrialAppRecord) + account_trial_app_record = db.session.scalar( + select(AccountTrialAppRecord) .where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id) - .first() + .limit(1) ) if account_trial_app_record: if account_trial_app_record.count >= trial_app.trial_limit: diff --git a/api/controllers/console/notification.py b/api/controllers/console/notification.py new file mode 100644 index 0000000000..53e4aa3d86 --- /dev/null +++ b/api/controllers/console/notification.py @@ -0,0 +1,90 @@ +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field + +from controllers.console import console_ns +from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required +from libs.login import current_account_with_tenant, login_required +from services.billing_service import BillingService + +# Notification content is stored under three lang tags. +_FALLBACK_LANG = "en-US" + + +def _pick_lang_content(contents: dict, lang: str) -> dict: + """Return the single LangContent for *lang*, falling back to English.""" + return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {}) + + +class DismissNotificationPayload(BaseModel): + notification_id: str = Field(...) + + +@console_ns.route("/notification") +class NotificationApi(Resource): + @console_ns.doc("get_notification") + @console_ns.doc( + description=( + "Return the active in-product notification for the current user " + "in their interface language (falls back to English if unavailable). " + "The notification is NOT marked as seen here; call POST /notification/dismiss " + "when the user explicitly closes the modal." + ), + responses={ + 200: "Success — inspect should_show to decide whether to render the modal", + 401: "Unauthorized", + }, + ) + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def get(self): + current_user, _ = current_account_with_tenant() + + result = BillingService.get_account_notification(str(current_user.id)) + + # Proto JSON uses camelCase field names (Kratos default marshaling). + if not result.get("shouldShow"): + return {"should_show": False, "notifications": []}, 200 + + lang = current_user.interface_language or _FALLBACK_LANG + + notifications = [] + for notification in result.get("notifications") or []: + contents: dict = notification.get("contents") or {} + lang_content = _pick_lang_content(contents, lang) + notifications.append( + { + "notification_id": notification.get("notificationId"), + "frequency": notification.get("frequency"), + "lang": lang_content.get("lang", lang), + "title": lang_content.get("title", ""), + "subtitle": lang_content.get("subtitle", ""), + "body": lang_content.get("body", ""), + "title_pic_url": lang_content.get("titlePicUrl", ""), + } + ) + + return {"should_show": bool(notifications), "notifications": notifications}, 200 + + +@console_ns.route("/notification/dismiss") +class NotificationDismissApi(Resource): + @console_ns.doc("dismiss_notification") + @console_ns.doc( + description="Mark a notification as dismissed for the current user.", + responses={200: "Success", 401: "Unauthorized"}, + ) + @setup_required + @login_required + @account_initialization_required + @only_edition_cloud + def post(self): + current_user, _ = current_account_with_tenant() + payload = DismissNotificationPayload.model_validate(request.get_json()) + BillingService.dismiss_notification( + notification_id=payload.notification_id, + account_id=str(current_user.id), + ) + return {"result": "success"}, 200 diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index f3738319df..49162d4dae 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -13,7 +13,7 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py index e099fe0f32..279e4ec502 100644 --- a/api/controllers/console/setup.py +++ b/api/controllers/console/setup.py @@ -2,6 +2,7 @@ from typing import Literal from flask import request from pydantic import BaseModel, Field, field_validator +from sqlalchemy import select from configs import dify_config from controllers.fastopenapi import console_router @@ -100,6 +101,6 @@ def setup_system(payload: SetupRequestPayload) -> SetupResponse: def get_setup_status() -> DifySetup | bool | None: if dify_config.EDITION == "SELF_HOSTED": - return db.session.query(DifySetup).first() + return db.session.scalar(select(DifySetup).limit(1)) return True diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 708df62642..6f93ff1e70 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -43,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone from libs.login import current_account_with_tenant, login_required from models import AccountIntegrate, InvitationCode +from models.account import AccountStatus, InvitationCodeStatus from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -211,19 +212,19 @@ class AccountInitApi(Resource): raise ValueError("invitation_code is required") # check invitation code - invitation_code = ( - db.session.query(InvitationCode) + invitation_code = db.session.scalar( + select(InvitationCode) .where( InvitationCode.code == args.invitation_code, - InvitationCode.status == "unused", + InvitationCode.status == InvitationCodeStatus.UNUSED, ) - .first() + .limit(1) ) if not invitation_code: raise InvalidInvitationCodeError() - invitation_code.status = "used" + invitation_code.status = InvitationCodeStatus.USED invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -231,7 +232,7 @@ class AccountInitApi(Resource): account.interface_language = args.interface_language account.timezone = args.timezone account.interface_theme = "light" - account.status = "active" + account.status = AccountStatus.ACTIVE account.initialized_at = naive_utc_now() db.session.commit() diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index 9527fe782e..e2b504751b 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -2,7 +2,7 @@ from flask_restx import Resource, fields from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 1897cbdca7..538c5fb561 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginPermissionDeniedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index ccb60b1461..0a9e54de99 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -5,8 +5,8 @@ from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index dd302b90d6..e3bf4c95b8 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -171,7 +171,7 @@ class MemberCancelInviteApi(Resource): current_user, _ = current_account_with_tenant() if not current_user.current_tenant: raise ValueError("No current tenant") - member = db.session.query(Account).where(Account.id == str(member_id)).first() + member = db.session.get(Account, str(member_id)) if member is None: abort(404) else: diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index 7bada2fa12..db3b02ae94 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -7,9 +7,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index 583e3e3057..d7eceb656c 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -8,9 +8,9 @@ from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index d1485bc1c0..ee537367c7 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -5,6 +5,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource from pydantic import BaseModel, Field +from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden from configs import dify_config @@ -12,8 +13,8 @@ from controllers.common.schema import register_enum_models, register_schema_mode from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.impl.exc import PluginDaemonClientSideError +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -169,6 +170,20 @@ register_enum_models( ) +def _read_upload_content(file: FileStorage, max_size: int) -> bytes: + """ + Read the uploaded file and validate its actual size before delegating to the plugin service. + + FileStorage.content_length is not reliable for multipart test uploads and may be zero even when + content exists, so the controllers validate against the loaded bytes instead. + """ + content = file.read() + if len(content) > max_size: + raise ValueError("File size exceeds the maximum allowed size") + + return content + + @console_ns.route("/workspaces/current/plugin/debugging-key") class PluginDebuggingKeyApi(Resource): @setup_required @@ -284,12 +299,7 @@ class PluginUploadFromPkgApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["pkg"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_PACKAGE_SIZE) try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: @@ -328,12 +338,7 @@ class PluginUploadFromBundleApi(Resource): _, tenant_id = current_account_with_tenant() file = request.files["bundle"] - - # check file size - if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE: - raise ValueError("File size exceeds the maximum allowed size") - - content = file.read() + content = _read_upload_content(file, dify_config.PLUGIN_MAX_BUNDLE_SIZE) try: response = PluginService.upload_bundle(tenant_id, content) except PluginDaemonClientSideError as e: diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5bfa895849..b38f05795a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -23,10 +23,10 @@ from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.error import MCPAuthError, MCPError, MCPRefreshTokenError from core.mcp.mcp_client import MCPClient -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index 6b642af613..ad78d2a623 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -10,11 +10,11 @@ from werkzeug.exceptions import BadRequest, Forbidden from configs import dify_config from controllers.common.schema import register_schema_models from controllers.web.error import NotFoundError -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.login import current_user, login_required from models.account import Account diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 94be81d94f..88fd2c010f 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -7,6 +7,7 @@ from sqlalchemy import select from werkzeug.exceptions import Unauthorized import services +from configs import dify_config from controllers.common.errors import ( FilenameNotExistsError, FileTooLargeError, @@ -29,6 +30,7 @@ from libs.helper import TimestampField from libs.login import current_account_with_tenant, login_required from models.account import Tenant, TenantStatus from services.account_service import TenantService +from services.billing_service import BillingService, SubscriptionPlan from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService from services.file_service import FileService @@ -108,9 +110,29 @@ class TenantListApi(Resource): current_user, current_tenant_id = current_account_with_tenant() tenants = TenantService.get_join_tenants(current_user) tenant_dicts = [] + is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED + is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED + tenant_plans: dict[str, SubscriptionPlan] = {} + + if is_saas: + tenant_ids = [tenant.id for tenant in tenants] + if tenant_ids: + tenant_plans = BillingService.get_plan_bulk(tenant_ids) + if not tenant_plans: + logger.warning("get_plan_bulk returned empty result, falling back to legacy feature path") for tenant in tenants: - features = FeatureService.get_features(tenant.id) + plan: str = CloudPlan.SANDBOX + if is_saas: + tenant_plan = tenant_plans.get(tenant.id) + if tenant_plan: + plan = tenant_plan["plan"] or CloudPlan.SANDBOX + else: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX + elif not is_enterprise_only: + features = FeatureService.get_features(tenant.id) + plan = features.billing.subscription.plan or CloudPlan.SANDBOX # Create a dictionary with tenant attributes tenant_dict = { @@ -118,7 +140,7 @@ class TenantListApi(Resource): "name": tenant.name, "status": tenant.status, "created_at": tenant.created_at, - "plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX, + "plan": plan, "current": tenant.id == current_tenant_id if current_tenant_id else False, } @@ -198,7 +220,7 @@ class SwitchWorkspaceApi(Resource): except Exception: raise AccountNotLinkTenantError("Account not link tenant") - new_tenant = db.session.query(Tenant).get(args.tenant_id) # Get new tenant + new_tenant = db.session.get(Tenant, args.tenant_id) # Get new tenant if new_tenant is None: raise ValueError("Tenant not found") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 014f4c4132..6785ba0c34 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -7,6 +7,7 @@ from functools import wraps from typing import ParamSpec, TypeVar from flask import abort, request +from sqlalchemy import select from configs import dify_config from controllers.console.auth.error import AuthenticationFailedError, EmailCodeError @@ -218,13 +219,9 @@ def setup_required(view: Callable[P, R]) -> Callable[P, R]: @wraps(view) def decorated(*args: P.args, **kwargs: P.kwargs) -> R: # check setup - if ( - dify_config.EDITION == "SELF_HOSTED" - and os.environ.get("INIT_PASSWORD") - and not db.session.query(DifySetup).first() - ): - raise NotInitValidateError() - elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first(): + if dify_config.EDITION == "SELF_HOSTED" and not db.session.scalar(select(DifySetup).limit(1)): + if os.environ.get("INIT_PASSWORD"): + raise NotInitValidateError() raise NotSetupError() return view(*args, **kwargs) diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index f6032a8e49..9e3fb3a90b 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -10,7 +10,6 @@ from controllers.common.file_response import enforce_download_for_html from controllers.files import files_ns from core.tools.signature import verify_tool_file_signature from core.tools.tool_file_manager import ToolFileManager -from extensions.ext_database import db as global_db DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -57,7 +56,7 @@ class ToolFileApi(Resource): raise Forbidden("Invalid request.") try: - tool_file_manager = ToolFileManager(engine=global_db.engine) + tool_file_manager = ToolFileManager() stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id( file_id, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index b34412ef6d..52690a12e1 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -8,7 +8,7 @@ from werkzeug.exceptions import Forbidden import services from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file.helpers import verify_plugin_file_signature +from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 4cd1c4745f..9b8b3950e6 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -4,7 +4,6 @@ from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data from controllers.inner_api.wraps import plugin_inner_api_only -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse from core.plugin.backwards_invocation.encrypt import PluginEncrypter @@ -29,7 +28,8 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.file.helpers import get_signed_file_url_for_plugin +from dify_graph.file.helpers import get_signed_file_url_for_plugin +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index edf3ac393c..d6e3ebfbcd 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -5,6 +5,7 @@ from typing import ParamSpec, TypeVar from flask import current_app, request from flask_login import user_logged_in from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.orm import Session from extensions.ext_database import db @@ -36,23 +37,16 @@ def get_user(tenant_id: str, user_id: str | None) -> EndUser: user_model = None if is_anonymous: - user_model = ( - session.query(EndUser) + user_model = session.scalar( + select(EndUser) .where( EndUser.session_id == user_id, EndUser.tenant_id == tenant_id, ) - .first() + .limit(1) ) else: - user_model = ( - session.query(EndUser) - .where( - EndUser.id == user_id, - EndUser.tenant_id == tenant_id, - ) - .first() - ) + user_model = session.get(EndUser, user_id) if not user_model: user_model = EndUser( @@ -85,16 +79,7 @@ def get_user_tenant(view_func: Callable[P, R]): if not user_id: user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID - try: - tenant_model = ( - db.session.query(Tenant) - .where( - Tenant.id == tenant_id, - ) - .first() - ) - except Exception: - raise ValueError("tenant not found") + tenant_model = db.session.get(Tenant, tenant_id) if not tenant_model: raise ValueError("tenant not found") @@ -114,6 +99,7 @@ def get_user_tenant(view_func: Callable[P, R]): def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]): def decorator(view_func: Callable[P, R]): + @wraps(view_func) def decorated_view(*args: P.args, **kwargs: P.kwargs): try: data = request.get_json() diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py index a5746abafa..ef0a46db63 100644 --- a/api/controllers/inner_api/workspace/workspace.py +++ b/api/controllers/inner_api/workspace/workspace.py @@ -2,6 +2,7 @@ import json from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from controllers.common.schema import register_schema_models from controllers.console.wraps import setup_required @@ -42,7 +43,7 @@ class EnterpriseWorkspace(Resource): def post(self): args = WorkspaceCreatePayload.model_validate(inner_api_ns.payload or {}) - account = db.session.query(Account).filter_by(email=args.owner_email).first() + account = db.session.scalar(select(Account).where(Account.email == args.owner_email).limit(1)) if account is None: return {"message": "owner account not found."}, 404 diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py index 4bdcc6832a..7c60b316e8 100644 --- a/api/controllers/inner_api/wraps.py +++ b/api/controllers/inner_api/wraps.py @@ -75,7 +75,7 @@ def enterprise_inner_api_user_auth(view: Callable[P, R]): if signature_base64 != token: return view(*args, **kwargs) - kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first() + kwargs["user"] = db.session.get(EndUser, user_id) return view(*args, **kwargs) diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 991a9166c7..9ddaaa315b 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -6,13 +6,13 @@ from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session from controllers.common.schema import register_schema_model -from controllers.console.app.mcp_server import AppMCPServerStatus from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper +from models.enums import AppMCPServerStatus from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index ef254ca357..c22190cbc9 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -185,4 +185,4 @@ class AnnotationUpdateDeleteApi(Resource): def delete(self, app_model: App, annotation_id: str): """Delete an annotation.""" AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) - return {"result": "success"}, 204 + return "", 204 diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 562f5e33cc..abcaa0e240 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from flask_restx import Resource from controllers.common.fields import Parameters @@ -33,14 +35,14 @@ class AppParameterApi(Resource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index e383920460..38d292d0b9 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -21,7 +21,7 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 9d8431f066..98f09c44a1 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -28,7 +28,7 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 8e29c9ff0f..edbf011656 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -14,7 +14,6 @@ from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from fields.conversation_fields import ( - ConversationDelete, ConversationInfiniteScrollPagination, SimpleConversation, ) @@ -163,7 +162,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return ConversationDelete(result="success").model_dump(mode="json"), 204 + return "", 204 @service_api_ns.route("/conversations//name") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 2aaf920efb..77fee9c142 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem from libs.helper import UUIDStrOrEmpty +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.message import ( FirstMessageNotExistsError, @@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 2ce8f05f75..35dd22c801 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -27,9 +27,9 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model @@ -132,6 +132,8 @@ class WorkflowRunDetailApi(Resource): app_id=app_model.id, run_id=workflow_run_id, ) + if not workflow_run: + raise NotFound("Workflow run not found.") return workflow_run diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index c06b81b775..83d07087ab 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,8 +14,8 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag from libs.login import current_user diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 0aeb4a2d36..d34b4124ae 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -1,10 +1,11 @@ import json +from contextlib import ExitStack from typing import Self from uuid import UUID -from flask import request +from flask import request, send_file from flask_restx import marshal -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from sqlalchemy import desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -35,6 +36,7 @@ from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields from libs.login import current_user from models.dataset import Dataset, Document, DocumentSegment +from models.enums import SegmentStatus from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( KnowledgeConfig, @@ -60,6 +62,13 @@ class DocumentTextCreatePayload(BaseModel): embedding_model: str | None = None embedding_model_provider: str | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -72,6 +81,13 @@ class DocumentTextUpdate(BaseModel): doc_language: str = "English" retrieval_model: RetrievalModel | None = None + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + if value not in Dataset.DOC_FORM_LIST: + raise ValueError("Invalid doc_form.") + return value + @model_validator(mode="after") def check_text_and_name(self) -> Self: if self.text is not None and self.name is None: @@ -86,6 +102,15 @@ class DocumentListQuery(BaseModel): status: str | None = Field(default=None, description="Document status filter") +DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100 + + +class DocumentBatchDownloadZipPayload(BaseModel): + """Request payload for bulk downloading uploaded documents as a ZIP archive.""" + + document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS) + + register_enum_models(service_api_ns, RetrievalMethod) register_schema_models( @@ -95,6 +120,7 @@ register_schema_models( DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery, + DocumentBatchDownloadZipPayload, Rule, PreProcessingRule, Segmentation, @@ -526,6 +552,46 @@ class DocumentListApi(DatasetApiResource): return response +@service_api_ns.route("/datasets//documents/download-zip") +class DocumentBatchDownloadZipApi(DatasetApiResource): + """Download multiple uploaded-file documents as a single ZIP archive.""" + + @service_api_ns.expect(service_api_ns.models[DocumentBatchDownloadZipPayload.__name__]) + @service_api_ns.doc("download_documents_as_zip") + @service_api_ns.doc(description="Download selected uploaded documents as a single ZIP archive") + @service_api_ns.doc(params={"dataset_id": "Dataset ID"}) + @service_api_ns.doc( + responses={ + 200: "ZIP archive generated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document or dataset not found", + } + ) + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def post(self, tenant_id, dataset_id): + payload = DocumentBatchDownloadZipPayload.model_validate(service_api_ns.payload or {}) + + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=str(dataset_id), + document_ids=[str(document_id) for document_id in payload.document_ids], + tenant_id=str(tenant_id), + current_user=current_user, + ) + + with ExitStack() as stack: + zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files)) + response = send_file( + zip_path, + mimetype="application/zip", + as_attachment=True, + download_name=download_name, + ) + cleanup = stack.pop_all() + response.call_on_close(cleanup.close) + return response + + @service_api_ns.route("/datasets//documents//indexing-status") class DocumentIndexingStatusApi(DatasetApiResource): @service_api_ns.doc("get_document_indexing_status") @@ -557,13 +623,15 @@ class DocumentIndexingStatusApi(DatasetApiResource): .where( DocumentSegment.completed_at.isnot(None), DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", + DocumentSegment.status != SegmentStatus.RE_SEGMENT, ) .count() ) total_segments = ( db.session.query(DocumentSegment) - .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .where( + DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + ) .count() ) # Create a dictionary with document attributes and additional fields @@ -586,6 +654,35 @@ class DocumentIndexingStatusApi(DatasetApiResource): return data +@service_api_ns.route("/datasets//documents//download") +class DocumentDownloadApi(DatasetApiResource): + """Return a signed download URL for a document's original uploaded file.""" + + @service_api_ns.doc("get_document_download_url") + @service_api_ns.doc(description="Get a signed download URL for a document's original uploaded file") + @service_api_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"}) + @service_api_ns.doc( + responses={ + 200: "Download URL generated successfully", + 401: "Unauthorized - invalid API token", + 403: "Forbidden - insufficient permissions", + 404: "Document or upload file not found", + } + ) + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") + def get(self, tenant_id, dataset_id, document_id): + dataset = self.get_dataset(str(dataset_id), str(tenant_id)) + document = DocumentService.get_document(dataset.id, str(document_id)) + + if not document: + raise NotFound("Document not found.") + + if document.tenant_id != str(tenant_id): + raise Forbidden("No permission.") + + return {"url": DocumentService.get_document_download_url(document)} + + @service_api_ns.route("/datasets//documents/") class DocumentApi(DatasetApiResource): METADATA_CHOICES = {"all", "only", "without"} diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 4eb4fed29a..2e3b7fd85e 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -17,7 +17,7 @@ from controllers.service_api.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import current_account_with_tenant diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index fffcb47bd4..35aed40a59 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -3,7 +3,7 @@ from flask_restx import Resource from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index cc55c69c48..7aa5b2f092 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -3,7 +3,7 @@ import time from collections.abc import Callable from enum import StrEnum, auto from functools import wraps -from typing import Concatenate, ParamSpec, TypeVar, cast +from typing import Concatenate, ParamSpec, TypeVar, cast, overload from flask import current_app, request from flask_login import user_logged_in @@ -44,10 +44,22 @@ class FetchUserArg(BaseModel): required: bool = False -def validate_app_token(view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None): - def decorator(view_func: Callable[P, R]): +@overload +def validate_app_token(view: Callable[P, R]) -> Callable[P, R]: ... + + +@overload +def validate_app_token( + view: None = None, *, fetch_user_arg: FetchUserArg | None = None +) -> Callable[[Callable[P, R]], Callable[P, R]]: ... + + +def validate_app_token( + view: Callable[P, R] | None = None, *, fetch_user_arg: FetchUserArg | None = None +) -> Callable[P, R] | Callable[[Callable[P, R]], Callable[P, R]]: + def decorator(view_func: Callable[P, R]) -> Callable[P, R]: @wraps(view_func) - def decorated_view(*args: P.args, **kwargs: P.kwargs): + def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") app_model = db.session.query(App).where(App.id == api_token.app_id).first() @@ -213,10 +225,20 @@ def cloud_edition_billing_rate_limit_check(resource: str, api_token_type: str): return interceptor -def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): - def decorator(view: Callable[Concatenate[T, P], R]): - @wraps(view) - def decorated(*args: P.args, **kwargs: P.kwargs): +@overload +def validate_dataset_token(view: Callable[Concatenate[T, P], R]) -> Callable[P, R]: ... + + +@overload +def validate_dataset_token(view: None = None) -> Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: ... + + +def validate_dataset_token( + view: Callable[Concatenate[T, P], R] | None = None, +) -> Callable[P, R] | Callable[[Callable[Concatenate[T, P], R]], Callable[P, R]]: + def decorator(view_func: Callable[Concatenate[T, P], R]) -> Callable[P, R]: + @wraps(view_func) + def decorated(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("dataset") # get url path dataset_id from positional args or kwargs @@ -287,7 +309,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None): raise Unauthorized("Tenant owner account does not exist.") else: raise Unauthorized("Tenant does not exist.") - return view(api_token.tenant_id, *args, **kwargs) + return view_func(api_token.tenant_id, *args, **kwargs) # type: ignore[arg-type] return decorated diff --git a/api/controllers/trigger/webhook.py b/api/controllers/trigger/webhook.py index 22b24271c6..eb579da5d4 100644 --- a/api/controllers/trigger/webhook.py +++ b/api/controllers/trigger/webhook.py @@ -70,7 +70,14 @@ def handle_webhook(webhook_id: str): @bp.route("/webhook-debug/", methods=["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"]) def handle_webhook_debug(webhook_id: str): - """Handle webhook debug calls without triggering production workflow execution.""" + """Handle webhook debug calls without triggering production workflow execution. + + The debug webhook endpoint is only for draft inspection flows. It never enqueues + Celery work for the published workflow; instead it dispatches an in-memory debug + event to an active Variable Inspector listener. Returning a clear error when no + listener is registered prevents a misleading 200 response for requests that are + effectively dropped. + """ try: webhook_trigger, _, node_config, webhook_data, error = _prepare_webhook_execution(webhook_id, is_debug=True) if error: @@ -94,11 +101,32 @@ def handle_webhook_debug(webhook_id: str): "method": webhook_data.get("method"), }, ) - TriggerDebugEventBus.dispatch( + dispatch_count = TriggerDebugEventBus.dispatch( tenant_id=webhook_trigger.tenant_id, event=event, pool_key=pool_key, ) + if dispatch_count == 0: + logger.warning( + "Webhook debug request dropped without an active listener for webhook %s (tenant=%s, app=%s, node=%s)", + webhook_trigger.webhook_id, + webhook_trigger.tenant_id, + webhook_trigger.app_id, + webhook_trigger.node_id, + ) + return ( + jsonify( + { + "error": "No active debug listener", + "message": ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ), + "execution_url": webhook_trigger.webhook_url, + } + ), + 409, + ) response_data, status_code = WebhookService.generate_webhook_response(node_config) return jsonify(response_data), status_code diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index 62ea532eac..25bbedce54 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,4 +1,5 @@ import logging +from typing import Any, cast from flask import request from flask_restx import Resource @@ -57,14 +58,14 @@ class AppParameterApi(WebApiResource): if workflow is None: raise AppUnavailableError() - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app_model.app_model_config if app_model_config is None: raise AppUnavailableError() - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 15828cc208..2b8f752668 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -20,7 +20,7 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index a97d745471..8634c1f43c 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -25,7 +25,7 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/human_input_form.py b/api/controllers/web/human_input_form.py index 4e69e56025..36728a47d1 100644 --- a/api/controllers/web/human_input_form.py +++ b/api/controllers/web/human_input_form.py @@ -8,6 +8,7 @@ from datetime import datetime from flask import Response, request from flask_restx import Resource, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -147,11 +148,11 @@ class HumanInputFormApi(Resource): def _get_app_site_from_form(form: Form) -> tuple[App, Site]: """Resolve App/Site for the form's app and validate tenant status.""" - app_model = db.session.query(App).where(App.id == form.app_id).first() + app_model = db.session.get(App, form.app_id) if app_model is None or app_model.tenant_id != form.tenant_id: raise NotFoundError("Form not found") - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if site is None: raise Forbidden() diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index 80035ba818..aa56292614 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -20,11 +20,12 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper from libs.helper import uuid_value +from models.enums import FeedbackRating from models.model import AppMode from services.app_generate_service import AppGenerateService from services.errors.app import MoreLikeThisDisabledError @@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource): app_model=app_model, message_id=message_id, user=end_user, - rating=payload.rating, + rating=FeedbackRating(payload.rating) if payload.rating else None, content=payload.content, ) except MessageNotExistsError: @@ -239,7 +240,7 @@ class MessageSuggestedQuestionApi(WebApiResource): def get(self, app_model, end_user, message_id): app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}: - raise NotCompletionAppError() + raise NotChatAppError() message_id = str(message_id) diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 1cdae0fe56..6a93ef6748 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -11,7 +11,7 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py index f957229ece..1a0c6d4252 100644 --- a/api/controllers/web/site.py +++ b/api/controllers/web/site.py @@ -1,6 +1,7 @@ from typing import cast from flask_restx import fields, marshal, marshal_with +from sqlalchemy import select from werkzeug.exceptions import Forbidden from configs import dify_config @@ -72,7 +73,7 @@ class AppSiteApi(WebApiResource): def get(self, app_model, end_user): """Retrieve app site info.""" # get site - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index a309ef3dad..508d1a756a 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -22,8 +22,8 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from core.model_runtime.errors.invoke import InvokeError -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager +from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from models.model import App, AppMode, EndUser diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 80e180ce96..1bdc8df813 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -19,7 +19,15 @@ from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackH from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ( + ToolParameter, +) +from core.tools.tool_manager import ToolManager +from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, LLMUsage, PromptMessage, @@ -29,17 +37,9 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.entities.model_entities import ModelFeature -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ( - ToolParameter, -) -from core.tools.tool_manager import ToolManager -from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from core.workflow.file import file_manager +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from factories import file_factory from models.enums import CreatorUserRole @@ -441,7 +441,7 @@ class BaseAgentRunner(AppRunner): continue result.append(self.organize_agent_user_prompt(message)) - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if agent_thoughts: for agent_thought in agent_thoughts: tool_names_raw = agent_thought.tool diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 0464afe194..9271ed10bd 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -6,23 +6,23 @@ from typing import Any from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError from core.agent.output_parser.cot_output_parser import CotAgentOutputParser from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.ops.ops_trace_manager import TraceQueueManager +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.__base.tool import Tool +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, ToolPromptMessage, UserPromptMessage, ) -from core.ops.ops_trace_manager import TraceQueueManager -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from core.workflow.nodes.agent.exc import AgentMaxIterationError from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index babb463aba..89451a0498 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,16 +1,16 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities import ( +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.file import file_manager +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class CotChatAgentRunner(CotAgentRunner): diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index da9a001d84..3023b9bc4d 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,13 +1,13 @@ import json from core.agent.cot_agent_runner import CotAgentRunner -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class CotCompletionAgentRunner(CotAgentRunner): diff --git a/api/core/agent/errors.py b/api/core/agent/errors.py new file mode 100644 index 0000000000..ed504d500a --- /dev/null +++ b/api/core/agent/errors.py @@ -0,0 +1,9 @@ +class AgentMaxIterationError(Exception): + """Raised when an agent runner exceeds the configured max iteration count.""" + + def __init__(self, max_iteration: int): + self.max_iteration = max_iteration + super().__init__( + f"Agent exceeded the maximum iteration limit of {max_iteration}. " + f"The agent was unable to complete the task within the allowed number of iterations." + ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 633609f54f..5e13a13b21 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -5,9 +5,14 @@ from copy import deepcopy from typing import Any, Union from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError from core.app.apps.base_app_queue_manager import PublishFrom from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.model_runtime.entities import ( +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, @@ -20,12 +25,7 @@ from core.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from core.workflow.file import file_manager -from core.workflow.nodes.agent.exc import AgentMaxIterationError +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes from models.model import Message logger = logging.getLogger(__name__) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 7c8f09e6b9..82676f1ebd 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -4,7 +4,7 @@ from collections.abc import Generator from typing import Union from core.agent.entities import AgentScratchpadUnit -from core.model_runtime.entities.llm_entities import LLMResultChunk +from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/common/parameters_mapping/__init__.py b/api/core/app/app_config/common/parameters_mapping/__init__.py index 6f1a3bf045..460fdfb3ba 100644 --- a/api/core/app/app_config/common/parameters_mapping/__init__.py +++ b/api/core/app/app_config/common/parameters_mapping/__init__.py @@ -1,13 +1,36 @@ from collections.abc import Mapping -from typing import Any +from typing import Any, TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS +class SystemParametersDict(TypedDict): + image_file_size_limit: int + video_file_size_limit: int + audio_file_size_limit: int + file_size_limit: int + workflow_file_upload_limit: int + + +class AppParametersDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: dict[str, Any] + speech_to_text: dict[str, Any] + text_to_speech: dict[str, Any] + retriever_resource: dict[str, Any] + annotation_reply: dict[str, Any] + more_like_this: dict[str, Any] + user_input_form: list[dict[str, Any]] + sensitive_word_avoidance: dict[str, Any] + file_upload: dict[str, Any] + system_parameters: SystemParametersDict + + def get_parameters_from_feature_dict( *, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]] -) -> Mapping[str, Any]: +) -> AppParametersDict: """ Mapping from feature dict to webapp parameters """ diff --git a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py index e925d6dd52..7d1b11c008 100644 --- a/api/core/app/app_config/common/sensitive_word_avoidance/manager.py +++ b/api/core/app/app_config/common/sensitive_word_avoidance/manager.py @@ -1,10 +1,13 @@ +from collections.abc import Mapping +from typing import Any + from core.app.app_config.entities import SensitiveWordAvoidanceEntity from core.moderation.factory import ModerationFactory class SensitiveWordAvoidanceConfigManager: @classmethod - def convert(cls, config: dict) -> SensitiveWordAvoidanceEntity | None: + def convert(cls, config: Mapping[str, Any]) -> SensitiveWordAvoidanceEntity | None: sensitive_word_avoidance_dict = config.get("sensitive_word_avoidance") if not sensitive_word_avoidance_dict: return None @@ -12,7 +15,7 @@ class SensitiveWordAvoidanceConfigManager: if sensitive_word_avoidance_dict.get("enabled"): return SensitiveWordAvoidanceEntity( type=sensitive_word_avoidance_dict.get("type"), - config=sensitive_word_avoidance_dict.get("config"), + config=sensitive_word_avoidance_dict.get("config", {}), ) else: return None diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 9b981dfc09..10db380d1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -1,10 +1,13 @@ +from typing import Any, cast + from core.agent.entities import AgentEntity, AgentPromptEntity, AgentToolEntity from core.agent.prompt.template import REACT_PROMPT_TEMPLATES +from models.model import AppModelConfigDict class AgentConfigManager: @classmethod - def convert(cls, config: dict) -> AgentEntity | None: + def convert(cls, config: AppModelConfigDict) -> AgentEntity | None: """ Convert model config to model config @@ -28,17 +31,17 @@ class AgentConfigManager: agent_tools = [] for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) >= 4: - if "enabled" not in tool or not tool["enabled"]: + tool_dict = cast(dict[str, Any], tool) + if len(tool_dict) >= 4: + if "enabled" not in tool_dict or not tool_dict["enabled"]: continue agent_tool_properties = { - "provider_type": tool["provider_type"], - "provider_id": tool["provider_id"], - "tool_name": tool["tool_name"], - "tool_parameters": tool.get("tool_parameters", {}), - "credential_id": tool.get("credential_id", None), + "provider_type": tool_dict["provider_type"], + "provider_id": tool_dict["provider_id"], + "tool_name": tool_dict["tool_name"], + "tool_parameters": tool_dict.get("tool_parameters", {}), + "credential_id": tool_dict.get("credential_id", None), } agent_tools.append(AgentToolEntity.model_validate(agent_tool_properties)) @@ -47,7 +50,8 @@ class AgentConfigManager: "react_router", "router", }: - agent_prompt = agent_dict.get("prompt", None) or {} + agent_prompt_raw = agent_dict.get("prompt", None) + agent_prompt: dict[str, Any] = agent_prompt_raw if isinstance(agent_prompt_raw, dict) else {} # check model mode model_mode = config.get("model", {}).get("mode", "completion") if model_mode == "completion": @@ -75,7 +79,7 @@ class AgentConfigManager: strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get("max_iteration", 10), + max_iteration=cast(int, agent_dict.get("max_iteration", 10)), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index aacafb2dad..f04a8df119 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,5 +1,5 @@ import uuid -from typing import Literal, cast +from typing import Any, Literal, cast from core.app.app_config.entities import ( DatasetEntity, @@ -8,13 +8,14 @@ from core.app.app_config.entities import ( ModelConfig, ) from core.entities.agent_entities import PlanningStrategy -from models.model import AppMode +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from models.model import AppMode, AppModelConfigDict from services.dataset_service import DatasetService class DatasetConfigManager: @classmethod - def convert(cls, config: dict) -> DatasetEntity | None: + def convert(cls, config: AppModelConfigDict) -> DatasetEntity | None: """ Convert model config to model config @@ -25,11 +26,15 @@ class DatasetConfigManager: datasets = config.get("dataset_configs", {}).get("datasets", {"strategy": "router", "datasets": []}) for dataset in datasets.get("datasets", []): + if not isinstance(dataset, dict): + continue keys = list(dataset.keys()) if len(keys) == 0 or keys[0] != "dataset": continue dataset = dataset["dataset"] + if not isinstance(dataset, dict): + continue if "enabled" not in dataset or not dataset["enabled"]: continue @@ -47,15 +52,14 @@ class DatasetConfigManager: agent_dict = config.get("agent_mode", {}) for tool in agent_dict.get("tools", []): - keys = tool.keys() - if len(keys) == 1: + if len(tool) == 1: # old standard key = list(tool.keys())[0] if key != "dataset": continue - tool_item = tool[key] + tool_item = cast(dict[str, Any], tool)[key] if "enabled" not in tool_item or not tool_item["enabled"]: continue @@ -114,8 +118,10 @@ class DatasetConfigManager: score_threshold=float(score_threshold_val) if dataset_configs.get("score_threshold_enabled", False) and score_threshold_val is not None else None, - reranking_model=reranking_model_val if isinstance(reranking_model_val, dict) else None, - weights=weights_val if isinstance(weights_val, dict) else None, + reranking_model=cast(RerankingModelDict, reranking_model_val) + if isinstance(reranking_model_val, dict) + else None, + weights=cast(WeightsDict, weights_val) if isinstance(weights_val, dict) else None, reranking_enabled=bool(dataset_configs.get("reranking_enabled", True)), rerank_mode=dataset_configs.get("reranking_mode", "reranking_model"), metadata_filtering_mode=cast( diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index b816c8d7d0..558b6e69a0 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,10 +4,10 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class ModelConfigConverter: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index c391a279b5..0929f52e33 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,15 +2,16 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID class ModelConfigManager: @classmethod - def convert(cls, config: dict) -> ModelConfigEntity: + def convert(cls, config: AppModelConfigDict) -> ModelConfigEntity: """ Convert model config to model config @@ -22,7 +23,7 @@ class ModelConfigManager: if not model_config: raise ValueError("model is required") - completion_params = model_config.get("completion_params") + completion_params = model_config.get("completion_params") or {} stop = [] if "stop" in completion_params: stop = completion_params["stop"] diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 21614c010c..b7073898d6 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,17 +1,19 @@ +from typing import Any + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, ) -from core.model_runtime.entities.message_entities import PromptMessageRole from core.prompt.simple_prompt_transform import ModelMode -from models.model import AppMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from models.model import AppMode, AppModelConfigDict class PromptTemplateConfigManager: @classmethod - def convert(cls, config: dict) -> PromptTemplateEntity: + def convert(cls, config: AppModelConfigDict) -> PromptTemplateEntity: if not config.get("prompt_type"): raise ValueError("prompt_type is required") @@ -40,14 +42,15 @@ class PromptTemplateConfigManager: advanced_completion_prompt_template = None completion_prompt_config = config.get("completion_prompt_config", {}) if completion_prompt_config: - completion_prompt_template_params = { + completion_prompt_template_params: dict[str, Any] = { "prompt": completion_prompt_config["prompt"]["text"], } - if "conversation_histories_role" in completion_prompt_config: + conv_role = completion_prompt_config.get("conversation_histories_role") + if conv_role: completion_prompt_template_params["role_prefix"] = { - "user": completion_prompt_config["conversation_histories_role"]["user_prefix"], - "assistant": completion_prompt_config["conversation_histories_role"]["assistant_prefix"], + "user": conv_role["user_prefix"], + "assistant": conv_role["assistant_prefix"], } advanced_completion_prompt_template = AdvancedCompletionPromptTemplateEntity( diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 22d602a190..8de1224a89 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,8 +1,10 @@ import re +from typing import cast from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( [ @@ -18,7 +20,7 @@ _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( class BasicVariablesConfigManager: @classmethod - def convert(cls, config: dict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: + def convert(cls, config: AppModelConfigDict) -> tuple[list[VariableEntity], list[ExternalDataVariableEntity]]: """ Convert model config to model config @@ -51,7 +53,9 @@ class BasicVariablesConfigManager: external_data_variables.append( ExternalDataVariableEntity( - variable=variable["variable"], type=variable["type"], config=variable["config"] + variable=variable["variable"], + type=variable.get("type", ""), + config=variable.get("config", {}), ) ) elif variable_type in { @@ -64,10 +68,10 @@ class BasicVariablesConfigManager: variable = variables[variable_type] variable_entities.append( VariableEntity( - type=variable_type, - variable=variable.get("variable"), + type=cast(VariableEntityType, variable_type), + variable=variable["variable"], description=variable.get("description") or "", - label=variable.get("label"), + label=variable["label"], required=variable.get("required", False), max_length=variable.get("max_length"), options=variable.get("options") or [], diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 062cc6a0b3..95ea70bc40 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -4,10 +4,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.file import FileUploadConfig -from core.workflow.variables.input_entities import VariableEntity as WorkflowVariableEntity +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from dify_graph.file import FileUploadConfig +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode @@ -194,8 +195,8 @@ class DatasetRetrieveConfigEntity(BaseModel): top_k: int | None = None score_threshold: float | None = 0.0 rerank_mode: str | None = "reranking_model" - reranking_model: dict | None = None - weights: dict | None = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None reranking_enabled: bool | None = True metadata_filtering_mode: Literal["disabled", "automatic", "manual"] | None = "disabled" metadata_model_config: ModelConfig | None = None @@ -281,7 +282,7 @@ class EasyUIBasedAppConfig(AppConfig): app_model_config_from: EasyUIBasedAppModelConfigFrom app_model_config_id: str - app_model_config_dict: dict + app_model_config_dict: dict[str, Any] model: ModelConfigEntity prompt_template: PromptTemplateEntity dataset: DatasetEntity | None = None diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index d69fa85801..0c4266fbeb 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from constants import DEFAULT_FILE_NUMBER_LIMITS -from core.workflow.file import FileUploadConfig +from dify_graph.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index ec7d85a09f..d2a9a73380 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,7 @@ import re from core.app.app_config.entities import RagPipelineVariableEntity -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 2891d3ceeb..5d974335ff 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -31,18 +31,18 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.helper.trace_id_helper import extract_external_trace_id_from_args -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.draft_variable_repository import ( +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -330,9 +330,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) return self._generate( workflow=workflow, @@ -413,9 +414,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) return self._generate( workflow=workflow, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 18ae75a087..66037696af 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,16 +25,16 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration -from core.workflow.enums import WorkflowType -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables.variables import Variable from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import WorkflowType +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span @@ -138,20 +138,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): query = self.application_generate_entity.query # moderation - if self.handle_input_moderation( + stop, new_inputs, new_query = self.handle_input_moderation( app_record=self._app, app_generate_entity=self.application_generate_entity, inputs=inputs, query=query, message_id=self.message.id, - ): + ) + if stop: return + self.application_generate_entity.inputs = new_inputs + self.application_generate_entity.query = new_query + system_inputs.query = new_query + # annotation reply if self.handle_annotation_reply( app_record=self._app, message=self.message, - query=query, + query=new_query, app_generate_entity=self.application_generate_entity, ): return @@ -163,7 +168,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # init variable pool variable_pool = VariablePool( system_variables=system_inputs, - user_inputs=inputs, + user_inputs=new_inputs, environment_variables=self._workflow.environment_variables, # Based on the definition of `Variable`, # `VariableBase` instances can be safely used as `Variable` since they are compatible. @@ -240,10 +245,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): inputs: Mapping[str, Any], query: str, message_id: str, - ) -> bool: + ) -> tuple[bool, Mapping[str, Any], str]: try: # process sensitive_word_avoidance - _, inputs, query = self.moderation_for_inputs( + _, new_inputs, new_query = self.moderation_for_inputs( app_id=app_record.id, tenant_id=app_generate_entity.app_config.tenant_id, app_generate_entity=app_generate_entity, @@ -253,9 +258,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): ) except ModerationError as e: self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION) - return True + return True, inputs, query - return False + return False, new_inputs, new_query def handle_annotation_reply( self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity diff --git a/api/core/app/apps/advanced_chat/generate_response_converter.py b/api/core/app/apps/advanced_chat/generate_response_converter.py index 02ec96f209..5c9bc43992 100644 --- a/api/core/app/apps/advanced_chat/generate_response_converter.py +++ b/api/core/app/apps/advanced_chat/generate_response_converter.py @@ -114,7 +114,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse): diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 534ef6994a..f7b5030d33 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -63,20 +63,20 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes import NodeType -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile -from models.enums import CreatorUserRole, MessageStatus +from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent from models.workflow import Workflow @@ -357,7 +357,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): ) -> Generator[StreamResponse, None, None]: """Handle node succeeded events.""" # Record files if it's an answer node or end node - if event.node_type in [NodeType.ANSWER, NodeType.END, NodeType.LLM]: + if event.node_type in [BuiltinNodeTypes.ANSWER, BuiltinNodeTypes.END, BuiltinNodeTypes.LLM]: self._recorded_files.extend( self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) @@ -516,8 +516,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): graph_runtime_state=validated_state, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp - self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) def _handle_workflow_partial_success_event( self, @@ -538,6 +540,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): exceptions_count=event.exceptions_count, ) + yield from self._handle_advanced_chat_message_end_event( + QueueAdvancedChatMessageEndEvent(), graph_runtime_state=validated_state + ) yield workflow_finish_resp def _handle_workflow_paused_event( @@ -735,7 +740,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, tenant_id=self._workflow_tenant_id, ) form = form_repository.get_form(self._workflow_run_id, node_id) @@ -855,6 +859,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): yield from self._handle_workflow_paused_event(event) break + case QueueWorkflowSucceededEvent(): + yield from self._handle_workflow_succeeded_event(event, trace_manager=trace_manager) + break + + case QueueWorkflowPartialSuccessEvent(): + yield from self._handle_workflow_partial_success_event(event, trace_manager=trace_manager) + break + case QueueStopEvent(): yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager) break @@ -927,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): type=file["type"], transfer_method=file["transfer_method"], url=file["remote_url"], - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, upload_file_id=file["related_id"], created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 801619ddbc..f0d81e0c59 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -20,7 +20,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager from core.entities.agent_entities import PlanningStrategy -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation OLD_TOOLS = ["dataset", "google_search", "web_reader", "wikipedia", "current_datetime"] @@ -40,7 +40,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> AgentChatAppConfig: """ Convert app model config to agent chat app config @@ -61,7 +61,9 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = AgentChatAppConfig( @@ -70,7 +72,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -86,7 +88,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: Mapping[str, Any]): + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> AppModelConfigDict: """ Validate for agent chat app model config @@ -157,7 +159,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) @classmethod def validate_agent_mode_and_set_defaults( diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 7bd3b8a56e..76a067d7b6 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -20,8 +20,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index 7309113f27..a81da2e91c 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -14,10 +14,10 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/app/apps/agent_chat/generate_response_converter.py b/api/core/app/apps/agent_chat/generate_response_converter.py index e35e9d9408..0c146c388f 100644 --- a/api/core/app/apps/agent_chat/generate_response_converter.py +++ b/api/core/app/apps/agent_chat/generate_response_converter.py @@ -113,7 +113,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index d1e2f16b6f..a92e3dd2ea 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -6,7 +6,7 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) @@ -74,11 +74,22 @@ class AppGenerateResponseConverter(ABC): for resource in metadata["retriever_resources"]: updated_resources.append( { + "dataset_id": resource.get("dataset_id"), + "dataset_name": resource.get("dataset_name"), + "document_id": resource.get("document_id"), "segment_id": resource.get("segment_id", ""), "position": resource["position"], + "data_source_type": resource.get("data_source_type"), "document_name": resource["document_name"], "score": resource["score"], + "hit_count": resource.get("hit_count"), + "word_count": resource.get("word_count"), + "segment_position": resource.get("segment_position"), + "index_node_hash": resource.get("index_node_hash"), "content": resource["content"], + "page": resource.get("page"), + "title": resource.get("title"), + "files": resource.get("files"), "summary": resource.get("summary"), } ) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 81617c5fb2..20e6ac98ea 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -4,21 +4,21 @@ from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.enums import NodeType -from core.workflow.file import File, FileUploadConfig -from core.workflow.repositories.draft_variable_repository import ( +from dify_graph.enums import NodeType +from dify_graph.file import File, FileUploadConfig +from dify_graph.repositories.draft_variable_repository import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from core.workflow.variables.input_entities import VariableEntityType +from dify_graph.variables.input_entities import VariableEntityType from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from core.workflow.variables.input_entities import VariableEntity + from dify_graph.variables.input_entities import VariableEntity class BaseAppGenerator: diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index af1f1d7c66..5addd41815 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -20,7 +20,7 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index b98e85dbe9..11fcbb7561 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -24,27 +24,27 @@ from core.app.features.hosting_moderation.hosting_moderation import HostingModer from core.external_data_tool.external_data_fetch import ExternalDataFetch from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.errors.invoke import InvokeBadRequestError from core.moderation.input_moderation import InputModeration from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file.enums import FileTransferMethod, FileType +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from core.workflow.file.models import File + from dify_graph.file.models import File _logger = logging.getLogger(__name__) @@ -419,7 +419,7 @@ class AppRunner: message_id=message_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=f"/files/tools/{tool_file.id}", upload_file_id=tool_file.id, created_by_role=( diff --git a/api/core/app/apps/chat/app_config_manager.py b/api/core/app/apps/chat/app_config_manager.py index 4b6720a3c3..5f087f6066 100644 --- a/api/core/app/apps/chat/app_config_manager.py +++ b/api/core/app/apps/chat/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -13,7 +15,7 @@ from core.app.app_config.features.suggested_questions_after_answer.manager impor SuggestedQuestionsAfterAnswerConfigManager, ) from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig, Conversation +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict, Conversation class ChatAppConfig(EasyUIBasedAppConfig): @@ -31,7 +33,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_model: App, app_model_config: AppModelConfig, conversation: Conversation | None = None, - override_config_dict: dict | None = None, + override_config_dict: AppModelConfigDict | None = None, ) -> ChatAppConfig: """ Convert app model config to chat app config @@ -64,7 +66,7 @@ class ChatAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -79,7 +81,7 @@ class ChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for chat app model config @@ -145,4 +147,4 @@ class ChatAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index c1251d2feb..91cf54c774 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index 4870a56281..f63b38fc86 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -13,10 +13,10 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Conversation, Message @@ -173,8 +173,10 @@ class ChatAppRunner(AppRunner): memory=memory, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/apps/chat/generate_response_converter.py b/api/core/app/apps/chat/generate_response_converter.py index 3aa1161fd8..f23ee7f89f 100644 --- a/api/core/app/apps/chat/generate_response_converter.py +++ b/api/core/app/apps/chat/generate_response_converter.py @@ -113,7 +113,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter): metadata = sub_stream_response_dict.get("metadata", {}) sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata) response_chunk.update(sub_stream_response_dict) - if isinstance(sub_stream_response, ErrorStreamResponse): + elif isinstance(sub_stream_response, ErrorStreamResponse): data = cls._error_to_stream_response(sub_stream_response.err) response_chunk.update(data) else: diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 0b03149665..6a8e436163 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index d4e801de13..621b0d8cf3 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping, Sequence from dataclasses import dataclass from datetime import datetime -from typing import Any, NewType, Union +from typing import Any, NewType, TypedDict, Union from sqlalchemy import select from sqlalchemy.orm import Session @@ -48,22 +48,23 @@ from core.app.entities.task_entities import ( from core.plugin.impl.datasource import PluginDatasourceManager from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import ( - NodeType, +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import ( + BuiltinNodeTypes, SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import FILE_MODEL_IDENTITY, File -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ArrayFileSegment, FileSegment, Segment -from core.workflow.workflow_entry import WorkflowEntry -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.file import FILE_MODEL_IDENTITY, File +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser @@ -75,6 +76,20 @@ NodeExecutionId = NewType("NodeExecutionId", str) logger = logging.getLogger(__name__) +class AccountCreatedByDict(TypedDict): + id: str + name: str + email: str + + +class EndUserCreatedByDict(TypedDict): + id: str + user: str + + +CreatedByDict = AccountCreatedByDict | EndUserCreatedByDict + + @dataclass(slots=True) class _NodeSnapshot: """In-memory cache for node metadata between start and completion events.""" @@ -248,19 +263,19 @@ class WorkflowResponseConverter: outputs_mapping = graph_runtime_state.outputs or {} encoded_outputs = WorkflowRuntimeTypeConverter().to_json_encodable(outputs_mapping) - created_by: Mapping[str, object] | None + created_by: CreatedByDict | dict[str, object] = {} user = self._user if isinstance(user, Account): - created_by = { - "id": user.id, - "name": user.name, - "email": user.email, - } - else: - created_by = { - "id": user.id, - "user": user.session_id, - } + created_by = AccountCreatedByDict( + id=user.id, + name=user.name, + email=user.email, + ) + elif isinstance(user, EndUser): + created_by = EndUserCreatedByDict( + id=user.id, + user=user.session_id, + ) return WorkflowFinishStreamResponse( task_id=task_id, @@ -442,7 +457,7 @@ class WorkflowResponseConverter: event: QueueNodeStartedEvent, task_id: str, ) -> NodeStartStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() snapshot = self._store_snapshot(event) @@ -464,13 +479,13 @@ class WorkflowResponseConverter: ) try: - if event.node_type == NodeType.TOOL: + if event.node_type == BuiltinNodeTypes.TOOL: response.data.extras["icon"] = ToolManager.get_tool_icon( tenant_id=self._application_generate_entity.app_config.tenant_id, provider_type=ToolProviderType(event.provider_type), provider_id=event.provider_id, ) - elif event.node_type == NodeType.DATASOURCE: + elif event.node_type == BuiltinNodeTypes.DATASOURCE: manager = PluginDatasourceManager() provider_entity = manager.fetch_datasource_provider( self._application_generate_entity.app_config.tenant_id, @@ -479,7 +494,7 @@ class WorkflowResponseConverter: response.data.extras["icon"] = provider_entity.declaration.identity.generate_datasource_icon_url( self._application_generate_entity.app_config.tenant_id ) - elif event.node_type == NodeType.TRIGGER_PLUGIN: + elif event.node_type == TRIGGER_PLUGIN_NODE_TYPE: response.data.extras["icon"] = TriggerManager.get_trigger_plugin_icon( self._application_generate_entity.app_config.tenant_id, event.provider_id, @@ -496,13 +511,13 @@ class WorkflowResponseConverter: event: QueueNodeSucceededEvent | QueueNodeFailedEvent | QueueNodeExceptionEvent, task_id: str, ) -> NodeFinishStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() snapshot = self._pop_snapshot(event.node_execution_id) start_at = snapshot.start_at if snapshot else event.start_at - finished_at = naive_utc_now() + finished_at = event.finished_at or naive_utc_now() elapsed_time = (finished_at - start_at).total_seconds() inputs, inputs_truncated = self._truncate_mapping(event.inputs) @@ -554,7 +569,7 @@ class WorkflowResponseConverter: event: QueueNodeRetryEvent, task_id: str, ) -> NodeRetryStreamResponse | None: - if event.node_type in {NodeType.ITERATION, NodeType.LOOP}: + if event.node_type in {BuiltinNodeTypes.ITERATION, BuiltinNodeTypes.LOOP}: return None run_id = self._ensure_workflow_run_id() @@ -612,7 +627,7 @@ class WorkflowResponseConverter: data=IterationNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, created_at=int(time.time()), extras={}, @@ -635,7 +650,7 @@ class WorkflowResponseConverter: data=IterationNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, index=event.index, created_at=int(time.time()), @@ -662,7 +677,7 @@ class WorkflowResponseConverter: data=IterationNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, outputs=new_outputs, outputs_truncated=outputs_truncated, @@ -692,7 +707,7 @@ class WorkflowResponseConverter: data=LoopNodeStartStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, created_at=int(time.time()), extras={}, @@ -715,7 +730,7 @@ class WorkflowResponseConverter: data=LoopNodeNextStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, index=event.index, # The `pre_loop_output` field is not utilized by the frontend. @@ -744,7 +759,7 @@ class WorkflowResponseConverter: data=LoopNodeCompletedStreamResponse.Data( id=event.node_id, node_id=event.node_id, - node_type=event.node_type.value, + node_type=event.node_type, title=event.node_title, outputs=new_outputs, outputs_truncated=outputs_truncated, diff --git a/api/core/app/apps/completion/app_config_manager.py b/api/core/app/apps/completion/app_config_manager.py index eb1902f12e..f49e7b8b5e 100644 --- a/api/core/app/apps/completion/app_config_manager.py +++ b/api/core/app/apps/completion/app_config_manager.py @@ -1,3 +1,5 @@ +from typing import Any, cast + from core.app.app_config.base_app_config_manager import BaseAppConfigManager from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager from core.app.app_config.easy_ui_based_app.dataset.manager import DatasetConfigManager @@ -8,7 +10,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.app_config.features.more_like_this.manager import MoreLikeThisConfigManager from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, AppModelConfigDict class CompletionAppConfig(EasyUIBasedAppConfig): @@ -22,7 +24,7 @@ class CompletionAppConfig(EasyUIBasedAppConfig): class CompletionAppConfigManager(BaseAppConfigManager): @classmethod def get_app_config( - cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: dict | None = None + cls, app_model: App, app_model_config: AppModelConfig, override_config_dict: AppModelConfigDict | None = None ) -> CompletionAppConfig: """ Convert app model config to completion app config @@ -40,7 +42,9 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_model_config_dict = app_model_config.to_dict() config_dict = app_model_config_dict.copy() else: - config_dict = override_config_dict or {} + if not override_config_dict: + raise Exception("override_config_dict is required when config_from is ARGS") + config_dict = override_config_dict app_mode = AppMode.value_of(app_model.mode) app_config = CompletionAppConfig( @@ -49,7 +53,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): app_mode=app_mode, app_model_config_from=config_from, app_model_config_id=app_model_config.id, - app_model_config_dict=config_dict, + app_model_config_dict=cast(dict[str, Any], config_dict), model=ModelConfigManager.convert(config=config_dict), prompt_template=PromptTemplateConfigManager.convert(config=config_dict), sensitive_word_avoidance=SensitiveWordAvoidanceConfigManager.convert(config=config_dict), @@ -64,7 +68,7 @@ class CompletionAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict): + def config_validate(cls, tenant_id: str, config: dict) -> AppModelConfigDict: """ Validate for completion app model config @@ -116,4 +120,4 @@ class CompletionAppConfigManager(BaseAppConfigManager): # Filter out extra parameters filtered_config = {key: config.get(key) for key in related_config_keys} - return filtered_config + return cast(AppModelConfigDict, filtered_config) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 843328f904..002b914ef1 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -19,8 +19,8 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message @@ -275,7 +275,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): raise ValueError("Message app_model_config is None") override_model_config_dict = app_model_config.to_dict() model_dict = override_model_config_dict["model"] - completion_params = model_dict.get("completion_params") + completion_params = model_dict.get("completion_params", {}) completion_params["temperature"] = 0.9 model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 30e1a609f8..56a4519879 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -11,10 +11,10 @@ from core.app.entities.app_invoke_entities import ( ) from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import ImagePromptMessageContent from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Message @@ -132,8 +132,10 @@ class CompletionAppRunner(AppRunner): hit_callback=hit_callback, message_id=message.id, inputs=inputs, - vision_enabled=application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}).get( - "enabled", False + vision_enabled=bool( + application_generate_entity.app_config.app_model_config_dict.get("file_upload", {}) + .get("image", {}) + .get("enabled", False) ), ) context_files = retrieved_files or [] diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 4e9a191dae..44d10d79b8 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -130,10 +130,10 @@ class MessageBasedAppGenerator(BaseAppGenerator): end_user_id = None account_id = None if application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: - from_source = "api" + from_source = ConversationFromSource.API end_user_id = application_generate_entity.user_id else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE account_id = application_generate_entity.user_id if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity): @@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): message_id=message.id, type=file.type, transfer_method=file.transfer_method, - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, upload_file_id=file.related_id, created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index eca96cb074..19d67eb108 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -33,13 +33,13 @@ from core.datasource.entities.datasource_entities import ( ) from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.index_processor.constant.built_in_field import BuiltInField from core.repositories.factory import DifyCoreRepositoryFactory -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom @@ -419,11 +419,12 @@ class PipelineGenerator(BaseAppGenerator): triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) return self._generate( @@ -514,11 +515,12 @@ class PipelineGenerator(BaseAppGenerator): triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) return self._generate( diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index 02caf8f511..e767766bdb 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -8,23 +8,24 @@ from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import ( InvokeFrom, RagPipelineGenerateEntity, + UserFrom, + build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphEngineEvent, GraphRunFailedEvent -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.graph_init_params import GraphInitParams +from dify_graph.enums import WorkflowType +from dify_graph.graph import Graph +from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db from models.dataset import Document, Pipeline -from models.enums import UserFrom from models.model import EndUser from models.workflow import Workflow @@ -257,13 +258,15 @@ class PipelineRunner(WorkflowBasedAppRunner): # init graph # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id=self.application_generate_entity.user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id=self.application_generate_entity.user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) @@ -271,6 +274,8 @@ class PipelineRunner(WorkflowBasedAppRunner): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + if start_node_id is None: + start_node_id = get_default_root_node_id(graph_config) graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=start_node_id) if not graph: diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index dc5852d552..6fbe19a3b2 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -28,15 +28,15 @@ from core.app.entities.task_entities import WorkflowAppBlockingResponse, Workflo from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -414,11 +414,12 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) return self._generate( @@ -497,11 +498,12 @@ class WorkflowAppGenerator(BaseAppGenerator): triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) draft_var_srv = WorkflowDraftVariableService(db.session()) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user.id) var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, tenant_id=application_generate_entity.app_config.tenant_id, + user_id=user.id, ) return self._generate( app_model=app_model, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index a43f7879d6..caea8b6b95 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -8,15 +8,15 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from core.workflow.enums import WorkflowType -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import WorkflowType +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 0a567a4315..96dd8c5445 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -55,11 +55,11 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory -from core.workflow.runtime import GraphRuntimeState -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory +from dify_graph.runtime import GraphRuntimeState +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index c9d7464c17..adc6cce9af 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,8 +3,11 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from pydantic import ValidationError + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.agent_strategy import AgentStrategyInfo +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.entities.queue_entities import ( AppQueueEvent, QueueAgentLogEvent, @@ -29,12 +32,15 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph import Graph -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import ( +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.graph import Graph +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, @@ -60,14 +66,10 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.graph import GraphRunAbortedEvent -from core.workflow.nodes import NodeType -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool -from core.workflow.workflow_entry import WorkflowEntry -from models.enums import UserFrom +from dify_graph.graph_events.graph import GraphRunAbortedEvent +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -119,13 +121,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=tenant_id or "", - app_id=self._app_id, workflow_id=workflow_id, graph_config=graph_config, - user_id=user_id, - user_from=user_from, - invoke_from=invoke_from, + run_context=build_dify_run_context( + tenant_id=tenant_id or "", + app_id=self._app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ), call_depth=0, ) @@ -136,6 +140,9 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) + if root_node_id is None: + root_node_id = get_default_root_node_id(graph_config) + # init graph graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) @@ -267,13 +274,15 @@ class WorkflowBasedAppRunner: # Create required parameters for Graph.init graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=self._app_id, workflow_id=workflow.id, graph_config=graph_config, - user_id="", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=self._app_id, + user_id="", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) @@ -300,10 +309,12 @@ class WorkflowBasedAppRunner: if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") + target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) + # Get node class - node_type = NodeType(target_node_config.get("data", {}).get("type")) - node_version = target_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_type = target_node_config["data"].type + node_version = str(target_node_config["data"].version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool @@ -331,6 +342,18 @@ class WorkflowBasedAppRunner: return graph, variable_pool + @staticmethod + def _build_agent_strategy_info(event: NodeRunStartedEvent) -> AgentStrategyInfo | None: + raw_agent_strategy = event.extras.get("agent_strategy") + if raw_agent_strategy is None: + return None + + try: + return AgentStrategyInfo.model_validate(raw_agent_strategy) + except ValidationError: + logger.warning("Invalid agent strategy payload for node %s", event.node_id, exc_info=True) + return None + def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent): """ Handle event @@ -416,7 +439,7 @@ class WorkflowBasedAppRunner: start_at=event.start_at, in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, - agent_strategy=event.agent_strategy, + agent_strategy=self._build_agent_strategy_info(event), provider_type=event.provider_type, provider_id=event.provider_id, ) @@ -433,6 +456,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=inputs, process_data=process_data, outputs=outputs, @@ -448,6 +472,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, outputs=event.node_run_result.outputs, @@ -464,6 +489,7 @@ class WorkflowBasedAppRunner: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, outputs=event.node_run_result.outputs, @@ -485,7 +511,9 @@ class WorkflowBasedAppRunner: elif isinstance(event, NodeRunRetrieverResourceEvent): self._publish_event( QueueRetrieverResourcesEvent( - retriever_resources=event.retriever_resources, + retriever_resources=[ + RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources + ], in_iteration_id=event.in_iteration_id, in_loop_id=event.in_loop_id, ) diff --git a/api/core/app/entities/__init__.py b/api/core/app/entities/__init__.py index e69de29bb2..8e41acee32 100644 --- a/api/core/app/entities/__init__.py +++ b/api/core/app/entities/__init__.py @@ -0,0 +1,3 @@ +from .agent_strategy import AgentStrategyInfo + +__all__ = ["AgentStrategyInfo"] diff --git a/api/core/app/entities/agent_strategy.py b/api/core/app/entities/agent_strategy.py new file mode 100644 index 0000000000..b063a12f4f --- /dev/null +++ b/api/core/app/entities/agent_strategy.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel, ConfigDict + + +class AgentStrategyInfo(BaseModel): + name: str + icon: str | None = None + + model_config = ConfigDict(extra="forbid") diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 65919e89e1..ecbb1cf2f3 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,78 +7,75 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from core.model_runtime.entities.model_entities import AIModelEntity -from core.workflow.file import File, FileUploadConfig +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.file import File, FileUploadConfig +from dify_graph.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +class UserFrom(StrEnum): + ACCOUNT = "account" + END_USER = "end-user" + + class InvokeFrom(StrEnum): - """ - Invoke From. - """ - - # SERVICE_API indicates that this invocation is from an API call to Dify app. - # - # Description of service api in Dify docs: - # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis SERVICE_API = "service-api" - - # WEB_APP indicates that this invocation is from - # the web app of the workflow (or chatflow). - # - # Description of web app in Dify docs: - # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README WEB_APP = "web-app" - - # TRIGGER indicates that this invocation is from a trigger. - # this is used for plugin trigger and webhook trigger. TRIGGER = "trigger" - - # EXPLORE indicates that this invocation is from - # the workflow (or chatflow) explore page. EXPLORE = "explore" - # DEBUGGER indicates that this invocation is from - # the workflow (or chatflow) edit page. DEBUGGER = "debugger" - # PUBLISHED_PIPELINE indicates that this invocation runs a published RAG pipeline workflow. PUBLISHED_PIPELINE = "published" - - # VALIDATION indicates that this invocation is from validation. VALIDATION = "validation" @classmethod - def value_of(cls, value: str): - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid invoke from value {value}") + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) def to_source(self) -> str: - """ - Get source of invoke from. + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") - :return: source - """ - if self == InvokeFrom.WEB_APP: - return "web_app" - elif self == InvokeFrom.DEBUGGER: - return "dev" - elif self == InvokeFrom.EXPLORE: - return "explore_app" - elif self == InvokeFrom.TRIGGER: - return "trigger" - elif self == InvokeFrom.SERVICE_API: - return "api" - return "dev" +class DifyRunContext(BaseModel): + tenant_id: str + app_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + +def build_dify_run_context( + *, + tenant_id: str, + app_id: str, + user_id: str, + user_from: UserFrom, + invoke_from: InvokeFrom, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + """ + Build graph run_context with the reserved Dify runtime payload. + + `extra_context` can carry user-defined context keys. The reserved `_dify` + payload is always overwritten by this function to keep one canonical source. + """ + run_context = dict(extra_context) if extra_context else {} + run_context[DIFY_RUN_CONTEXT_KEY] = DifyRunContext( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + return run_context class ModelConfigWithCredentialsEntity(BaseModel): diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 5b2fa29b56..d2a36f2a0d 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -5,13 +5,12 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.nodes import NodeType +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): @@ -314,7 +313,7 @@ class QueueNodeStartedEvent(AppQueueEvent): in_iteration_id: str | None = None in_loop_id: str | None = None start_at: datetime - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None # FIXME(-LAN-): only for ToolNode, need to refactor provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType @@ -336,6 +335,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) @@ -391,6 +391,7 @@ class QueueNodeExceptionEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) @@ -415,6 +416,7 @@ class QueueNodeFailedEvent(AppQueueEvent): in_loop_id: str | None = None """loop id if node is in loop""" start_at: datetime + finished_at: datetime | None = None inputs: Mapping[str, object] = Field(default_factory=dict) process_data: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index f6096eb683..46a8ab52f2 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -4,12 +4,12 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): @@ -349,7 +349,7 @@ class NodeStartStreamResponse(StreamResponse): extras: dict[str, object] = Field(default_factory=dict) iteration_id: str | None = None loop_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None + agent_strategy: AgentStrategyInfo | None = None event: StreamEvent = StreamEvent.NODE_STARTED workflow_run_id: str diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py index 3f9f3da9b2..87d4772815 100644 --- a/api/core/app/features/annotation_reply/annotation_reply.py +++ b/api/core/app/features/annotation_reply/annotation_reply.py @@ -6,6 +6,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.rag.datasource.vdb.vector_factory import Vector from extensions.ext_database import db from models.dataset import Dataset +from models.enums import CollectionBindingType, ConversationFromSource from models.model import App, AppAnnotationSetting, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.dataset_service import DatasetCollectionBindingService @@ -43,7 +44,7 @@ class AnnotationReplyFeature: embedding_model_name = collection_binding_detail.model_name dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) dataset = Dataset( @@ -67,9 +68,9 @@ class AnnotationReplyFeature: annotation = AppAnnotationService.get_annotation_by_id(annotation_id) if annotation: if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP}: - from_source = "api" + from_source = ConversationFromSource.API else: - from_source = "console" + from_source = ConversationFromSource.CONSOLE # insert annotation history AppAnnotationService.add_annotation_history( diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index a5a5486581..5ed1fadc41 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -2,7 +2,7 @@ import logging from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from core.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index a748d90387..d227e4e904 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,12 @@ import logging -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.conversation_variable_updater import ConversationVariableUpdater -from core.workflow.enums import NodeType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.variables import VariableBase +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.conversation_variable_updater import ConversationVariableUpdater +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.variables import VariableBase logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): def on_event(self, event: GraphEngineEvent) -> None: if not isinstance(event, NodeRunSucceededEvent): return - if event.node_type != NodeType.VARIABLE_ASSIGNER: + if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: return if self.graph_runtime_state is None: return diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 1c267091a4..4370c01a0b 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,9 +6,9 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunPausedEvent from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index 0a107de012..2adaf14a35 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,6 @@ -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index f82397deca..d7ca45f209 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -4,9 +4,9 @@ from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore -from core.workflow.graph_engine.entities.commands import CommandType, GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent +from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a7ea9ef446..a4019a83e1 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,9 +5,9 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events.base import GraphEngineEvent -from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events.base import GraphEngineEvent +from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index ebae830389..a63ff39fa5 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -5,11 +5,11 @@ from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.nodes.llm.entities import ModelConfig +from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 1c66c8c1ff..7aa3bf15ab 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -6,7 +6,7 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 26c7e60a4c..0d5e0acec6 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -16,8 +16,8 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.errors.error import QuotaExceededError -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.moderation.output_moderation import ModerationRule, OutputModeration +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index a77e5abb30..b530fe1ce4 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -2,7 +2,7 @@ import logging import time from collections.abc import Generator from threading import Thread -from typing import Union, cast +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -44,21 +44,20 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from core.app.task_pipeline.message_file_utils import prepare_file_dict from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.tools.signature import sign_tool_file -from core.workflow.file import helpers as file_helpers -from core.workflow.file.enums import FileTransferMethod +from dify_graph.file.enums import FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -219,14 +218,14 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): tenant_id = self._application_generate_entity.app_config.tenant_id task_id = self._application_generate_entity.task_id publisher = None - text_to_speech_dict = self._app_config.app_model_config_dict.get("text_to_speech") + text_to_speech_dict = cast(dict[str, Any], self._app_config.app_model_config_dict.get("text_to_speech")) if ( text_to_speech_dict and text_to_speech_dict.get("autoPlay") == "enabled" and text_to_speech_dict.get("enabled") ): publisher = AppGeneratorTTSPublisher( - tenant_id, text_to_speech_dict.get("voice", None), text_to_speech_dict.get("language", None) + tenant_id, text_to_speech_dict.get("voice", ""), text_to_speech_dict.get("language", None) ) for response in self._process_stream_response(publisher=publisher, trace_manager=trace_manager): while True: @@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): """ self._task_state.metadata.usage = self._task_state.llm_result.usage metadata_dict = self._task_state.metadata.model_dump() + + # Fetch files associated with this message + files = None + with Session(db.engine, expire_on_commit=False) as session: + message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() + + if message_files: + # Fetch all required UploadFile objects in a single query to avoid N+1 problem + upload_file_ids = list( + dict.fromkeys( + mf.upload_file_id + for mf in message_files + if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id + ) + ) + upload_files_map = {} + if upload_file_ids: + upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() + upload_files_map = {uf.id: uf for uf in upload_files} + + files_list = [] + for message_file in message_files: + file_dict = prepare_file_dict(message_file, upload_files_map) + files_list.append(file_dict) + + files = files_list or None + return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, metadata=metadata_dict, + files=files, ) - def _record_files(self): - with Session(db.engine, expire_on_commit=False) as session: - message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() - if not message_files: - return None - - files_list = [] - upload_file_ids = [ - mf.upload_file_id - for mf in message_files - if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id - ] - upload_files_map = {} - if upload_file_ids: - upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() - upload_files_map = {uf.id: uf for uf in upload_files} - - for message_file in message_files: - upload_file = None - if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: - upload_file = upload_files_map.get(message_file.upload_file_id) - - url = None - filename = "file" - mime_type = "application/octet-stream" - size = 0 - extension = "" - - if message_file.transfer_method == FileTransferMethod.REMOTE_URL: - url = message_file.url - if message_file.url: - filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params - elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: - if upload_file: - url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) - filename = upload_file.name - mime_type = upload_file.mime_type or "application/octet-stream" - size = upload_file.size or 0 - extension = f".{upload_file.extension}" if upload_file.extension else "" - elif message_file.upload_file_id: - # Fallback: generate URL even if upload_file not found - url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) - elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: - # For tool files, use URL directly if it's HTTP, otherwise sign it - if message_file.url.startswith("http"): - url = message_file.url - filename = message_file.url.split("/")[-1].split("?")[0] - else: - # Extract tool file id and extension from URL - url_parts = message_file.url.split("/") - if url_parts: - file_part = url_parts[-1].split("?")[0] # Remove query params first - # Use rsplit to correctly handle filenames with multiple dots - if "." in file_part: - tool_file_id, ext = file_part.rsplit(".", 1) - extension = f".{ext}" - else: - tool_file_id = file_part - extension = ".bin" - url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) - filename = file_part - - transfer_method_value = message_file.transfer_method - remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" - file_dict = { - "related_id": message_file.id, - "extension": extension, - "filename": filename, - "size": size, - "mime_type": mime_type, - "transfer_method": transfer_method_value, - "type": message_file.type, - "url": url or "", - "upload_file_id": message_file.upload_file_id or message_file.id, - "remote_url": remote_url, - } - files_list.append(file_dict) - return files_list or None - def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: """ Agent message to stream response. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index cc4f97ad94..62f27060b4 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -1,7 +1,6 @@ import hashlib import logging -import time -from threading import Thread +from threading import Thread, Timer from typing import Union from flask import Flask, current_app @@ -35,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator from core.tools.signature import sign_tool_file from extensions.ext_database import db from extensions.ext_redis import redis_client +from models.enums import MessageFileBelongsTo from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService @@ -96,9 +96,9 @@ class MessageCycleManager: if auto_generate_conversation_name and is_first_message: # start generate thread # time.sleep not block other logic - time.sleep(1) - thread = Thread( - target=self._generate_conversation_name_worker, + thread = Timer( + 1, + self._generate_conversation_name_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore "conversation_id": conversation_id, @@ -234,7 +234,7 @@ class MessageCycleManager: task_id=self._application_generate_entity.task_id, id=message_file.id, type=message_file.type, - belongs_to=message_file.belongs_to or "user", + belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER, url=url, ) diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py new file mode 100644 index 0000000000..fc8b6c6b5a --- /dev/null +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -0,0 +1,91 @@ +from typing import TypedDict + +from core.tools.signature import sign_tool_file +from dify_graph.file import helpers as file_helpers +from dify_graph.file.enums import FileTransferMethod +from models.model import MessageFile, UploadFile + +MAX_TOOL_FILE_EXTENSION_LENGTH = 10 + + +class MessageFileInfoDict(TypedDict): + related_id: str + extension: str + filename: str + size: int + mime_type: str + transfer_method: str + type: str + url: str + upload_file_id: str + remote_url: str | None + + +def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> MessageFileInfoDict: + """ + Prepare file dictionary for message end stream response. + + :param message_file: MessageFile instance + :param upload_files_map: Dictionary mapping upload_file_id to UploadFile + :return: Dictionary containing file information + """ + upload_file = None + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: + upload_file = upload_files_map.get(message_file.upload_file_id) + + url = None + filename = "file" + mime_type = "application/octet-stream" + size = 0 + extension = "" + + if message_file.transfer_method == FileTransferMethod.REMOTE_URL: + url = message_file.url + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: + if message_file.url.startswith(("http://", "https://")): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + if "." in filename: + extension = "." + filename.rsplit(".", 1)[1] + else: + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH: + extension = ".bin" + else: + tool_file_id = file_part + extension = ".bin" + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + + transfer_method_value = message_file.transfer_method.value + remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" + return { + "related_id": message_file.id, + "extension": extension, + "filename": filename, + "size": size, + "mime_type": mime_type, + "transfer_method": transfer_method_value, + "type": message_file.type, + "url": url or "", + "upload_file_id": message_file.upload_file_id or message_file.id, + "remote_url": remote_url, + } diff --git a/api/core/app/workflow/__init__.py b/api/core/app/workflow/__init__.py index 172ee5d703..3bca7f5c34 100644 --- a/api/core/app/workflow/__init__.py +++ b/api/core/app/workflow/__init__.py @@ -1,3 +1,3 @@ -from .node_factory import DifyNodeFactory +from core.workflow.node_factory import DifyNodeFactory __all__ = ["DifyNodeFactory"] diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index 954638b901..e0f8d27111 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -5,13 +5,13 @@ from collections.abc import Generator from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from core.workflow.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from core.workflow.file.runtime import set_workflow_file_runtime +from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from dify_graph.file.runtime import set_workflow_file_runtime from extensions.ext_storage import storage class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``core.workflow.file``.""" + """Production runtime wiring for ``dify_graph.file``.""" @property def files_url(self) -> str: diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index 45fb84c81f..faf1516c40 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -12,17 +12,17 @@ from typing_extensions import override from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from core.workflow.enums import NodeType -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.base.node import Node +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.nodes.base.node import Node if TYPE_CHECKING: - from core.workflow.nodes.llm.node import LLMNode - from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from dify_graph.nodes.llm.node import LLMNode + from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,8 +75,9 @@ class LLMQuotaLayer(GraphEngineLayer): return try: + dify_ctx = node.require_dify_context() deduct_llm_quota( - tenant_id=node.tenant_id, + tenant_id=dify_ctx.tenant_id, model_instance=model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -112,11 +113,11 @@ class LLMQuotaLayer(GraphEngineLayer): def _extract_model_instance(node: Node) -> ModelInstance | None: try: match node.node_type: - case NodeType.LLM: + case BuiltinNodeTypes.LLM: return cast("LLMNode", node).model_instance - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: return cast("ParameterExtractorNode", node).model_instance - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: return cast("QuestionClassifierNode", node).model_instance case _: return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 94839c8ae3..4b20477a7f 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -16,10 +16,10 @@ from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_ from typing_extensions import override from configs import dify_config -from core.workflow.enums import NodeType -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, @@ -74,16 +74,13 @@ class ObservabilityLayer(GraphEngineLayer): def _build_parser_registry(self) -> None: """Initialize parser registry for node types.""" self._parsers = { - NodeType.TOOL: ToolNodeOTelParser(), - NodeType.LLM: LLMNodeOTelParser(), - NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(), + BuiltinNodeTypes.TOOL: ToolNodeOTelParser(), + BuiltinNodeTypes.LLM: LLMNodeOTelParser(), + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(), } def _get_parser(self, node: Node) -> NodeOTelParser: - node_type = getattr(node, "node_type", None) - if isinstance(node_type, NodeType): - return self._parsers.get(node_type, self._default_parser) - return self._default_parser + return self._parsers.get(node.node_type, self._default_parser) @override def on_graph_start(self) -> None: diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index 41052b4f52..99b64b3ab5 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,17 +17,17 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution -from core.workflow.enums import ( +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution +from dify_graph.enums import ( SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import ( +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +42,9 @@ from core.workflow.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.node_events import NodeRunResult +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.datetime_utils import naive_utc_now @@ -268,7 +268,12 @@ class WorkflowPersistenceLayer(GraphEngineLayer): def _handle_node_succeeded(self, event: NodeRunSucceededEvent) -> None: domain_execution = self._get_node_execution(event.id) - self._update_node_execution(domain_execution, event.node_run_result, WorkflowNodeExecutionStatus.SUCCEEDED) + self._update_node_execution( + domain_execution, + event.node_run_result, + WorkflowNodeExecutionStatus.SUCCEEDED, + finished_at=event.finished_at, + ) def _handle_node_failed(self, event: NodeRunFailedEvent) -> None: domain_execution = self._get_node_execution(event.id) @@ -277,6 +282,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): event.node_run_result, WorkflowNodeExecutionStatus.FAILED, error=event.error, + finished_at=event.finished_at, ) def _handle_node_exception(self, event: NodeRunExceptionEvent) -> None: @@ -286,6 +292,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer): event.node_run_result, WorkflowNodeExecutionStatus.EXCEPTION, error=event.error, + finished_at=event.finished_at, ) def _handle_node_pause_requested(self, event: NodeRunPauseRequestedEvent) -> None: @@ -352,13 +359,14 @@ class WorkflowPersistenceLayer(GraphEngineLayer): *, error: str | None = None, update_outputs: bool = True, + finished_at: datetime | None = None, ) -> None: - finished_at = naive_utc_now() + actual_finished_at = finished_at or naive_utc_now() snapshot = self._node_snapshots.get(domain_execution.id) start_at = snapshot.created_at if snapshot else domain_execution.created_at domain_execution.status = status - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = max((finished_at - start_at).total_seconds(), 0.0) + domain_execution.finished_at = actual_finished_at + domain_execution.elapsed_time = max((actual_finished_at - start_at).total_seconds(), 0.0) if error: domain_execution.error = error diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py deleted file mode 100644 index 3a82f0a45e..0000000000 --- a/api/core/app/workflow/node_factory.py +++ /dev/null @@ -1,337 +0,0 @@ -from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, cast, final - -from sqlalchemy import select -from sqlalchemy.orm import Session -from typing_extensions import override - -from configs import dify_config -from core.app.llm.model_access import build_dify_model_access -from core.datasource.datasource_manager import DatasourceManager -from core.helper.code_executor.code_executor import ( - CodeExecutionError, - CodeExecutor, -) -from core.helper.ssrf_proxy import ssrf_proxy -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities.graph_config import NodeConfigDict -from core.workflow.enums import NodeType, SystemVariableKey -from core.workflow.file.file_manager import file_manager -from core.workflow.graph.graph import NodeFactory -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code.code_node import CodeNode, WorkflowCodeExecutor -from core.workflow.nodes.code.entities import CodeLanguage -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.nodes.datasource import DatasourceNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig -from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config -from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.nodes.llm.entities import ModelConfig -from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode -from core.workflow.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.variables.segments import StringSegment -from extensions.ext_database import db -from models.model import Conversation - -if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState - - -def fetch_memory( - *, - conversation_id: str | None, - app_id: str, - node_data_memory: MemoryConfig | None, - model_instance: ModelInstance, -) -> TokenBufferMemory | None: - if not node_data_memory or not conversation_id: - return None - - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) - conversation = session.scalar(stmt) - if not conversation: - return None - - return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - -class DefaultWorkflowCodeExecutor: - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: - return CodeExecutor.execute_workflow_code_template( - language=language, - code=code, - inputs=inputs, - ) - - def is_execution_error(self, error: Exception) -> bool: - return isinstance(error, CodeExecutionError) - - -@final -class DifyNodeFactory(NodeFactory): - """ - Default implementation of NodeFactory that uses the traditional node mapping. - - This factory creates nodes by looking up their types in NODE_TYPE_CLASSES_MAPPING - and instantiating the appropriate node class. - """ - - def __init__( - self, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ) -> None: - self.graph_init_params = graph_init_params - self.graph_runtime_state = graph_runtime_state - self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() - self._code_limits = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, - ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer() - self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH - self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager - self._http_request_file_manager = file_manager - self._rag_retrieval = DatasetRetrieval() - self._document_extractor_unstructured_api_config = UnstructuredApiConfig( - api_url=dify_config.UNSTRUCTURED_API_URL, - api_key=dify_config.UNSTRUCTURED_API_KEY or "", - ) - self._http_request_config = build_http_request_config( - max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, - max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, - max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, - max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE, - max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE, - ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY, - ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, - ) - - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id) - - @override - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data using the traditional mapping. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid - """ - # Get node_id from config - node_id = node_config["id"] - - # Get node type from config - node_data = node_config["data"] - try: - node_type = NodeType(node_data["type"]) - except ValueError: - raise ValueError(f"Unknown node type: {node_data['type']}") - - # Get node class - node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type) - if not node_mapping: - raise ValueError(f"No class mapping found for node type: {node_type}") - - latest_node_class = node_mapping.get(LATEST_VERSION) - node_version = str(node_data.get("version", "1")) - matched_node_class = node_mapping.get(node_version) - node_class = matched_node_class or latest_node_class - if not node_class: - raise ValueError(f"No latest version class found for node type: {node_type}") - - # Create node instance - if node_type == NodeType.CODE: - return CodeNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - code_executor=self._code_executor, - code_limits=self._code_limits, - ) - - if node_type == NodeType.TEMPLATE_TRANSFORM: - return TemplateTransformNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - template_renderer=self._template_renderer, - max_output_length=self._template_transform_max_output_length, - ) - - if node_type == NodeType.HTTP_REQUEST: - return HttpRequestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - http_request_config=self._http_request_config, - http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, - file_manager=self._http_request_file_manager, - ) - - if node_type == NodeType.LLM: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return LLMNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - ) - - if node_type == NodeType.DATASOURCE: - return DatasourceNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - datasource_manager=DatasourceManager, - ) - - if node_type == NodeType.KNOWLEDGE_RETRIEVAL: - return KnowledgeRetrievalNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - rag_retrieval=self._rag_retrieval, - ) - - if node_type == NodeType.DOCUMENT_EXTRACTOR: - return DocumentExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - unstructured_api_config=self._document_extractor_unstructured_api_config, - ) - - if node_type == NodeType.QUESTION_CLASSIFIER: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return QuestionClassifierNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - ) - - if node_type == NodeType.PARAMETER_EXTRACTOR: - model_instance = self._build_model_instance_for_llm_node(node_data) - memory = self._build_memory_for_llm_node(node_data=node_data, model_instance=model_instance) - return ParameterExtractorNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - credentials_provider=self._llm_credentials_provider, - model_factory=self._llm_model_factory, - model_instance=model_instance, - memory=memory, - ) - - return node_class( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - - def _build_model_instance_for_llm_node(self, node_data: Mapping[str, Any]) -> ModelInstance: - node_data_model = ModelConfig.model_validate(node_data["model"]) - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, - ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) - model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - return model_instance - - def _build_memory_for_llm_node( - self, - *, - node_data: Mapping[str, Any], - model_instance: ModelInstance, - ) -> PromptMessageMemory | None: - raw_memory_config = node_data.get("memory") - if raw_memory_config is None: - return None - - node_memory = MemoryConfig.model_validate(raw_memory_config) - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) - return fetch_memory( - conversation_id=conversation_id, - app_id=self.graph_init_params.app_id, - node_data_memory=node_memory, - model_instance=model_instance, - ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index f83aaa0006..beda515666 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -15,8 +15,8 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import TextPromptMessageContent -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelType class AudioTrunk: diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index d0279349ca..8de5cb1690 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -12,6 +12,7 @@ from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import CreatorUserRole, DatasetQuerySource _logger = logging.getLogger(__name__) @@ -35,10 +36,12 @@ class DatasetIndexToolCallbackHandler: dataset_query = DatasetQuery( dataset_id=dataset_id, content=query, - source="app", + source=DatasetQuerySource.APP, source_app_id=self._app_id, created_by_role=( - "account" if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} else "end_user" + CreatorUserRole.ACCOUNT + if self._invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER ), created_by=self._user_id, ) diff --git a/api/core/datasource/__base/datasource_provider.py b/api/core/datasource/__base/datasource_provider.py index bae39dc8c7..4b47777f0b 100644 --- a/api/core/datasource/__base/datasource_provider.py +++ b/api/core/datasource/__base/datasource_provider.py @@ -59,8 +59,6 @@ class DatasourcePluginProviderController(ABC): :param credentials: the credentials of the tool """ credentials_schema = dict[str, ProviderConfig]() - if credentials_schema is None: - return for credential in self.entity.credentials_schema: credentials_schema[credential.name] = credential diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index f67bfb6ead..24243add17 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -15,6 +15,7 @@ from configs import dify_config from core.helper import ssrf_proxy from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from models.enums import CreatorUserRole from models.model import MessageFile, UploadFile from models.tools import ToolFile @@ -81,7 +82,7 @@ class DatasourceFileManager: upload_file = UploadFile( tenant_id=tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=filepath, name=present_filename, size=len(file_binary), @@ -213,6 +214,6 @@ class DatasourceFileManager: # init tool_file_parser -# from core.workflow.file.datasource_file_parser import datasource_file_manager +# from dify_graph.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 9c48f755a9..4fa941ae16 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -24,12 +24,12 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.file import File -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.repositories.datasource_manager_protocol import DatasourceParameter, OnlineDriveDownloadFileParam +from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory from models.model import UploadFile from models.tools import ToolFile diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 1179537570..4c9ff64479 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -3,8 +3,8 @@ from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index ab3302bd6e..2881888e27 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,7 +4,7 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from models.tools import ToolFile logger = logging.getLogger(__name__) diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 46006f4381..1343bd8e82 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index 5902c03e27..d214652e9c 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -15,7 +15,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index a123fb0321..3427fc54b1 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -3,9 +3,9 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType, ProviderModel -from core.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class ModelStatus(StrEnum): diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 8a26b2e91b..a9f2300ba2 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -19,17 +19,18 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from libs.datetime_utils import naive_utc_now from models.engine import db +from models.enums import CredentialSourceType from models.provider import ( LoadBalancingModelConfig, Provider, @@ -473,9 +474,21 @@ class ProviderConfiguration(BaseModel): self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) else: - # some historical data may have a provider record but not be set as valid provider_record.is_valid = True + if provider_record.credential_id is None: + provider_record.credential_id = new_record.id + provider_record.updated_at = naive_utc_now() + + provider_model_credentials_cache = ProviderCredentialsCache( + tenant_id=self.tenant_id, + identity_id=provider_record.id, + cache_type=ProviderCredentialsCacheType.PROVIDER, + ) + provider_model_credentials_cache.delete() + + self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session) + session.commit() except Exception: session.rollback() @@ -534,7 +547,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="provider", + credential_source=CredentialSourceType.PROVIDER, session=session, ) except Exception: @@ -611,7 +624,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "provider", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.PROVIDER, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() try: @@ -1031,7 +1044,7 @@ class ProviderConfiguration(BaseModel): self._update_load_balancing_configs_with_credential( credential_id=credential_id, credential_record=credential_record, - credential_source="custom_model", + credential_source=CredentialSourceType.CUSTOM_MODEL, session=session, ) except Exception: @@ -1061,7 +1074,7 @@ class ProviderConfiguration(BaseModel): LoadBalancingModelConfig.tenant_id == self.tenant_id, LoadBalancingModelConfig.provider_name.in_(self._get_provider_names()), LoadBalancingModelConfig.credential_id == credential_id, - LoadBalancingModelConfig.credential_source_type == "custom_model", + LoadBalancingModelConfig.credential_source_type == CredentialSourceType.CUSTOM_MODEL, ) lb_configs_using_credential = session.execute(lb_stmt).scalars().all() @@ -1409,12 +1422,12 @@ class ProviderConfiguration(BaseModel): preferred_model_provider = s.execute(stmt).scalars().first() if preferred_model_provider: - preferred_model_provider.preferred_provider_type = provider_type.value + preferred_model_provider.preferred_provider_type = provider_type else: preferred_model_provider = TenantPreferredModelProvider( tenant_id=self.tenant_id, provider_name=self.provider.provider, - preferred_provider_type=provider_type.value, + preferred_provider_type=provider_type, ) s.add(preferred_model_provider) s.commit() @@ -1699,7 +1712,7 @@ class ProviderConfiguration(BaseModel): provider_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "custom_model" + if config.credential_source_type != CredentialSourceType.CUSTOM_MODEL ] load_balancing_enabled = model_setting.load_balancing_enabled @@ -1757,7 +1770,7 @@ class ProviderConfiguration(BaseModel): custom_model_lb_configs = [ config for config in model_setting.load_balancing_configs - if config.credential_source_type != "provider" + if config.credential_source_type != CredentialSourceType.PROVIDER ] load_balancing_enabled = model_setting.load_balancing_enabled diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index 0078ec7e4f..a830f227a9 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -11,8 +11,8 @@ from core.entities.parameter_entities import ( ModelSelectorScope, ToolSelectorScope, ) -from core.model_runtime.entities.model_entities import ModelType from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index d581b3ac39..4251cfd30b 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -13,7 +13,7 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from core.workflow.nodes.code.entities import CodeLanguage +from dify_graph.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 1b56eaba21..c569e066f4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 86bac4119a..873f6a4093 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,10 +4,10 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeBadRequestError -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 370e64e385..600a444357 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -4,7 +4,7 @@ from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 4e3ad7bb75..52776ee626 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -5,6 +5,7 @@ import re import threading import time import uuid +from collections.abc import Mapping from typing import Any from flask import Flask, current_app @@ -15,7 +16,6 @@ from configs import dify_config from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.docstore.dataset_docstore import DatasetDocumentStore @@ -31,14 +31,16 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper from libs.datetime_utils import naive_utc_now from models import Account -from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment +from models.dataset import AutomaticRulesConfig, ChildChunk, Dataset, DatasetProcessRule, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import DataSourceType, IndexingStatus, ProcessRuleMode, SegmentStatus from models.model import UploadFile from services.feature_service import FeatureService @@ -55,7 +57,7 @@ class IndexingRunner: logger.exception("consume document failed") document = db.session.get(DatasetDocument, document_id) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR error_message = getattr(error, "description", str(error)) document.error = str(error_message) document.stopped_at = naive_utc_now() @@ -218,7 +220,7 @@ class IndexingRunner: if document_segments: for document_segment in document_segments: # transform segment to node - if document_segment.status != "completed": + if document_segment.status != SegmentStatus.COMPLETED: document = Document( page_content=document_segment.content, metadata={ @@ -265,7 +267,7 @@ class IndexingRunner: self, tenant_id: str, extract_settings: list[ExtractSetting], - tmp_processing_rule: dict, + tmp_processing_rule: Mapping[str, Any], doc_form: str | None = None, doc_language: str = "English", dataset_id: str | None = None, @@ -376,12 +378,12 @@ class IndexingRunner: return IndexingEstimate(total_segments=total_segments, preview=preview_texts) def _extract( - self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict + self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: Mapping[str, Any] ) -> list[Document]: data_source_info = dataset_document.data_source_info_dict text_docs = [] match dataset_document.data_source_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if not data_source_info or "upload_file_id" not in data_source_info: raise ValueError("no upload file found") stmt = select(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]) @@ -394,7 +396,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if ( not data_source_info or "notion_workspace_id" not in data_source_info @@ -416,7 +418,7 @@ class IndexingRunner: document_model=dataset_document.doc_form, ) text_docs = index_processor.extract(extract_setting, process_rule_mode=process_rule["mode"]) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if ( not data_source_info or "provider" not in data_source_info @@ -444,7 +446,7 @@ class IndexingRunner: # update document status to splitting self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="splitting", + after_indexing_status=IndexingStatus.SPLITTING, extra_update_params={ DatasetDocument.parsing_completed_at: naive_utc_now(), }, @@ -543,7 +545,8 @@ class IndexingRunner: """ Clean the document text according to the processing rules. """ - if processing_rule.mode == "automatic": + rules: AutomaticRulesConfig | dict[str, Any] + if processing_rule.mode == ProcessRuleMode.AUTOMATIC: rules = DatasetProcessRule.AUTOMATIC_RULES else: rules = json.loads(processing_rule.rules) if processing_rule.rules else {} @@ -634,7 +637,7 @@ class IndexingRunner: # update document status to completed self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="completed", + after_indexing_status=IndexingStatus.COMPLETED, extra_update_params={ DatasetDocument.tokens: tokens, DatasetDocument.completed_at: naive_utc_now(), @@ -657,10 +660,10 @@ class IndexingRunner: DocumentSegment.document_id == document_id, DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -701,10 +704,10 @@ class IndexingRunner: DocumentSegment.document_id == dataset_document.id, DocumentSegment.dataset_id == dataset.id, DocumentSegment.index_node_id.in_(document_ids), - DocumentSegment.status == "indexing", + DocumentSegment.status == SegmentStatus.INDEXING, ).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.enabled: True, DocumentSegment.completed_at: naive_utc_now(), } @@ -723,7 +726,7 @@ class IndexingRunner: @staticmethod def _update_document_index_status( - document_id: str, after_indexing_status: str, extra_update_params: dict | None = None + document_id: str, after_indexing_status: IndexingStatus, extra_update_params: dict | None = None ): """ Update the document indexing status. @@ -756,7 +759,7 @@ class IndexingRunner: dataset: Dataset, text_docs: list[Document], doc_language: str, - process_rule: dict, + process_rule: Mapping[str, Any], current_user: Account | None = None, ) -> list[Document]: # get embedding model instance @@ -801,7 +804,7 @@ class IndexingRunner: cur_time = naive_utc_now() self._update_document_index_status( document_id=dataset_document.id, - after_indexing_status="indexing", + after_indexing_status=IndexingStatus.INDEXING, extra_update_params={ DatasetDocument.cleaning_completed_at: cur_time, DatasetDocument.splitting_completed_at: cur_time, @@ -813,7 +816,7 @@ class IndexingRunner: self._update_segments_by_document( dataset_document_id=dataset_document.id, update_params={ - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), }, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index 5b2c640265..c8848336d9 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -23,15 +23,15 @@ from core.llm_generator.prompts import ( WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, ) from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -193,7 +193,8 @@ class LLMGenerator: error_step = "generate rule config" except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "generate rule config" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" @@ -279,7 +280,8 @@ class LLMGenerator: except Exception as e: logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name) - rule_config["error"] = str(e) + error = str(e) + error_step = "handle unexpected exception" rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else "" diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 686529c3ca..77ea1713ea 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -10,22 +10,22 @@ from pydantic import TypeAdapter, ValidationError from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT from core.model_manager import ModelInstance -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import ( +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule class ResponseFormat(StrEnum): diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index aef1afb235..d015769b54 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -55,15 +55,31 @@ def build_protected_resource_metadata_discovery_urls( """ urls = [] + parsed_server_url = urlparse(server_url) + base_url = f"{parsed_server_url.scheme}://{parsed_server_url.netloc}" + path = parsed_server_url.path.rstrip("/") + # First priority: URL from WWW-Authenticate header if www_auth_resource_metadata_url: - urls.append(www_auth_resource_metadata_url) + parsed_metadata_url = urlparse(www_auth_resource_metadata_url) + normalized_metadata_url = None + if parsed_metadata_url.scheme and parsed_metadata_url.netloc: + normalized_metadata_url = www_auth_resource_metadata_url + elif not parsed_metadata_url.scheme and parsed_metadata_url.netloc: + normalized_metadata_url = f"{parsed_server_url.scheme}:{www_auth_resource_metadata_url}" + elif ( + not parsed_metadata_url.scheme + and not parsed_metadata_url.netloc + and parsed_metadata_url.path.startswith("/") + ): + first_segment = parsed_metadata_url.path.lstrip("/").split("/", 1)[0] + if first_segment == ".well-known" or "." not in first_segment: + normalized_metadata_url = urljoin(base_url, parsed_metadata_url.path) + + if normalized_metadata_url: + urls.append(normalized_metadata_url) # Fallback: construct from server URL - parsed = urlparse(server_url) - base_url = f"{parsed.scheme}://{parsed.netloc}" - path = parsed.path.rstrip("/") - # Priority 2: With path insertion (e.g., /.well-known/oauth-protected-resource/public/mcp) if path: path_url = f"{base_url}/.well-known/oauth-protected-resource{path}" diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index da747d2c1f..de68eb268b 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -7,7 +7,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index 84bef7b935..db9cb726d7 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -8,7 +8,7 @@ from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2b78a705c9..1156a98af1 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -5,7 +5,9 @@ from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -13,9 +15,7 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from core.workflow.file import file_manager +from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 2b3a3be1b9..0f710a8fcf 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,20 +7,20 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tts_model import TTSModel from core.provider_manager import ProviderManager +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 5cab4841f5..06676f5cf4 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,6 @@ from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult +from dify_graph.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 46c129099d..18f35b5b9c 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -57,8 +57,8 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom @@ -302,11 +302,11 @@ class AliyunDataTrace(BaseTraceInstance): self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata ): try: - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: node_span = self.build_workflow_llm_span(trace_info, node_execution, trace_metadata) - elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: node_span = self.build_workflow_retrieval_span(trace_info, node_execution, trace_metadata) - elif node_execution.node_type == NodeType.TOOL: + elif node_execution.node_type == BuiltinNodeTypes.TOOL: node_span = self.build_workflow_tool_span(trace_info, node_execution, trace_metadata) else: node_span = self.build_workflow_task_span(trace_info, node_execution, trace_metadata) diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py index 7624586367..0e00e90520 100644 --- a/api/core/ops/aliyun_trace/data_exporter/traceclient.py +++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py @@ -7,7 +7,7 @@ import uuid from collections import deque from collections.abc import Sequence from datetime import datetime -from typing import Final, cast +from typing import Final from urllib.parse import urljoin import httpx @@ -201,7 +201,7 @@ def convert_to_trace_id(uuid_v4: str | None) -> int: raise ValueError("UUID cannot be None") try: uuid_obj = uuid.UUID(uuid_v4) - return cast(int, uuid_obj.int) + return uuid_obj.int except ValueError as e: raise ValueError(f"Invalid UUID input: {uuid_v4}") from e diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 7f68889e92..45319f24c1 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -14,8 +14,8 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db from models import EndUser diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index a7b73e032e..f54461e99a 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -155,16 +155,32 @@ def wrap_span_metadata(metadata, **kwargs): return metadata +# Mapping from built-in node type strings to OpenInference span kinds. +# Node types not listed here default to CHAIN. +_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = { + "llm": OpenInferenceSpanKindValues.LLM, + "knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER, + "tool": OpenInferenceSpanKindValues.TOOL, + "agent": OpenInferenceSpanKindValues.AGENT, +} + + +def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues: + """Return the OpenInference span kind for a given workflow node type. + + Covers every built-in node type string. Nodes that do not have a + specialised span kind (e.g. ``start``, ``end``, ``if-else``, + ``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``. + """ + return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN) + + class ArizePhoenixDataTrace(BaseTraceInstance): def __init__( self, arize_phoenix_config: ArizeConfig | PhoenixConfig, ): super().__init__(arize_phoenix_config) - import logging - - logging.basicConfig() - logging.getLogger().setLevel(logging.DEBUG) self.arize_phoenix_config = arize_phoenix_config self.tracer, self.processor = setup_tracer(arize_phoenix_config) self.project = arize_phoenix_config.project @@ -289,9 +305,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Determine the correct span kind based on node type - span_kind = OpenInferenceSpanKindValues.CHAIN + span_kind = _get_node_span_kind(node_execution.node_type) if node_execution.node_type == "llm": - span_kind = OpenInferenceSpanKindValues.LLM provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: @@ -306,12 +321,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance): node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0) node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0) - elif node_execution.node_type == "dataset_retrieval": - span_kind = OpenInferenceSpanKindValues.RETRIEVER - elif node_execution.node_type == "tool": - span_kind = OpenInferenceSpanKindValues.TOOL - else: - span_kind = OpenInferenceSpanKindValues.CHAIN workflow_span_context = set_span_in_context(workflow_span) node_span = self.tracer.start_span( diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 4de4f403ce..6e62387a1f 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -141,7 +141,7 @@ class LangFuseDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 8b8117b24c..32a0c77fe2 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -163,7 +163,7 @@ class LangSmithDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} @@ -197,7 +197,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index df6e016632..ab4a7650ec 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -23,7 +23,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -145,10 +145,10 @@ class MLflowDataTrace(BaseTraceInstance): "app_name": node.title, } - if node.node_type in (NodeType.LLM, NodeType.QUESTION_CLASSIFIER): + if node.node_type in (BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER): inputs, llm_attributes = self._parse_llm_inputs_and_attributes(node) attributes.update(llm_attributes) - elif node.node_type == NodeType.HTTP_REQUEST: + elif node.node_type == BuiltinNodeTypes.HTTP_REQUEST: inputs = node.process_data # contains request URL if not inputs: @@ -180,9 +180,9 @@ class MLflowDataTrace(BaseTraceInstance): # End node span finished_at = node.created_at + timedelta(seconds=node.elapsed_time) outputs = json.loads(node.outputs) if node.outputs else {} - if node.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + if node.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: outputs = self._parse_knowledge_retrieval_outputs(outputs) - elif node.node_type == NodeType.LLM: + elif node.node_type == BuiltinNodeTypes.LLM: outputs = outputs.get("text", outputs) node_span.end( outputs=outputs, @@ -471,13 +471,13 @@ class MLflowDataTrace(BaseTraceInstance): def _get_node_span_type(self, node_type: str) -> str: """Map Dify node types to MLflow span types""" node_type_mapping = { - NodeType.LLM: SpanType.LLM, - NodeType.QUESTION_CLASSIFIER: SpanType.LLM, - NodeType.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER, - NodeType.TOOL: SpanType.TOOL, - NodeType.CODE: SpanType.TOOL, - NodeType.HTTP_REQUEST: SpanType.TOOL, - NodeType.AGENT: SpanType.AGENT, + BuiltinNodeTypes.LLM: SpanType.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER: SpanType.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: SpanType.RETRIEVER, + BuiltinNodeTypes.TOOL: SpanType.TOOL, + BuiltinNodeTypes.CODE: SpanType.TOOL, + BuiltinNodeTypes.HTTP_REQUEST: SpanType.TOOL, + BuiltinNodeTypes.AGENT: SpanType.AGENT, } return node_type_mapping.get(node_type, "CHAIN") # type: ignore[arg-type,call-overload] diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8050c59db9..fb72bc2381 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -1,3 +1,4 @@ +import hashlib import logging import os import uuid @@ -22,7 +23,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -46,6 +47,22 @@ def wrap_metadata(metadata, **kwargs): return metadata +def _seed_to_uuid4(seed: str) -> str: + """Derive a deterministic UUID4-formatted string from an arbitrary seed. + + uuid4_to_uuid7 requires a valid UUID v4 string, but some Dify identifiers + are not UUIDs (e.g. a workflow_run_id with a "-root" suffix appended to + distinguish the root span from the trace). This helper hashes the seed + with MD5 and patches the version/variant bits so the result satisfies the + UUID v4 contract. + """ + raw = hashlib.md5(seed.encode()).digest() + ba = bytearray(raw) + ba[6] = (ba[6] & 0x0F) | 0x40 # version 4 + ba[8] = (ba[8] & 0x3F) | 0x80 # variant 1 + return str(uuid.UUID(bytes=bytes(ba))) + + def prepare_opik_uuid(user_datetime: datetime | None, user_uuid: str | None): """Opik needs UUIDv7 while Dify uses UUIDv4 for identifier of most messages and objects. The type-hints of BaseTraceInfo indicates that @@ -95,60 +112,52 @@ class OpikDataTrace(BaseTraceInstance): self.generate_name_trace(trace_info) def workflow_trace(self, trace_info: WorkflowTraceInfo): - dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id - opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) workflow_metadata = wrap_metadata( trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id ) - root_span_id = None if trace_info.message_id: dify_trace_id = trace_info.trace_id or trace_info.message_id - opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) - - trace_data = { - "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE, - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "thread_id": trace_info.conversation_id, - "tags": ["message", "workflow"], - "project_name": self.project, - } - self.add_trace(trace_data) - - root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) - span_data = { - "id": root_span_id, - "parent_span_id": None, - "trace_id": opik_trace_id, - "name": TraceTaskName.WORKFLOW_TRACE, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "tags": ["workflow"], - "project_name": self.project, - } - self.add_span(span_data) + trace_name = TraceTaskName.MESSAGE_TRACE + trace_tags = ["message", "workflow"] + root_span_seed = trace_info.workflow_run_id else: - trace_data = { - "id": opik_trace_id, - "name": TraceTaskName.MESSAGE_TRACE, - "start_time": trace_info.start_time, - "end_time": trace_info.end_time, - "metadata": workflow_metadata, - "input": wrap_dict("input", trace_info.workflow_run_inputs), - "output": wrap_dict("output", trace_info.workflow_run_outputs), - "thread_id": trace_info.conversation_id, - "tags": ["workflow"], - "project_name": self.project, - } - self.add_trace(trace_data) + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id + trace_name = TraceTaskName.WORKFLOW_TRACE + trace_tags = ["workflow"] + root_span_seed = _seed_to_uuid4(trace_info.workflow_run_id + "-root") + + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + + trace_data = { + "id": opik_trace_id, + "name": trace_name, + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "thread_id": trace_info.conversation_id, + "tags": trace_tags, + "project_name": self.project, + } + self.add_trace(trace_data) + + root_span_id = prepare_opik_uuid(trace_info.start_time, root_span_seed) + span_data = { + "id": root_span_id, + "parent_span_id": None, + "trace_id": opik_trace_id, + "name": TraceTaskName.WORKFLOW_TRACE, + "input": wrap_dict("input", trace_info.workflow_run_inputs), + "output": wrap_dict("output", trace_info.workflow_run_outputs), + "start_time": trace_info.start_time, + "end_time": trace_info.end_time, + "metadata": workflow_metadata, + "tags": ["workflow"], + "project_name": self.project, + } + self.add_span(span_data) # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) @@ -178,7 +187,7 @@ class OpikDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} @@ -231,15 +240,13 @@ class OpikDataTrace(BaseTraceInstance): else: run_type = "tool" - parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id - if not total_tokens: total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), - "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), + "parent_span_id": root_span_id, "name": node_name, "type": run_type, "start_time": created_at, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 177991e645..9ac753240b 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -35,7 +35,7 @@ from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from core.workflow.entities import WorkflowExecution + from dify_graph.entities import WorkflowExecution logger = logging.getLogger(__name__) @@ -628,10 +628,10 @@ class TraceTask: if not message_data: return {} conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id) - conversation_mode = db.session.scalars(conversation_mode_stmt).all() - if not conversation_mode or len(conversation_mode) == 0: + conversation_modes = db.session.scalars(conversation_mode_stmt).all() + if not conversation_modes or len(conversation_modes) == 0: return {} - conversation_mode = conversation_mode[0] + conversation_mode = conversation_modes[0] created_at = message_data.created_at inputs = message_data.message diff --git a/api/core/ops/tencent_trace/client.py b/api/core/ops/tencent_trace/client.py index 99ccf00400..c39093bf4c 100644 --- a/api/core/ops/tencent_trace/client.py +++ b/api/core/ops/tencent_trace/client.py @@ -120,7 +120,8 @@ class TencentTraceClient: # Metrics exporter and instruments try: - from opentelemetry.sdk.metrics import Histogram, MeterProvider + from opentelemetry.sdk.metrics import Histogram as SdkHistogram + from opentelemetry.sdk.metrics import MeterProvider from opentelemetry.sdk.metrics.export import AggregationTemporality, PeriodicExportingMetricReader protocol = os.getenv("OTEL_EXPORTER_OTLP_PROTOCOL", "").strip().lower() @@ -128,7 +129,7 @@ class TencentTraceClient: use_http_json = protocol in {"http/json", "http-json"} # Tencent APM works best with delta aggregation temporality - preferred_temporality: dict[type, AggregationTemporality] = {Histogram: AggregationTemporality.DELTA} + preferred_temporality: dict[type, AggregationTemporality] = {SdkHistogram: AggregationTemporality.DELTA} def _create_metric_exporter(exporter_cls, **kwargs): """Create metric exporter with preferred_temporality support""" diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 26e8779e3e..0a6013e244 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -41,7 +41,7 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 93ec186863..7e56b1effa 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -24,10 +24,10 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from core.workflow.nodes import NodeType +from dify_graph.nodes import BuiltinNodeTypes from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom @@ -179,7 +179,7 @@ class TencentDataTrace(BaseTraceInstance): if node_span: self.trace_client.add_span(node_span) - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: self._record_llm_metrics(node_execution) except Exception: logger.exception("[Tencent APM] Failed to process node execution: %s", node_execution.id) @@ -192,15 +192,15 @@ class TencentDataTrace(BaseTraceInstance): ) -> SpanData | None: """Build span for different node types""" try: - if node_execution.node_type == NodeType.LLM: + if node_execution.node_type == BuiltinNodeTypes.LLM: return TencentSpanBuilder.build_workflow_llm_span( trace_id, workflow_span_id, trace_info, node_execution ) - elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL: + elif node_execution.node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: return TencentSpanBuilder.build_workflow_retrieval_span( trace_id, workflow_span_id, trace_info, node_execution ) - elif node_execution.node_type == NodeType.TOOL: + elif node_execution.node_type == BuiltinNodeTypes.TOOL: return TencentSpanBuilder.build_workflow_tool_span( trace_id, workflow_span_id, trace_info, node_execution ) diff --git a/api/core/ops/tencent_trace/utils.py b/api/core/ops/tencent_trace/utils.py index 96087951ab..678287ae1d 100644 --- a/api/core/ops/tencent_trace/utils.py +++ b/api/core/ops/tencent_trace/utils.py @@ -6,7 +6,6 @@ import hashlib import random import uuid from datetime import datetime -from typing import cast from opentelemetry.trace import Link, SpanContext, TraceFlags @@ -23,7 +22,7 @@ class TencentTraceUtils: uuid_obj = uuid.UUID(uuid_v4) if uuid_v4 else uuid.uuid4() except Exception as e: raise ValueError(f"Invalid UUID input: {e}") - return cast(int, uuid_obj.int) + return uuid_obj.int @staticmethod def convert_to_span_id(uuid_v4: str | None, span_type: str) -> int: @@ -52,9 +51,9 @@ class TencentTraceUtils: @staticmethod def create_link(trace_id_str: str) -> Link: try: - trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else cast(int, uuid.UUID(trace_id_str).int) + trace_id = int(trace_id_str, 16) if len(trace_id_str) == 32 else uuid.UUID(trace_id_str).int except (ValueError, TypeError): - trace_id = cast(int, uuid.uuid4().int) + trace_id = uuid.uuid4().int span_context = SpanContext( trace_id=trace_id, diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py index a5196d66c0..8b9a2e424a 100644 --- a/api/core/ops/utils.py +++ b/api/core/ops/utils.py @@ -1,6 +1,6 @@ from contextlib import contextmanager from datetime import datetime -from typing import Union +from typing import Any, Union from urllib.parse import urlparse from sqlalchemy import select @@ -9,7 +9,7 @@ from models.engine import db from models.model import Message -def filter_none_values(data: dict): +def filter_none_values(data: dict[str, Any]) -> dict[str, Any]: new_data = {} for key, value in data.items(): if value is None: diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2134be0bce..2a657b672c 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -31,7 +31,7 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -175,7 +175,7 @@ class WeaveDataTrace(BaseTraceInstance): node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == NodeType.LLM: + if node_type == BuiltinNodeTypes.LLM: inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: inputs = node_execution.inputs or {} diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 3c5df2b905..60d08b26c9 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator, Mapping -from typing import Union +from typing import Any, Union, cast from sqlalchemy import select from sqlalchemy.orm import Session @@ -34,14 +34,14 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): if workflow is None: raise ValueError("unexpected app type") - features_dict = workflow.features_dict + features_dict: dict[str, Any] = workflow.features_dict user_input_form = workflow.user_input_form(to_old_structure=True) else: app_model_config = app.app_model_config if app_model_config is None: raise ValueError("unexpected app type") - features_dict = app_model_config.to_dict() + features_dict = cast(dict[str, Any], app_model_config.to_dict()) user_input_form = features_dict.get("user_input_form", []) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 4ecc22834d..11c9191bac 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -5,18 +5,6 @@ from collections.abc import Generator from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from core.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.entities.request import ( RequestInvokeLLM, @@ -30,6 +18,18 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from dify_graph.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) from models.account import Tenant diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index 9fbcbf55b4..d6aef93fc4 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,17 +1,17 @@ from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from core.workflow.enums import NodeType -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) from services.workflow_service import WorkflowService @@ -52,7 +52,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): instruction=instruction, # instruct with variables are not supported ) node_data_dict = node_data.model_dump() - node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR + node_data_dict["type"] = BuiltinNodeTypes.PARAMETER_EXTRACTOR execution = workflow_service.run_free_workflow_node( node_data_dict, tenant_id=tenant_id, diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index cf1f7ff0dd..81e1e12c5f 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ from pydantic import BaseModel, Field, computed_field, model_validator -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index bfa662b9f6..ce5813a294 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -191,7 +191,7 @@ def cast_parameter_value(typ: StrEnum, value: Any, /): except ValueError: raise except Exception: - raise ValueError(f"The tool parameter value {value} is not in correct type of {as_normal_type(typ)}.") + raise ValueError(f"The tool parameter value {repr(value)} is not in correct type of {as_normal_type(typ)}.") def init_frontend_parameter(rule: PluginParameter, type: StrEnum, value: Any): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 9e1a9edf82..7a3780f7de 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -8,12 +8,12 @@ from pydantic import BaseModel, Field, field_validator, model_validator from core.agent.plugin_entities import AgentStrategyProviderEntity from core.datasource.entities.datasource_entities import DatasourceProviderEntity -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 6674228dc0..416e0f6b4d 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -10,14 +10,14 @@ from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity from core.plugin.entities.parameters import PluginParameterOption from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) @@ -157,6 +157,7 @@ class PluginInstallTaskPluginStatus(BaseModel): message: str = Field(description="The message of the install task.") icon: str = Field(description="The icon of the plugin.") labels: I18nObject = Field(description="The labels of the plugin.") + source: str | None = Field(default=None, description="The installation source of the plugin") class PluginInstallTask(BasePluginEntity): diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 73d3b8c89c..c15e9b0385 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -7,7 +7,8 @@ from flask import Response from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.provider_entities import BasicProviderConfig -from core.model_runtime.entities.message_entities import ( +from core.plugin.utils.http_parser import deserialize_response +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -16,18 +17,17 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelType -from core.plugin.utils.http_parser import deserialize_response -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.nodes.parameter_extractor.entities import ( ModelConfig as ParameterExtractorModelConfig, ) -from core.workflow.nodes.parameter_extractor.entities import ( +from dify_graph.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ClassConfig, ) -from core.workflow.nodes.question_classifier.entities import ( +from dify_graph.nodes.question_classifier.entities import ( ModelConfig as QuestionClassifierModelConfig, ) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 7a6a598a2f..737d204105 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -9,14 +9,6 @@ from pydantic import BaseModel from yarl import URL from configs import dify_config -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( @@ -35,6 +27,14 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 5d70980967..49ee5d79cb 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -2,12 +2,6 @@ import binascii from collections.abc import Generator, Sequence from typing import IO -from core.model_runtime.entities.llm_entities import LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, PluginDaemonInnerError, @@ -19,6 +13,12 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient +from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 3fe1b84dfa..53bcd9e9c6 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,7 @@ from typing import Any from core.tools.entities.tool_entities import ToolSelector -from core.workflow.file.models import File +from dify_graph.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 771b6be332..ce9f7e64b2 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -5,7 +5,12 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from dify_graph.file import file_manager +from dify_graph.file.models import File +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -13,13 +18,8 @@ from core.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import file_manager -from core.workflow.file.models import File -from core.workflow.runtime import VariablePool +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from dify_graph.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index c1ae47709f..d09a46bfde 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -4,13 +4,13 @@ from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.entities.message_entities import ( PromptMessage, SystemPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 7094633093..667f5ef099 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -2,7 +2,7 @@ from typing import Literal from pydantic import BaseModel -from core.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole class ChatModelMessage(BaseModel): diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 22ef5809bb..951736831f 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -3,9 +3,9 @@ from typing import Any from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 936a093488..10c44349ae 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -7,7 +7,11 @@ from typing import TYPE_CHECKING, Any, cast from core.app.app_config.entities import PromptTemplateEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from dify_graph.file import file_manager +from dify_graph.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -15,14 +19,10 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import file_manager from models.model import AppMode if TYPE_CHECKING: - from core.workflow.file.models import File + from dify_graph.file.models import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 0a7a467227..85a2201395 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,7 +1,8 @@ from collections.abc import Sequence from typing import Any, cast -from core.model_runtime.entities import ( +from core.prompt.simple_prompt_transform import ModelMode +from dify_graph.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -10,7 +11,6 @@ from core.model_runtime.entities import ( PromptMessageRole, TextPromptMessageContent, ) -from core.prompt.simple_prompt_transform import ModelMode class PromptMessageUtil: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index fdbfca4330..3c3fbd6dd2 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -28,14 +28,14 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormType, ProviderEntity, ) -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -195,7 +195,9 @@ class ProviderManager: preferred_provider_type_record = provider_name_to_preferred_model_provider_records_dict.get(provider_name) if preferred_provider_type_record: - preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type) + preferred_provider_type = preferred_provider_type_record.preferred_provider_type + elif dify_config.EDITION == "CLOUD" and system_configuration.enabled: + preferred_provider_type = ProviderType.SYSTEM elif custom_configuration.provider or custom_configuration.models: preferred_provider_type = ProviderType.CUSTOM elif system_configuration.enabled: @@ -305,9 +307,7 @@ class ProviderManager: available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next( - (model for model in available_models if model.model == "gpt-4"), available_models[0] - ) + available_model = available_models[0] default_model = TenantDefaultModel( tenant_id=tenant_id, @@ -627,7 +627,7 @@ class ProviderManager: tenant_id=tenant_id, # TODO: Use provider name with prefix after the data migration. provider_name=ModelProviderID(provider_name).provider_name, - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, quota_type=quota.quota_type, quota_limit=0, # type: ignore quota_used=0, diff --git a/api/core/rag/cleaner/clean_processor.py b/api/core/rag/cleaner/clean_processor.py index e182c35b99..790253053d 100644 --- a/api/core/rag/cleaner/clean_processor.py +++ b/api/core/rag/cleaner/clean_processor.py @@ -1,9 +1,10 @@ import re +from typing import Any class CleanProcessor: @classmethod - def clean(cls, text: str, process_rule: dict) -> str: + def clean(cls, text: str, process_rule: dict[str, Any] | None) -> str: # default clean # remove invalid symbol text = re.sub(r"<\|", "<", text) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index bfa8781e9f..33eb5f963a 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,6 +1,6 @@ +from typing_extensions import TypedDict + from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document @@ -8,6 +8,28 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +class RerankingModelDict(TypedDict): + reranking_provider_name: str + reranking_model_name: str + + +class VectorSettingDict(TypedDict): + vector_weight: float + embedding_provider_name: str + embedding_model_name: str + + +class KeywordSettingDict(TypedDict): + keyword_weight: float + + +class WeightsDict(TypedDict): + vector_setting: VectorSettingDict + keyword_setting: KeywordSettingDict class DataPostProcessor: @@ -17,8 +39,8 @@ class DataPostProcessor: self, tenant_id: str, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reorder_enabled: bool = False, ): self.rerank_runner = self._get_rerank_runner(reranking_mode, tenant_id, reranking_model, weights) @@ -45,8 +67,8 @@ class DataPostProcessor: self, reranking_mode: str, tenant_id: str, - reranking_model: dict | None = None, - weights: dict | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, ) -> BaseRerankRunner | None: if reranking_mode == RerankMode.WEIGHTED_SCORE and weights: runner = RerankRunnerFactory.create_rerank_runner( @@ -79,12 +101,14 @@ class DataPostProcessor: return ReorderRunner() return None - def _get_rerank_model_instance(self, tenant_id: str, reranking_model: dict | None) -> ModelInstance | None: + def _get_rerank_model_instance( + self, tenant_id: str, reranking_model: RerankingModelDict | None + ) -> ModelInstance | None: if reranking_model: try: model_manager = ModelManager() - reranking_provider_name = reranking_model.get("reranking_provider_name") - reranking_model_name = reranking_model.get("reranking_model_name") + reranking_provider_name = reranking_model["reranking_provider_name"] + reranking_model_name = reranking_model["reranking_model_name"] if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py index 0f19ecadc8..b07dc108be 100644 --- a/api/core/rag/datasource/keyword/jieba/jieba.py +++ b/api/core/rag/datasource/keyword/jieba/jieba.py @@ -4,6 +4,7 @@ from typing import Any import orjson from pydantic import BaseModel from sqlalchemy import select +from typing_extensions import TypedDict from configs import dify_config from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -15,6 +16,11 @@ from extensions.ext_storage import storage from models.dataset import Dataset, DatasetKeywordTable, DocumentSegment +class PreSegmentData(TypedDict): + segment: DocumentSegment + keywords: list[str] + + class KeywordTableConfig(BaseModel): max_keywords_per_chunk: int = 10 @@ -128,7 +134,7 @@ class Jieba(BaseKeyword): file_key = "keyword_files/" + self.dataset.tenant_id + "/" + self.dataset.id + ".txt" storage.delete(file_key) - def _save_dataset_keyword_table(self, keyword_table): + def _save_dataset_keyword_table(self, keyword_table: dict[str, set[str]] | None): keyword_table_dict = { "__type__": "keyword_table", "__data__": {"index_id": self.dataset.id, "summary": None, "table": keyword_table}, @@ -144,7 +150,7 @@ class Jieba(BaseKeyword): storage.delete(file_key) storage.save(file_key, dumps_with_sets(keyword_table_dict).encode("utf-8")) - def _get_dataset_keyword_table(self) -> dict | None: + def _get_dataset_keyword_table(self) -> dict[str, set[str]] | None: dataset_keyword_table = self.dataset.dataset_keyword_table if dataset_keyword_table: keyword_table_dict = dataset_keyword_table.keyword_table_dict @@ -169,14 +175,16 @@ class Jieba(BaseKeyword): return {} - def _add_text_to_keyword_table(self, keyword_table: dict, id: str, keywords: list[str]): + def _add_text_to_keyword_table( + self, keyword_table: dict[str, set[str]], id: str, keywords: list[str] + ) -> dict[str, set[str]]: for keyword in keywords: if keyword not in keyword_table: keyword_table[keyword] = set() keyword_table[keyword].add(id) return keyword_table - def _delete_ids_from_keyword_table(self, keyword_table: dict, ids: list[str]): + def _delete_ids_from_keyword_table(self, keyword_table: dict[str, set[str]], ids: list[str]) -> dict[str, set[str]]: # get set of ids that correspond to node node_idxs_to_delete = set(ids) @@ -193,7 +201,7 @@ class Jieba(BaseKeyword): return keyword_table - def _retrieve_ids_by_query(self, keyword_table: dict, query: str, k: int = 4): + def _retrieve_ids_by_query(self, keyword_table: dict[str, set[str]], query: str, k: int = 4) -> list[str]: keyword_table_handler = JiebaKeywordTableHandler() keywords = keyword_table_handler.extract_keywords(query) @@ -228,7 +236,7 @@ class Jieba(BaseKeyword): keyword_table = self._add_text_to_keyword_table(keyword_table or {}, node_id, keywords) self._save_dataset_keyword_table(keyword_table) - def multi_create_segment_keywords(self, pre_segment_data_list: list): + def multi_create_segment_keywords(self, pre_segment_data_list: list[PreSegmentData]): keyword_table_handler = JiebaKeywordTableHandler() keyword_table = self._get_dataset_keyword_table() for pre_segment_data in pre_segment_data_list: diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 91c16ce079..713319ab9d 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -1,20 +1,20 @@ import concurrent.futures import logging from concurrent.futures import ThreadPoolExecutor -from typing import Any +from typing import Any, NotRequired from flask import Flask, current_app from sqlalchemy import select from sqlalchemy.orm import Session, load_only +from typing_extensions import TypedDict from configs import dify_config from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector -from core.rag.embedding.retrieval import RetrievalChildChunk, RetrievalSegments +from core.rag.embedding.retrieval import AttachmentInfoDict, RetrievalChildChunk, RetrievalSegments from core.rag.entities.metadata_entities import MetadataCondition from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType @@ -23,6 +23,7 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ( ChildChunk, @@ -35,7 +36,49 @@ from models.dataset import Document as DatasetDocument from models.model import UploadFile from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { + +class SegmentAttachmentResult(TypedDict): + attachment_info: AttachmentInfoDict + segment_id: str + + +class SegmentAttachmentInfoResult(TypedDict): + attachment_id: str + attachment_info: AttachmentInfoDict + segment_id: str + + +class ChildChunkDetail(TypedDict): + id: str + content: str + position: int + score: float + + +class SegmentChildMapDetail(TypedDict): + max_score: float + child_chunks: list[ChildChunkDetail] + + +class SegmentRecord(TypedDict): + segment: DocumentSegment + score: NotRequired[float] + child_chunks: NotRequired[list[ChildChunkDetail]] + files: NotRequired[list[AttachmentInfoDict]] + + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -56,11 +99,11 @@ class RetrievalService: query: str, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, - attachment_ids: list | None = None, + attachment_ids: list[str] | None = None, ): if not query and not attachment_ids: return [] @@ -207,8 +250,8 @@ class RetrievalService: dataset_id: str, query: str, top_k: int, - all_documents: list, - exceptions: list, + all_documents: list[Document], + exceptions: list[str], document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): @@ -235,10 +278,10 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, - all_documents: list, + reranking_model: RerankingModelDict | None, + all_documents: list[Document], retrieval_method: RetrievalMethod, - exceptions: list, + exceptions: list[str], document_ids_filter: list[str] | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ): @@ -277,8 +320,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.SEMANTIC_SEARCH ): data_post_processor = DataPostProcessor( @@ -288,8 +331,8 @@ class RetrievalService: model_manager = ModelManager() is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, - provider=reranking_model.get("reranking_provider_name") or "", - model=reranking_model.get("reranking_model_name") or "", + provider=reranking_model["reranking_provider_name"], + model=reranking_model["reranking_model_name"], model_type=ModelType.RERANK, ) if is_support_vision: @@ -329,10 +372,10 @@ class RetrievalService: query: str, top_k: int, score_threshold: float | None, - reranking_model: dict | None, - all_documents: list, + reranking_model: RerankingModelDict | None, + all_documents: list[Document], retrieval_method: str, - exceptions: list, + exceptions: list[str], document_ids_filter: list[str] | None = None, ): with flask_app.app_context(): @@ -349,8 +392,8 @@ class RetrievalService: if documents: if ( reranking_model - and reranking_model.get("reranking_model_name") - and reranking_model.get("reranking_provider_name") + and reranking_model["reranking_model_name"] + and reranking_model["reranking_provider_name"] and retrieval_method == RetrievalMethod.FULL_TEXT_SEARCH ): data_post_processor = DataPostProcessor( @@ -459,7 +502,7 @@ class RetrievalService: segment_ids: list[str] = [] index_node_segments: list[DocumentSegment] = [] segments: list[DocumentSegment] = [] - attachment_map: dict[str, list[dict[str, Any]]] = {} + attachment_map: dict[str, list[AttachmentInfoDict]] = {} child_chunk_map: dict[str, list[ChildChunk]] = {} doc_segment_map: dict[str, list[str]] = {} segment_summary_map: dict[str, str] = {} # Map segment_id to summary content @@ -544,12 +587,12 @@ class RetrievalService: segment_summary_map[summary.chunk_id] = summary.summary_content include_segment_ids = set() - segment_child_map: dict[str, dict[str, Any]] = {} - records: list[dict[str, Any]] = [] + segment_child_map: dict[str, SegmentChildMapDetail] = {} + records: list[SegmentRecord] = [] for segment in segments: child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, []) - attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, []) + attachment_infos: list[AttachmentInfoDict] = attachment_map.get(segment.id, []) ds_dataset_document: DatasetDocument | None = valid_dataset_documents.get(segment.document_id) if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX: @@ -560,14 +603,14 @@ class RetrievalService: max_score = summary_score_map.get(segment.id, 0.0) if child_chunks or attachment_infos: - child_chunk_details = [] + child_chunk_details: list[ChildChunkDetail] = [] for child_chunk in child_chunks: child_document: Document | None = doc_to_document_map.get(child_chunk.index_node_id) if child_document: child_score = child_document.metadata.get("score", 0.0) else: child_score = 0.0 - child_chunk_detail = { + child_chunk_detail: ChildChunkDetail = { "id": child_chunk.id, "content": child_chunk.content, "position": child_chunk.position, @@ -580,7 +623,7 @@ class RetrievalService: if file_document: max_score = max(max_score, file_document.metadata.get("score", 0.0)) - map_detail = { + map_detail: SegmentChildMapDetail = { "max_score": max_score, "child_chunks": child_chunk_details, } @@ -593,7 +636,7 @@ class RetrievalService: "max_score": summary_score, "child_chunks": [], } - record: dict[str, Any] = { + record: SegmentRecord = { "segment": segment, } records.append(record) @@ -617,19 +660,19 @@ class RetrievalService: if file_doc: max_score = max(max_score, file_doc.metadata.get("score", 0.0)) - record = { + another_record: SegmentRecord = { "segment": segment, "score": max_score, } - records.append(record) + records.append(another_record) # Add child chunks information to records for record in records: if record["segment"].id in segment_child_map: - record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore - record["score"] = segment_child_map[record["segment"].id]["max_score"] # type: ignore + record["child_chunks"] = segment_child_map[record["segment"].id]["child_chunks"] + record["score"] = segment_child_map[record["segment"].id]["max_score"] if record["segment"].id in attachment_map: - record["files"] = attachment_map[record["segment"].id] # type: ignore[assignment] + record["files"] = attachment_map[record["segment"].id] result: list[RetrievalSegments] = [] for record in records: @@ -693,9 +736,9 @@ class RetrievalService: query: str | None = None, top_k: int = 4, score_threshold: float | None = 0.0, - reranking_model: dict | None = None, + reranking_model: RerankingModelDict | None = None, reranking_mode: str = "reranking_model", - weights: dict | None = None, + weights: WeightsDict | None = None, document_ids_filter: list[str] | None = None, attachment_id: str | None = None, ): @@ -807,7 +850,7 @@ class RetrievalService: @classmethod def get_segment_attachment_info( cls, dataset_id: str, tenant_id: str, attachment_id: str, session: Session - ) -> dict[str, Any] | None: + ) -> SegmentAttachmentResult | None: upload_file = session.query(UploadFile).where(UploadFile.id == attachment_id).first() if upload_file: attachment_binding = ( @@ -816,7 +859,7 @@ class RetrievalService: .first() ) if attachment_binding: - attachment_info = { + attachment_info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -828,8 +871,10 @@ class RetrievalService: return None @classmethod - def get_segment_attachment_infos(cls, attachment_ids: list[str], session: Session) -> list[dict[str, Any]]: - attachment_infos = [] + def get_segment_attachment_infos( + cls, attachment_ids: list[str], session: Session + ) -> list[SegmentAttachmentInfoResult]: + attachment_infos: list[SegmentAttachmentInfoResult] = [] upload_files = session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).all() if upload_files: upload_file_ids = [upload_file.id for upload_file in upload_files] @@ -843,7 +888,7 @@ class RetrievalService: if attachment_bindings: for upload_file in upload_files: attachment_binding = attachment_binding_map.get(upload_file.id) - attachment_info = { + info: AttachmentInfoDict = { "id": upload_file.id, "name": upload_file.name, "extension": "." + upload_file.extension, @@ -855,7 +900,7 @@ class RetrievalService: attachment_infos.append( { "attachment_id": attachment_binding.attachment_id, - "attachment_info": attachment_info, + "attachment_info": info, "segment_id": attachment_binding.segment_id, } ) diff --git a/api/core/rag/datasource/vdb/chroma/chroma_vector.py b/api/core/rag/datasource/vdb/chroma/chroma_vector.py index de1572410c..cbc846f716 100644 --- a/api/core/rag/datasource/vdb/chroma/chroma_vector.py +++ b/api/core/rag/datasource/vdb/chroma/chroma_vector.py @@ -65,7 +65,7 @@ class ChromaVector(BaseVector): self._client.get_or_create_collection(collection_name) redis_client.set(collection_exist_cache_key, 1, ex=3600) - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: uuids = self._get_uuids(documents) texts = [d.page_content for d in documents] metadatas = [d.metadata for d in documents] @@ -73,6 +73,7 @@ class ChromaVector(BaseVector): collection = self._client.get_or_create_collection(self._collection_name) # FIXME: chromadb using numpy array, fix the type error later collection.upsert(ids=uuids, documents=texts, embeddings=embeddings, metadatas=metadatas) # type: ignore + return uuids def delete_by_metadata_field(self, key: str, value: str): collection = self._client.get_or_create_collection(self._collection_name) diff --git a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py index 91bb71bfa6..8e8120fc10 100644 --- a/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +++ b/api/core/rag/datasource/vdb/clickzetta/clickzetta_vector.py @@ -605,25 +605,36 @@ class ClickzettaVector(BaseVector): logger.warning("Failed to create inverted index: %s", e) # Continue without inverted index - full-text search will fall back to LIKE - def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs) -> list[str]: """Add documents with embeddings to the collection.""" if not documents: - return + return [] batch_size = self._config.batch_size total_batches = (len(documents) + batch_size - 1) // batch_size + added_ids = [] for i in range(0, len(documents), batch_size): batch_docs = documents[i : i + batch_size] batch_embeddings = embeddings[i : i + batch_size] + batch_doc_ids = [] + for doc in batch_docs: + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} + batch_doc_ids.append(self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4())))) + added_ids.extend(batch_doc_ids) # Execute batch insert through write queue - self._execute_write(self._insert_batch, batch_docs, batch_embeddings, i, batch_size, total_batches) + self._execute_write( + self._insert_batch, batch_docs, batch_embeddings, batch_doc_ids, i, batch_size, total_batches + ) + + return added_ids def _insert_batch( self, batch_docs: list[Document], batch_embeddings: list[list[float]], + batch_doc_ids: list[str], batch_index: int, batch_size: int, total_batches: int, @@ -641,14 +652,9 @@ class ClickzettaVector(BaseVector): data_rows = [] vector_dimension = len(batch_embeddings[0]) if batch_embeddings and batch_embeddings[0] else 768 - for doc, embedding in zip(batch_docs, batch_embeddings): + for doc, embedding, doc_id in zip(batch_docs, batch_embeddings, batch_doc_ids): # Optimized: minimal checks for common case, fallback for edge cases - metadata = doc.metadata or {} - - if not isinstance(metadata, dict): - metadata = {} - - doc_id = self._safe_doc_id(metadata.get("doc_id", str(uuid.uuid4()))) + metadata = doc.metadata if isinstance(doc.metadata, dict) else {} # Fast path for JSON serialization try: diff --git a/api/core/model_runtime/__init__.py b/api/core/rag/datasource/vdb/hologres/__init__.py similarity index 100% rename from api/core/model_runtime/__init__.py rename to api/core/rag/datasource/vdb/hologres/__init__.py diff --git a/api/core/rag/datasource/vdb/hologres/hologres_vector.py b/api/core/rag/datasource/vdb/hologres/hologres_vector.py new file mode 100644 index 0000000000..36b259e494 --- /dev/null +++ b/api/core/rag/datasource/vdb/hologres/hologres_vector.py @@ -0,0 +1,361 @@ +import json +import logging +import time +from typing import Any + +import holo_search_sdk as holo # type: ignore +from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType +from psycopg import sql as psql +from pydantic import BaseModel, model_validator + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class HologresVectorConfig(BaseModel): + """ + Configuration for Hologres vector database connection. + + In Hologres, access_key_id is used as the PostgreSQL username, + and access_key_secret is used as the PostgreSQL password. + """ + + host: str + port: int = 80 + database: str + access_key_id: str + access_key_secret: str + schema_name: str = "public" + tokenizer: TokenizerType = "jieba" + distance_method: DistanceType = "Cosine" + base_quantization_type: BaseQuantizationType = "rabitq" + max_degree: int = 64 + ef_construction: int = 400 + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict): + if not values.get("host"): + raise ValueError("config HOLOGRES_HOST is required") + if not values.get("database"): + raise ValueError("config HOLOGRES_DATABASE is required") + if not values.get("access_key_id"): + raise ValueError("config HOLOGRES_ACCESS_KEY_ID is required") + if not values.get("access_key_secret"): + raise ValueError("config HOLOGRES_ACCESS_KEY_SECRET is required") + return values + + +class HologresVector(BaseVector): + """ + Hologres vector storage implementation using holo-search-sdk. + + Supports semantic search (vector), full-text search, and hybrid search. + """ + + def __init__(self, collection_name: str, config: HologresVectorConfig): + super().__init__(collection_name) + self._config = config + self._client = self._init_client(config) + self.table_name = f"embedding_{collection_name}".lower() + + def _init_client(self, config: HologresVectorConfig): + """Initialize and return a holo-search-sdk client.""" + client = holo.connect( + host=config.host, + port=config.port, + database=config.database, + access_key_id=config.access_key_id, + access_key_secret=config.access_key_secret, + schema=config.schema_name, + ) + client.connect() + return client + + def get_type(self) -> str: + return VectorType.HOLOGRES + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + """Create collection table with vector and full-text indexes, then add texts.""" + dimension = len(embeddings[0]) + self._create_collection(dimension) + self.add_texts(texts, embeddings) + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + """Add texts with embeddings to the collection using batch upsert.""" + if not documents: + return [] + + pks: list[str] = [] + batch_size = 100 + for i in range(0, len(documents), batch_size): + batch_docs = documents[i : i + batch_size] + batch_embeddings = embeddings[i : i + batch_size] + + values = [] + column_names = ["id", "text", "meta", "embedding"] + + for j, doc in enumerate(batch_docs): + doc_id = doc.metadata.get("doc_id", "") if doc.metadata else "" + pks.append(doc_id) + values.append( + [ + doc_id, + doc.page_content, + json.dumps(doc.metadata or {}), + batch_embeddings[j], + ] + ) + + table = self._client.open_table(self.table_name) + table.upsert_multi( + index_column="id", + values=values, + column_names=column_names, + update=True, + update_columns=["text", "meta", "embedding"], + ) + + return pks + + def text_exists(self, id: str) -> bool: + """Check if a text with the given doc_id exists in the collection.""" + if not self._client.check_table_exist(self.table_name): + return False + + result = self._client.execute( + psql.SQL("SELECT 1 FROM {} WHERE id = {} LIMIT 1").format( + psql.Identifier(self.table_name), psql.Literal(id) + ), + fetch_result=True, + ) + return bool(result) + + def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None: + """Get document IDs by metadata field key and value.""" + result = self._client.execute( + psql.SQL("SELECT id FROM {} WHERE meta->>{} = {}").format( + psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value) + ), + fetch_result=True, + ) + if result: + return [row[0] for row in result] + return None + + def delete_by_ids(self, ids: list[str]): + """Delete documents by their doc_id list.""" + if not ids: + return + if not self._client.check_table_exist(self.table_name): + return + + self._client.execute( + psql.SQL("DELETE FROM {} WHERE id IN ({})").format( + psql.Identifier(self.table_name), + psql.SQL(", ").join(psql.Literal(id) for id in ids), + ) + ) + + def delete_by_metadata_field(self, key: str, value: str): + """Delete documents by metadata field key and value.""" + if not self._client.check_table_exist(self.table_name): + return + + self._client.execute( + psql.SQL("DELETE FROM {} WHERE meta->>{} = {}").format( + psql.Identifier(self.table_name), psql.Literal(key), psql.Literal(value) + ) + ) + + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + """Search for documents by vector similarity.""" + if not self._client.check_table_exist(self.table_name): + return [] + + top_k = kwargs.get("top_k", 4) + score_threshold = float(kwargs.get("score_threshold") or 0.0) + + table = self._client.open_table(self.table_name) + query = ( + table.search_vector( + vector=query_vector, + column="embedding", + distance_method=self._config.distance_method, + output_name="distance", + ) + .select(["id", "text", "meta"]) + .limit(top_k) + ) + + # Apply document_ids_filter if provided + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filter_sql = psql.SQL("meta->>'document_id' IN ({})").format( + psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter) + ) + query = query.where(filter_sql) + + results = query.fetchall() + return self._process_vector_results(results, score_threshold) + + def _process_vector_results(self, results: list, score_threshold: float) -> list[Document]: + """Process vector search results into Document objects.""" + docs = [] + for row in results: + # row format: (distance, id, text, meta) + # distance is first because search_vector() adds the computed column before selected columns + distance = row[0] + text = row[2] + meta = row[3] + + if isinstance(meta, str): + meta = json.loads(meta) + + # Convert distance to similarity score (consistent with pgvector) + score = 1 - distance + meta["score"] = score + + if score >= score_threshold: + docs.append(Document(page_content=text, metadata=meta)) + + return docs + + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + """Search for documents by full-text search.""" + if not self._client.check_table_exist(self.table_name): + return [] + + top_k = kwargs.get("top_k", 4) + + table = self._client.open_table(self.table_name) + search_query = table.search_text( + column="text", + expression=query, + return_score=True, + return_score_name="score", + return_all_columns=True, + ).limit(top_k) + + # Apply document_ids_filter if provided + document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: + filter_sql = psql.SQL("meta->>'document_id' IN ({})").format( + psql.SQL(", ").join(psql.Literal(id) for id in document_ids_filter) + ) + search_query = search_query.where(filter_sql) + + results = search_query.fetchall() + return self._process_full_text_results(results) + + def _process_full_text_results(self, results: list) -> list[Document]: + """Process full-text search results into Document objects.""" + docs = [] + for row in results: + # row format: (id, text, meta, embedding, score) + text = row[1] + meta = row[2] + score = row[-1] # score is the last column from return_score + + if isinstance(meta, str): + meta = json.loads(meta) + + meta["score"] = score + docs.append(Document(page_content=text, metadata=meta)) + + return docs + + def delete(self): + """Delete the entire collection table.""" + if self._client.check_table_exist(self.table_name): + self._client.drop_table(self.table_name) + + def _create_collection(self, dimension: int): + """Create the collection table with vector and full-text indexes.""" + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return + + if not self._client.check_table_exist(self.table_name): + # Create table via SQL with CHECK constraint for vector dimension + create_table_sql = psql.SQL(""" + CREATE TABLE IF NOT EXISTS {} ( + id TEXT PRIMARY KEY, + text TEXT NOT NULL, + meta JSONB NOT NULL, + embedding float4[] NOT NULL + CHECK (array_ndims(embedding) = 1 + AND array_length(embedding, 1) = {}) + ); + """).format(psql.Identifier(self.table_name), psql.Literal(dimension)) + self._client.execute(create_table_sql) + + # Wait for table to be fully ready before creating indexes + max_wait_seconds = 30 + poll_interval = 2 + for _ in range(max_wait_seconds // poll_interval): + if self._client.check_table_exist(self.table_name): + break + time.sleep(poll_interval) + else: + raise RuntimeError(f"Table {self.table_name} was not ready after {max_wait_seconds}s") + + # Open table and set vector index + table = self._client.open_table(self.table_name) + table.set_vector_index( + column="embedding", + distance_method=self._config.distance_method, + base_quantization_type=self._config.base_quantization_type, + max_degree=self._config.max_degree, + ef_construction=self._config.ef_construction, + use_reorder=self._config.base_quantization_type == "rabitq", + ) + + # Create full-text search index + table.create_text_index( + index_name=f"ft_idx_{self._collection_name}", + column="text", + tokenizer=self._config.tokenizer, + ) + + redis_client.set(collection_exist_cache_key, 1, ex=3600) + + +class HologresVectorFactory(AbstractVectorFactory): + """Factory class for creating HologresVector instances.""" + + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HologresVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HOLOGRES, collection_name)) + + return HologresVector( + collection_name=collection_name, + config=HologresVectorConfig( + host=dify_config.HOLOGRES_HOST or "", + port=dify_config.HOLOGRES_PORT, + database=dify_config.HOLOGRES_DATABASE or "", + access_key_id=dify_config.HOLOGRES_ACCESS_KEY_ID or "", + access_key_secret=dify_config.HOLOGRES_ACCESS_KEY_SECRET or "", + schema_name=dify_config.HOLOGRES_SCHEMA, + tokenizer=dify_config.HOLOGRES_TOKENIZER, + distance_method=dify_config.HOLOGRES_DISTANCE_METHOD, + base_quantization_type=dify_config.HOLOGRES_BASE_QUANTIZATION_TYPE, + max_degree=dify_config.HOLOGRES_MAX_DEGREE, + ef_construction=dify_config.HOLOGRES_EF_CONSTRUCTION, + ), + ) diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py index b986c79e3a..90d9173409 100644 --- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py @@ -135,8 +135,8 @@ class PGVectoRS(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self._client) as session: - select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>'{key}' = '{value}'; ") - result = session.execute(select_statement).fetchall() + select_statement = sql_text(f"SELECT id FROM {self._collection_name} WHERE meta->>:key = :value") + result = session.execute(select_statement, {"key": key, "value": value}).fetchall() if result: return [item[0] for item in result] else: @@ -172,9 +172,9 @@ class PGVectoRS(BaseVector): def text_exists(self, id: str) -> bool: with Session(self._client) as session: select_statement = sql_text( - f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = '{id}' limit 1; " + f"SELECT id FROM {self._collection_name} WHERE meta->>'doc_id' = :doc_id limit 1" ) - result = session.execute(select_statement).fetchall() + result = session.execute(select_statement, {"doc_id": id}).fetchall() return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/relyt/relyt_vector.py b/api/core/rag/datasource/vdb/relyt/relyt_vector.py index 70857b3e3c..e486375ec2 100644 --- a/api/core/rag/datasource/vdb/relyt/relyt_vector.py +++ b/api/core/rag/datasource/vdb/relyt/relyt_vector.py @@ -154,10 +154,8 @@ class RelytVector(BaseVector): def get_ids_by_metadata_field(self, key: str, value: str): result = None with Session(self.client) as session: - select_statement = sql_text( - f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'{key}' = '{value}'; """ - ) - result = session.execute(select_statement).fetchall() + select_statement = sql_text(f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>:key = :value""") + result = session.execute(select_statement, {"key": key, "value": value}).fetchall() if result: return [item[0] for item in result] else: @@ -201,11 +199,10 @@ class RelytVector(BaseVector): def delete_by_ids(self, ids: list[str]): with Session(self.client) as session: - ids_str = ",".join(f"'{doc_id}'" for doc_id in ids) select_statement = sql_text( - f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' in ({ids_str}); """ + f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = ANY(:doc_ids)""" ) - result = session.execute(select_statement).fetchall() + result = session.execute(select_statement, {"doc_ids": ids}).fetchall() if result: ids = [item[0] for item in result] self.delete_by_uuids(ids) @@ -218,9 +215,9 @@ class RelytVector(BaseVector): def text_exists(self, id: str) -> bool: with Session(self.client) as session: select_statement = sql_text( - f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = '{id}' limit 1; """ + f"""SELECT id FROM "{self._collection_name}" WHERE metadata->>'doc_id' = :doc_id limit 1""" ) - result = session.execute(select_statement).fetchall() + result = session.execute(select_statement, {"doc_id": id}).fetchall() return len(result) > 0 def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 56ffb36a2b..71b6fa0a9b 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -284,27 +284,29 @@ class TidbOnQdrantVector(BaseVector): from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse - for node_id in ids: - try: - filter = models.Filter( - must=[ - models.FieldCondition( - key="metadata.doc_id", - match=models.MatchValue(value=node_id), - ), - ], - ) - self._client.delete( - collection_name=self._collection_name, - points_selector=FilterSelector(filter=filter), - ) - except UnexpectedResponse as e: - # Collection does not exist, so return - if e.status_code == 404: - return - # Some other error occurred, so re-raise the exception - else: - raise e + if not ids: + return + + try: + filter = models.Filter( + must=[ + models.FieldCondition( + key="metadata.doc_id", + match=models.MatchAny(any=ids), + ), + ], + ) + self._client.delete( + collection_name=self._collection_name, + points_selector=FilterSelector(filter=filter), + ) + except UnexpectedResponse as e: + # Collection does not exist, so return + if e.status_code == 404: + return + # Some other error occurred, so re-raise the exception + else: + raise e def text_exists(self, id: str) -> bool: all_collection_name = [] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index b9772b3c08..cd12cd3fae 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -8,13 +8,13 @@ from sqlalchemy import select from configs import dify_config from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_base import BaseVector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -38,7 +38,7 @@ class AbstractVectorFactory(ABC): class Vector: def __init__(self, dataset: Dataset, attributes: list | None = None): if attributes is None: - attributes = ["doc_id", "dataset_id", "document_id", "doc_hash"] + attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] self._dataset = dataset self._embeddings = self._get_embeddings() self._attributes = attributes @@ -191,6 +191,10 @@ class Vector: from core.rag.datasource.vdb.iris.iris_vector import IrisVectorFactory return IrisVectorFactory + case VectorType.HOLOGRES: + from core.rag.datasource.vdb.hologres.hologres_vector import HologresVectorFactory + + return HologresVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index bd99a31446..9cce8e4c32 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -34,3 +34,4 @@ class VectorType(StrEnum): MATRIXONE = "matrixone" CLICKZETTA = "clickzetta" IRIS = "iris" + HOLOGRES = "hologres" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index b48dd93f04..d29d62c93f 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -5,9 +5,11 @@ This module provides integration with Weaviate vector database for storing and r document embeddings used in retrieval-augmented generation workflows. """ +import atexit import datetime import json import logging +import threading import uuid as _uuid from typing import Any from urllib.parse import urlparse @@ -32,6 +34,35 @@ from models.dataset import Dataset logger = logging.getLogger(__name__) +_weaviate_client: weaviate.WeaviateClient | None = None +_weaviate_client_lock = threading.Lock() + + +def _shutdown_weaviate_client() -> None: + """ + Best-effort shutdown hook to close the module-level Weaviate client. + + This is registered with atexit so that HTTP/gRPC resources are released + when the Python interpreter exits. + """ + global _weaviate_client + + # Ensure thread-safety when accessing the shared client instance + with _weaviate_client_lock: + client = _weaviate_client + _weaviate_client = None + + if client is not None: + try: + client.close() + except Exception: + # Best-effort cleanup; log at debug level and ignore errors. + logger.debug("Failed to close Weaviate client during shutdown", exc_info=True) + + +# Register the shutdown hook once per process. +atexit.register(_shutdown_weaviate_client) + class WeaviateConfig(BaseModel): """ @@ -81,61 +112,58 @@ class WeaviateVector(BaseVector): self._client = self._init_client(config) self._attributes = attributes - def __del__(self): - """ - Destructor to properly close the Weaviate client connection. - Prevents connection leaks and resource warnings. - """ - if hasattr(self, "_client") and self._client is not None: - try: - self._client.close() - except Exception as e: - # Ignore errors during cleanup as object is being destroyed - logger.warning("Error closing Weaviate client %s", e, exc_info=True) - def _init_client(self, config: WeaviateConfig) -> weaviate.WeaviateClient: """ Initializes and returns a connected Weaviate client. Configures both HTTP and gRPC connections with proper authentication. """ - p = urlparse(config.endpoint) - host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "") - http_secure = p.scheme == "https" - http_port = p.port or (443 if http_secure else 80) + global _weaviate_client + if _weaviate_client and _weaviate_client.is_ready(): + return _weaviate_client - # Parse gRPC configuration - if config.grpc_endpoint: - # Urls without scheme won't be parsed correctly in some python versions, - # see https://bugs.python.org/issue27657 - grpc_endpoint_with_scheme = ( - config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}" + with _weaviate_client_lock: + if _weaviate_client and _weaviate_client.is_ready(): + return _weaviate_client + + p = urlparse(config.endpoint) + host = p.hostname or config.endpoint.replace("https://", "").replace("http://", "") + http_secure = p.scheme == "https" + http_port = p.port or (443 if http_secure else 80) + + # Parse gRPC configuration + if config.grpc_endpoint: + # Urls without scheme won't be parsed correctly in some python versions, + # see https://bugs.python.org/issue27657 + grpc_endpoint_with_scheme = ( + config.grpc_endpoint if "://" in config.grpc_endpoint else f"grpc://{config.grpc_endpoint}" + ) + grpc_p = urlparse(grpc_endpoint_with_scheme) + grpc_host = grpc_p.hostname or "localhost" + grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051) + grpc_secure = grpc_p.scheme == "grpcs" + else: + # Infer from HTTP endpoint as fallback + grpc_host = host + grpc_secure = http_secure + grpc_port = 443 if grpc_secure else 50051 + + client = weaviate.connect_to_custom( + http_host=host, + http_port=http_port, + http_secure=http_secure, + grpc_host=grpc_host, + grpc_port=grpc_port, + grpc_secure=grpc_secure, + auth_credentials=Auth.api_key(config.api_key) if config.api_key else None, + skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests ) - grpc_p = urlparse(grpc_endpoint_with_scheme) - grpc_host = grpc_p.hostname or "localhost" - grpc_port = grpc_p.port or (443 if grpc_p.scheme == "grpcs" else 50051) - grpc_secure = grpc_p.scheme == "grpcs" - else: - # Infer from HTTP endpoint as fallback - grpc_host = host - grpc_secure = http_secure - grpc_port = 443 if grpc_secure else 50051 - client = weaviate.connect_to_custom( - http_host=host, - http_port=http_port, - http_secure=http_secure, - grpc_host=grpc_host, - grpc_port=grpc_port, - grpc_secure=grpc_secure, - auth_credentials=Auth.api_key(config.api_key) if config.api_key else None, - skip_init_checks=True, # Skip PyPI version check to avoid unnecessary HTTP requests - ) + if not client.is_ready(): + raise ConnectionError("Vector database is not ready") - if not client.is_ready(): - raise ConnectionError("Vector database is not ready") - - return client + _weaviate_client = client + return client def get_type(self) -> str: """Returns the vector database type identifier.""" @@ -196,6 +224,7 @@ class WeaviateVector(BaseVector): ), wc.Property(name="document_id", data_type=wc.DataType.TEXT), wc.Property(name="doc_id", data_type=wc.DataType.TEXT), + wc.Property(name="doc_type", data_type=wc.DataType.TEXT), wc.Property(name="chunk_index", data_type=wc.DataType.INT), ], vector_config=wc.Configure.Vectors.self_provided(), @@ -225,6 +254,8 @@ class WeaviateVector(BaseVector): to_add.append(wc.Property(name="document_id", data_type=wc.DataType.TEXT)) if "doc_id" not in existing: to_add.append(wc.Property(name="doc_id", data_type=wc.DataType.TEXT)) + if "doc_type" not in existing: + to_add.append(wc.Property(name="doc_type", data_type=wc.DataType.TEXT)) if "chunk_index" not in existing: to_add.append(wc.Property(name="chunk_index", data_type=wc.DataType.INT)) diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index 69adac522d..16a5588024 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -6,8 +6,8 @@ from typing import Any from sqlalchemy import func, select from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 0efe19a57c..6d1b65a055 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -9,9 +9,9 @@ from sqlalchemy.exc import IntegrityError from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.embedding.embedding_base import Embeddings +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper diff --git a/api/core/rag/embedding/retrieval.py b/api/core/rag/embedding/retrieval.py index f6834ab87b..030237559d 100644 --- a/api/core/rag/embedding/retrieval.py +++ b/api/core/rag/embedding/retrieval.py @@ -1,8 +1,18 @@ from pydantic import BaseModel +from typing_extensions import TypedDict from models.dataset import DocumentSegment +class AttachmentInfoDict(TypedDict): + id: str + name: str + extension: str + mime_type: str + source_url: str + size: int + + class RetrievalChildChunk(BaseModel): """Retrieval segments.""" @@ -19,5 +29,5 @@ class RetrievalSegments(BaseModel): segment: DocumentSegment child_chunks: list[RetrievalChildChunk] | None = None score: float | None = None - files: list[dict[str, str | int]] | None = None + files: list[AttachmentInfoDict] | None = None summary: str | None = None # Summary content if retrieved via summary index diff --git a/api/core/rag/extractor/extract_processor.py b/api/core/rag/extractor/extract_processor.py index 6d28ce25bc..449be6a448 100644 --- a/api/core/rag/extractor/extract_processor.py +++ b/api/core/rag/extractor/extract_processor.py @@ -74,7 +74,8 @@ class ExtractProcessor: else: suffix = "" # https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521 - file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}" + # Generate a temporary filename under the created temp_dir and ensure the directory exists + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore Path(file_path).write_bytes(response.content) extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model") if return_text: diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 5d6223db06..371f7b0865 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -1,12 +1,38 @@ import json import time -from typing import Any, cast +from typing import Any, NotRequired, cast import httpx +from typing_extensions import TypedDict from extensions.ext_storage import storage +class FirecrawlDocumentData(TypedDict): + title: str | None + description: str | None + source_url: str | None + markdown: str | None + + +class CrawlStatusResponse(TypedDict): + status: str + total: int | None + current: int | None + data: list[FirecrawlDocumentData] + + +class MapResponse(TypedDict): + success: bool + links: list[str] + + +class SearchResponse(TypedDict): + success: bool + data: list[dict[str, Any]] + warning: NotRequired[str] + + class FirecrawlApp: def __init__(self, api_key=None, base_url=None): self.api_key = api_key @@ -14,7 +40,7 @@ class FirecrawlApp: if self.api_key is None and self.base_url == "https://api.firecrawl.dev": raise ValueError("No API key provided") - def scrape_url(self, url, params=None) -> dict[str, Any]: + def scrape_url(self, url, params=None) -> FirecrawlDocumentData: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/scrape headers = self._prepare_headers() json_data = { @@ -32,9 +58,7 @@ class FirecrawlApp: return self._extract_common_fields(data) elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "scrape URL") - return {} # Avoid additional exception after handling error - else: - raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") + raise Exception(f"Failed to scrape URL. Status code: {response.status_code}") def crawl_url(self, url, params=None) -> str: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post @@ -51,7 +75,7 @@ class FirecrawlApp: self._handle_error(response, "start crawl job") return "" # unreachable - def map(self, url: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def map(self, url: str, params: dict[str, Any] | None = None) -> MapResponse: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/map headers = self._prepare_headers() json_data: dict[str, Any] = {"url": url, "integration": "dify"} @@ -60,14 +84,12 @@ class FirecrawlApp: json_data.update(params) response = self._post_request(self._build_url("v2/map"), json_data, headers) if response.status_code == 200: - return cast(dict[str, Any], response.json()) + return cast(MapResponse, response.json()) elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "start map job") - return {} - else: - raise Exception(f"Failed to start map job. Status code: {response.status_code}") + raise Exception(f"Failed to start map job. Status code: {response.status_code}") - def check_crawl_status(self, job_id) -> dict[str, Any]: + def check_crawl_status(self, job_id) -> CrawlStatusResponse: headers = self._prepare_headers() response = self._get_request(self._build_url(f"v2/crawl/{job_id}"), headers) if response.status_code == 200: @@ -77,7 +99,7 @@ class FirecrawlApp: if total == 0: raise Exception("Failed to check crawl status. Error: No page found") data = crawl_status_response.get("data", []) - url_data_list = [] + url_data_list: list[FirecrawlDocumentData] = [] for item in data: if isinstance(item, dict) and "metadata" in item and "markdown" in item: url_data = self._extract_common_fields(item) @@ -95,13 +117,15 @@ class FirecrawlApp: return self._format_crawl_status_response( crawl_status_response.get("status"), crawl_status_response, [] ) - else: - self._handle_error(response, "check crawl status") - return {} # unreachable + self._handle_error(response, "check crawl status") + raise RuntimeError("unreachable: _handle_error always raises") def _format_crawl_status_response( - self, status: str, crawl_status_response: dict[str, Any], url_data_list: list[dict[str, Any]] - ) -> dict[str, Any]: + self, + status: str, + crawl_status_response: dict[str, Any], + url_data_list: list[FirecrawlDocumentData], + ) -> CrawlStatusResponse: return { "status": status, "total": crawl_status_response.get("total"), @@ -109,7 +133,7 @@ class FirecrawlApp: "data": url_data_list, } - def _extract_common_fields(self, item: dict[str, Any]) -> dict[str, Any]: + def _extract_common_fields(self, item: dict[str, Any]) -> FirecrawlDocumentData: return { "title": item.get("metadata", {}).get("title"), "description": item.get("metadata", {}).get("description"), @@ -117,7 +141,7 @@ class FirecrawlApp: "markdown": item.get("markdown"), } - def _prepare_headers(self) -> dict[str, Any]: + def _prepare_headers(self) -> dict[str, str]: return {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} def _build_url(self, path: str) -> str: @@ -150,10 +174,10 @@ class FirecrawlApp: error_message = response.text or "Unknown error occurred" raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] - def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + def search(self, query: str, params: dict[str, Any] | None = None) -> SearchResponse: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search headers = self._prepare_headers() - json_data = { + json_data: dict[str, Any] = { "query": query, "limit": 5, "lang": "en", @@ -170,12 +194,10 @@ class FirecrawlApp: json_data.update(params) response = self._post_request(self._build_url("v2/search"), json_data, headers) if response.status_code == 200: - response_data = response.json() + response_data: SearchResponse = response.json() if not response_data.get("success"): raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}") - return cast(dict[str, Any], response_data) + return response_data elif response.status_code in {402, 409, 500, 429, 408}: self._handle_error(response, "perform search") - return {} # Avoid additional exception after handling error - else: - raise Exception(f"Failed to perform search. Status code: {response.status_code}") + raise Exception(f"Failed to perform search. Status code: {response.status_code}") diff --git a/api/core/rag/extractor/pdf_extractor.py b/api/core/rag/extractor/pdf_extractor.py index 6aabcac704..9abdb31325 100644 --- a/api/core/rag/extractor/pdf_extractor.py +++ b/api/core/rag/extractor/pdf_extractor.py @@ -15,6 +15,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -150,7 +151,7 @@ class PdfExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self._tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=len(img_bytes), diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 7cf6c4d289..e8da866870 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -1,10 +1,11 @@ import json from collections.abc import Generator -from typing import Union +from typing import Any, Union from urllib.parse import urljoin import httpx from httpx import Response +from typing_extensions import TypedDict from core.rag.extractor.watercrawl.exceptions import ( WaterCrawlAuthenticationError, @@ -13,6 +14,27 @@ from core.rag.extractor.watercrawl.exceptions import ( ) +class SpiderOptions(TypedDict): + max_depth: int + page_limit: int + allowed_domains: list[str] + exclude_paths: list[str] + include_paths: list[str] + + +class PageOptions(TypedDict): + exclude_tags: list[str] + include_tags: list[str] + wait_time: int + include_html: bool + only_main_content: bool + include_links: bool + timeout: int + accept_cookies_selector: str + locale: str + actions: list[Any] + + class BaseAPIClient: def __init__(self, api_key, base_url): self.api_key = api_key @@ -121,9 +143,9 @@ class WaterCrawlAPIClient(BaseAPIClient): def create_crawl_request( self, url: Union[list, str] | None = None, - spider_options: dict | None = None, - page_options: dict | None = None, - plugin_options: dict | None = None, + spider_options: SpiderOptions | None = None, + page_options: PageOptions | None = None, + plugin_options: dict[str, Any] | None = None, ): data = { # 'urls': url if isinstance(url, list) else [url], @@ -176,8 +198,8 @@ class WaterCrawlAPIClient(BaseAPIClient): def scrape_url( self, url: str, - page_options: dict | None = None, - plugin_options: dict | None = None, + page_options: PageOptions | None = None, + plugin_options: dict[str, Any] | None = None, sync: bool = True, prefetched: bool = True, ): diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index fe983aa86a..81c19005db 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -2,16 +2,39 @@ from collections.abc import Generator from datetime import datetime from typing import Any -from core.rag.extractor.watercrawl.client import WaterCrawlAPIClient +from typing_extensions import TypedDict + +from core.rag.extractor.watercrawl.client import PageOptions, SpiderOptions, WaterCrawlAPIClient + + +class WatercrawlDocumentData(TypedDict): + title: str | None + description: str | None + source_url: str | None + markdown: str | None + + +class CrawlJobResponse(TypedDict): + status: str + job_id: str | None + + +class WatercrawlCrawlStatusResponse(TypedDict): + status: str + job_id: str | None + total: int + current: int + data: list[WatercrawlDocumentData] + time_consuming: float class WaterCrawlProvider: def __init__(self, api_key, base_url: str | None = None): self.client = WaterCrawlAPIClient(api_key, base_url) - def crawl_url(self, url, options: dict | Any | None = None): + def crawl_url(self, url: str, options: dict[str, Any] | None = None) -> CrawlJobResponse: options = options or {} - spider_options = { + spider_options: SpiderOptions = { "max_depth": 1, "page_limit": 1, "allowed_domains": [], @@ -25,7 +48,7 @@ class WaterCrawlProvider: spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else [] wait_time = options.get("wait_time", 1000) - page_options = { + page_options: PageOptions = { "exclude_tags": options.get("exclude_tags", "").split(",") if options.get("exclude_tags") else [], "include_tags": options.get("include_tags", "").split(",") if options.get("include_tags") else [], "wait_time": max(1000, wait_time), # minimum wait time is 1 second @@ -41,9 +64,9 @@ class WaterCrawlProvider: return {"status": "active", "job_id": result.get("uuid")} - def get_crawl_status(self, crawl_request_id): + def get_crawl_status(self, crawl_request_id: str) -> WatercrawlCrawlStatusResponse: response = self.client.get_crawl_request(crawl_request_id) - data = [] + data: list[WatercrawlDocumentData] = [] if response["status"] in ["new", "running"]: status = "active" else: @@ -67,7 +90,7 @@ class WaterCrawlProvider: "time_consuming": time_consuming, } - def get_crawl_url_data(self, job_id, url) -> dict | None: + def get_crawl_url_data(self, job_id: str, url: str) -> WatercrawlDocumentData | None: if not job_id: return self.scrape_url(url) @@ -82,11 +105,11 @@ class WaterCrawlProvider: return None - def scrape_url(self, url: str): + def scrape_url(self, url: str) -> WatercrawlDocumentData: response = self.client.scrape_url(url=url, sync=True, prefetched=True) return self._structure_data(response) - def _structure_data(self, result_object: dict): + def _structure_data(self, result_object: dict[str, Any]) -> WatercrawlDocumentData: if isinstance(result_object.get("result", {}), str): raise ValueError("Invalid result object. Expected a dictionary.") @@ -98,7 +121,9 @@ class WaterCrawlProvider: "markdown": result_object.get("result", {}).get("markdown"), } - def _get_results(self, crawl_request_id: str, query_params: dict | None = None) -> Generator[dict, None, None]: + def _get_results( + self, crawl_request_id: str, query_params: dict | None = None + ) -> Generator[WatercrawlDocumentData, None, None]: page = 0 page_size = 100 diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 1ddbfc5864..052fca930d 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -21,6 +21,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import UploadFile @@ -112,7 +113,7 @@ class WordExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=0, @@ -140,7 +141,7 @@ class WordExtractor(BaseExtractor): # save file to db upload_file = UploadFile( tenant_id=self.tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=file_key, size=0, @@ -204,26 +205,61 @@ class WordExtractor(BaseExtractor): return " ".join(unique_content) def _parse_cell_paragraph(self, paragraph, image_map): - paragraph_content = [] - for run in paragraph.runs: - if run.element.xpath(".//a:blip"): - for blip in run.element.xpath(".//a:blip"): - image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed") - if not image_id: - continue - rel = paragraph.part.rels.get(image_id) - if rel is None: - continue - # For external images, use image_id as key; for internal, use target_part - if rel.is_external: - if image_id in image_map: - paragraph_content.append(image_map[image_id]) - else: - image_part = rel.target_part - if image_part in image_map: - paragraph_content.append(image_map[image_part]) - else: - paragraph_content.append(run.text) + paragraph_content: list[str] = [] + + for child in paragraph._element: + tag = child.tag + if tag == qn("w:hyperlink"): + # Note: w:hyperlink elements may also use w:anchor for internal bookmarks. + # This extractor intentionally only converts external links (HTTP/mailto, etc.) + # that are backed by a relationship id (r:id) with rel.is_external == True. + # Hyperlinks without such an external rel (including anchor-only bookmarks) + # are left as plain text link_text. + r_id = child.get(qn("r:id")) + link_text_parts: list[str] = [] + for run_elem in child.findall(qn("w:r")): + run = Run(run_elem, paragraph) + if run.text: + link_text_parts.append(run.text) + link_text = "".join(link_text_parts).strip() + if r_id: + try: + rel = paragraph.part.rels.get(r_id) + if rel: + target_ref = getattr(rel, "target_ref", None) + if target_ref: + parsed_target = urlparse(str(target_ref)) + if rel.is_external or parsed_target.scheme in ("http", "https", "mailto"): + display_text = link_text or str(target_ref) + link_text = f"[{display_text}]({target_ref})" + except Exception: + logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id) + if link_text: + paragraph_content.append(link_text) + + elif tag == qn("w:r"): + run = Run(child, paragraph) + if run.element.xpath(".//a:blip"): + for blip in run.element.xpath(".//a:blip"): + image_id = blip.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed" + ) + if not image_id: + continue + rel = paragraph.part.rels.get(image_id) + if rel is None: + continue + if rel.is_external: + if image_id in image_map: + paragraph_content.append(image_map[image_id]) + else: + image_part = rel.target_part + if image_part in image_map: + paragraph_content.append(image_map[image_part]) + else: + if run.text: + paragraph_content.append(run.text) + return "".join(paragraph_content).strip() def parse_docx(self, docx_path): @@ -330,7 +366,7 @@ class WordExtractor(BaseExtractor): paragraph_content = [] # State for legacy HYPERLINK fields hyperlink_field_url = None - hyperlink_field_text_parts: list = [] + hyperlink_field_text_parts: list[str] = [] is_collecting_field_text = False # Iterate through paragraph elements in document order for child in paragraph._element: diff --git a/api/core/rag/index_processor/index_processor.py b/api/core/rag/index_processor/index_processor.py new file mode 100644 index 0000000000..d9145023ac --- /dev/null +++ b/api/core/rag/index_processor/index_processor.py @@ -0,0 +1,258 @@ +import concurrent.futures +import datetime +import logging +import time +from collections.abc import Mapping +from typing import Any + +from flask import current_app +from sqlalchemy import delete, func, select + +from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict +from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from core.workflow.nodes.knowledge_index.protocols import Preview, PreviewItem, QaPreview +from models.dataset import Dataset, Document, DocumentSegment + +from .index_processor_factory import IndexProcessorFactory +from .processor.paragraph_index_processor import ParagraphIndexProcessor + +logger = logging.getLogger(__name__) + + +class IndexProcessor: + def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: + index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() + preview = index_processor.format_preview(chunks) + data = Preview( + chunk_structure=preview["chunk_structure"], + total_segments=preview["total_segments"], + preview=[], + parent_mode=None, + qa_preview=[], + ) + if "parent_mode" in preview: + data.parent_mode = preview["parent_mode"] + + for item in preview["preview"]: + if "content" in item and "child_chunks" in item: + data.preview.append( + PreviewItem(content=item["content"], child_chunks=item["child_chunks"], summary=None) + ) + elif "question" in item and "answer" in item: + data.qa_preview.append(QaPreview(question=item["question"], answer=item["answer"])) + elif "content" in item: + data.preview.append(PreviewItem(content=item["content"], child_chunks=None, summary=None)) + return data + + def index_and_clean( + self, + dataset_id: str, + document_id: str, + original_document_id: str, + chunks: Mapping[str, Any], + batch: Any, + summary_index_setting: SummaryIndexSettingDict | None = None, + ): + with session_factory.create_session() as session: + document = session.query(Document).filter_by(id=document_id).first() + if not document: + raise KnowledgeIndexNodeError(f"Document {document_id} not found.") + + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") + + dataset_name_value = dataset.name + document_name_value = document.name + created_at_value = document.created_at + if summary_index_setting is None: + summary_index_setting = dataset.summary_index_setting + index_node_ids = [] + + index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() + if original_document_id: + segments = session.scalars( + select(DocumentSegment).where(DocumentSegment.document_id == original_document_id) + ).all() + if segments: + index_node_ids = [segment.index_node_id for segment in segments] + + indexing_start_at = time.perf_counter() + # delete from vector index + if index_node_ids: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + + with session_factory.create_session() as session, session.begin(): + if index_node_ids: + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == original_document_id) + session.execute(segment_delete_stmt) + + index_processor.index(dataset, document, chunks) + indexing_end_at = time.perf_counter() + + with session_factory.create_session() as session, session.begin(): + document.indexing_latency = indexing_end_at - indexing_start_at + document.indexing_status = "completed" + document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.word_count = ( + session.query(func.sum(DocumentSegment.word_count)) + .where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ) + .scalar() + ) or 0 + # Update need_summary based on dataset's summary_index_setting + if summary_index_setting and summary_index_setting.get("enable") is True: + document.need_summary = True + else: + document.need_summary = False + session.add(document) + # update document segment status + session.query(DocumentSegment).where( + DocumentSegment.document_id == document_id, + DocumentSegment.dataset_id == dataset_id, + ).update( + { + DocumentSegment.status: "completed", + DocumentSegment.enabled: True, + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + } + ) + + return { + "dataset_id": dataset_id, + "dataset_name": dataset_name_value, + "batch": batch, + "document_id": document_id, + "document_name": document_name_value, + "created_at": created_at_value.timestamp(), + "display_status": "completed", + } + + def get_preview_output( + self, + chunks: Any, + dataset_id: str, + document_id: str, + chunk_structure: str, + summary_index_setting: SummaryIndexSettingDict | None, + ) -> Preview: + doc_language = None + with session_factory.create_session() as session: + if document_id: + document = session.query(Document).filter_by(id=document_id).first() + else: + document = None + + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset: + raise KnowledgeIndexNodeError(f"Dataset {dataset_id} not found.") + + if summary_index_setting is None: + summary_index_setting = dataset.summary_index_setting + + if document: + doc_language = document.doc_language + indexing_technique = dataset.indexing_technique + tenant_id = dataset.tenant_id + + preview_output = self.format_preview(chunk_structure, chunks) + if indexing_technique != "high_quality": + return preview_output + + if not summary_index_setting or not summary_index_setting.get("enable"): + return preview_output + + if preview_output.preview is not None: + chunk_count = len(preview_output.preview) + logger.info( + "Generating summaries for %s chunks in preview mode (dataset: %s)", + chunk_count, + dataset_id, + ) + + flask_app = None + try: + flask_app = current_app._get_current_object() # type: ignore + except RuntimeError: + logger.warning("No Flask application context available, summary generation may fail") + + def generate_summary_for_chunk(preview_item: PreviewItem) -> None: + """Generate summary for a single chunk.""" + if flask_app: + with flask_app.app_context(): + if preview_item.content is not None: + # Set Flask application context in worker thread + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview_item.content, + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + if summary: + preview_item.summary = summary + + else: + summary, _ = ParagraphIndexProcessor.generate_summary( + tenant_id=tenant_id, + text=preview_item.content if preview_item.content is not None else "", + summary_index_setting=summary_index_setting, + document_language=doc_language, + ) + if summary: + preview_item.summary = summary + + # Generate summaries concurrently using ThreadPoolExecutor + # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) + timeout_seconds = min(300, 60 * len(preview_output.preview)) + errors: list[Exception] = [] + + with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output.preview))) as executor: + futures = [ + executor.submit(generate_summary_for_chunk, preview_item) for preview_item in preview_output.preview + ] + # Wait for all tasks to complete with timeout + done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) + + # Cancel tasks that didn't complete in time + if not_done: + timeout_error_msg = ( + f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" + ) + logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) + # In preview mode, timeout is also an error + errors.append(TimeoutError(timeout_error_msg)) + for future in not_done: + future.cancel() + # Wait a bit for cancellation to take effect + concurrent.futures.wait(not_done, timeout=5) + + # Collect exceptions from completed futures + for future in done: + try: + future.result() # This will raise any exception that occurred + except Exception as e: + logger.exception("Error in summary generation future") + errors.append(e) + + # In preview mode, if there are any errors, fail the request + if errors: + error_messages = [str(e) for e in errors] + error_summary = ( + f"Failed to generate summaries for {len(errors)} chunk(s). " + f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors + ) + if len(errors) > 3: + error_summary += f" (and {len(errors) - 3} more)" + logger.error("Summary generation failed in preview mode: %s", error_summary) + raise KnowledgeIndexNodeError(error_summary) + + completed_count = sum(1 for item in preview_output.preview if item.summary is not None) + logger.info( + "Completed summary generation for preview chunks: %s/%s succeeded", + completed_count, + len(preview_output.preview), + ) + return preview_output diff --git a/api/core/rag/index_processor/index_processor_base.py b/api/core/rag/index_processor/index_processor_base.py index e8b3fa1508..a435dfc46a 100644 --- a/api/core/rag/index_processor/index_processor_base.py +++ b/api/core/rag/index_processor/index_processor_base.py @@ -7,14 +7,16 @@ import os import re from abc import ABC, abstractmethod from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, NotRequired, Optional from urllib.parse import unquote, urlparse import httpx +from typing_extensions import TypedDict from configs import dify_config from core.entities.knowledge_entities import PreviewDetail from core.helper import ssrf_proxy +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import AttachmentDocument, Document @@ -35,6 +37,13 @@ if TYPE_CHECKING: from core.model_manager import ModelInstance +class SummaryIndexSettingDict(TypedDict): + enable: bool + model_name: NotRequired[str] + model_provider_name: NotRequired[str] + summary_prompt: NotRequired[str] + + class BaseIndexProcessor(ABC): """Interface for extract files.""" @@ -51,7 +60,7 @@ class BaseIndexProcessor(ABC): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -98,7 +107,7 @@ class BaseIndexProcessor(ABC): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: raise NotImplementedError @@ -294,7 +303,7 @@ class BaseIndexProcessor(ABC): logging.warning("Error downloading image from %s: %s", image_url, str(e)) return None except Exception: - logging.exception("Unexpected error downloading image from %s", image_url) + logging.warning("Unexpected error downloading image from %s", image_url, exc_info=True) return None def _download_tool_file(self, tool_file_id: str, current_user: Account) -> str | None: diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index df5c89a522..80163b1707 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -12,17 +12,9 @@ from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.provider_manager import ProviderManager from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector @@ -31,11 +23,20 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from core.workflow.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.file import File, FileTransferMethod, FileType, file_manager +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs import helper @@ -175,7 +176,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -278,7 +279,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ @@ -362,7 +363,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): def generate_summary( tenant_id: str, text: str, - summary_index_setting: dict | None = None, + summary_index_setting: SummaryIndexSettingDict | None = None, segment_id: str | None = None, document_language: str | None = None, ) -> tuple[str, LLMUsage]: diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py index 367f0aec00..df0761ca73 100644 --- a/api/core/rag/index_processor/processor/parent_child_index_processor.py +++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py @@ -11,6 +11,7 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.model_manager import ModelInstance from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore @@ -18,7 +19,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, ChildDocument, Document, ParentChildStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db @@ -215,7 +216,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ) -> list[Document]: # Set search parameters. results = RetrievalService.retrieve( @@ -361,7 +362,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 503cce2132..62f88b7760 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -15,13 +15,14 @@ from core.db.session_factory import session_factory from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.llm_generator import LLMGenerator from core.rag.cleaner.clean_processor import CleanProcessor +from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.vdb.vector_factory import Vector from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.extract_processor import ExtractProcessor from core.rag.index_processor.constant.index_type import IndexStructureType -from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.index_processor.index_processor_base import BaseIndexProcessor, SummaryIndexSettingDict from core.rag.models.document import AttachmentDocument, Document, QAStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols @@ -185,7 +186,7 @@ class QAIndexProcessor(BaseIndexProcessor): dataset: Dataset, top_k: int, score_threshold: float, - reranking_model: dict, + reranking_model: RerankingModelDict, ): # Set search parameters. results = RetrievalService.retrieve( @@ -244,7 +245,7 @@ class QAIndexProcessor(BaseIndexProcessor): self, tenant_id: str, preview_texts: list[PreviewDetail], - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, doc_language: str | None = None, ) -> list[PreviewDetail]: """ diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 48639bf4c8..dc3b771406 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -4,7 +4,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.file import File +from dify_graph.file import File class ChildDocument(BaseModel): diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index 690e780921..fcb14ffc52 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,12 +1,12 @@ import base64 from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.rerank_entities import RerankResult from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage from models.model import UploadFile diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 18020608cb..7edd05d2d1 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -4,7 +4,6 @@ from collections import Counter import numpy as np from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.index_processor.constant.doc_type import DocType @@ -12,6 +11,7 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner +from dify_graph.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index cfea8d114a..78a97f79a5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -25,19 +25,15 @@ from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode -from core.rag.data_post_processor.data_post_processor import DataPostProcessor +from core.rag.data_post_processor.data_post_processor import DataPostProcessor, RerankingModelDict, WeightsDict from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler -from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.datasource.retrieval_service import DefaultRetrievalModelDict, RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition @@ -60,14 +56,18 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool -from core.workflow.file import File, FileTransferMethod, FileType from core.workflow.nodes.knowledge_retrieval import exc -from core.workflow.repositories.rag_retrieval_protocol import ( +from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, Source, SourceChildChunk, SourceMetadata, ) +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown @@ -83,10 +83,11 @@ from models.dataset import ( ) from models.dataset import Document as DatasetDocument from models.dataset import Document as DocumentModel +from models.enums import CreatorUserRole, DatasetQuerySource from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureService -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -127,11 +128,12 @@ class DatasetRetrieval: metadata_filter_document_ids, metadata_condition = None, None if request.metadata_filtering_mode != "disabled": - # Convert workflow layer types to app_config layer types - if not request.metadata_model_config: - raise ValueError("metadata_model_config is required for this method") + app_metadata_model_config = ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}) + if request.metadata_filtering_mode == "automatic": + if not request.metadata_model_config: + raise ValueError("metadata_model_config is required for this method") - app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump()) + app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump()) app_metadata_filtering_conditions = None if request.metadata_filtering_conditions is not None: @@ -248,19 +250,22 @@ class DatasetRetrieval: retrieval_resource_list = [] # deal with external documents for item in external_documents: + ext_meta = item.metadata or {} + title = ext_meta.get("title") or "" + doc_id = ext_meta.get("document_id") or title source = Source( metadata=SourceMetadata( source="knowledge", - dataset_id=item.metadata.get("dataset_id"), - dataset_name=item.metadata.get("dataset_name"), - document_id=item.metadata.get("document_id"), - document_name=item.metadata.get("title"), + dataset_id=ext_meta.get("dataset_id") or "", + dataset_name=ext_meta.get("dataset_name") or "", + document_id=str(doc_id), + document_name=ext_meta.get("title") or "", data_source_type="external", retriever_from="workflow", - score=item.metadata.get("score"), - doc_metadata=item.metadata, + score=float(ext_meta.get("score") or 0.0), + doc_metadata=ext_meta, ), - title=item.metadata.get("title"), + title=title, content=item.page_content, ) retrieval_resource_list.append(source) @@ -586,7 +591,7 @@ class DatasetRetrieval: user_id: str, user_from: str, query: str, - available_datasets: list, + available_datasets: list[Dataset], model_instance: ModelInstance, model_config: ModelConfigWithCredentialsEntity, planning_strategy: PlanningStrategy, @@ -628,15 +633,15 @@ class DatasetRetrieval: if dataset_id: # get retrieval model config dataset_stmt = select(Dataset).where(Dataset.id == dataset_id) - dataset = db.session.scalar(dataset_stmt) - if dataset: + selected_dataset = db.session.scalar(dataset_stmt) + if selected_dataset: results = [] - if dataset.provider == "external": + if selected_dataset.provider == "external": external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( - tenant_id=dataset.tenant_id, + tenant_id=selected_dataset.tenant_id, dataset_id=dataset_id, query=query, - external_retrieval_parameters=dataset.retrieval_model, + external_retrieval_parameters=selected_dataset.retrieval_model, metadata_condition=metadata_condition, ) for external_document in external_documents: @@ -649,24 +654,28 @@ class DatasetRetrieval: document.metadata["score"] = external_document.get("score") document.metadata["title"] = external_document.get("title") document.metadata["dataset_id"] = dataset_id - document.metadata["dataset_name"] = dataset.name + document.metadata["dataset_name"] = selected_dataset.name results.append(document) else: if metadata_condition and not metadata_filter_document_ids: return [] document_ids_filter = None if metadata_filter_document_ids: - document_ids = metadata_filter_document_ids.get(dataset.id, []) + document_ids = metadata_filter_document_ids.get(selected_dataset.id, []) if document_ids: document_ids_filter = document_ids else: return [] - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, selected_dataset.retrieval_model) + if selected_dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"] # get retrieval method - if dataset.indexing_technique == "economy": + if selected_dataset.indexing_technique == "economy": retrieval_method = RetrievalMethod.KEYWORD_SEARCH else: retrieval_method = retrieval_model_config["search_method"] @@ -685,7 +694,7 @@ class DatasetRetrieval: with measure_time() as timer: results = RetrievalService.retrieve( retrieval_method=retrieval_method, - dataset_id=dataset.id, + dataset_id=selected_dataset.id, query=query, top_k=top_k, score_threshold=score_threshold, @@ -717,13 +726,13 @@ class DatasetRetrieval: tenant_id: str, user_id: str, user_from: str, - available_datasets: list, + available_datasets: list[Dataset], query: str | None, top_k: int, score_threshold: float, reranking_mode: str, - reranking_model: dict | None = None, - weights: dict[str, Any] | None = None, + reranking_model: RerankingModelDict | None = None, + weights: WeightsDict | None = None, reranking_enable: bool = True, message_id: str | None = None, metadata_filter_document_ids: dict[str, list[str]] | None = None, @@ -1003,9 +1012,9 @@ class DatasetRetrieval: dataset_query = DatasetQuery( dataset_id=dataset_id, content=json.dumps(contents), - source="app", + source=DatasetQuerySource.APP, source_app_id=app_id, - created_by_role=user_from, + created_by_role=CreatorUserRole(user_from), created_by=user_id, ) dataset_queries.append(dataset_query) @@ -1019,7 +1028,7 @@ class DatasetRetrieval: dataset_id: str, query: str, top_k: int, - all_documents: list, + all_documents: list[Document], document_ids_filter: list[str] | None = None, metadata_condition: MetadataCondition | None = None, attachment_ids: list[str] | None = None, @@ -1053,7 +1062,11 @@ class DatasetRetrieval: all_documents.append(document) else: # get retrieval model , if the model is not setting , using default - retrieval_model = dataset.retrieval_model or default_retrieval_model + retrieval_model: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) if dataset.indexing_technique == "economy": # use keyword table query @@ -1127,7 +1140,7 @@ class DatasetRetrieval: if retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: # get retrieval model config - default_retrieval_model = { + default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -1136,7 +1149,11 @@ class DatasetRetrieval: } for dataset in available_datasets: - retrieval_model_config = dataset.retrieval_model or default_retrieval_model + retrieval_model_config: DefaultRetrievalModelDict = ( + cast(DefaultRetrievalModelDict, dataset.retrieval_model) + if dataset.retrieval_model + else default_retrieval_model + ) # get top k top_k = retrieval_model_config["top_k"] @@ -1176,8 +1193,8 @@ class DatasetRetrieval: hit_callbacks=[hit_callback], return_resource=return_resource, retriever_from=invoke_from.to_source(), - reranking_provider_name=retrieve_config.reranking_model.get("reranking_provider_name"), - reranking_model_name=retrieve_config.reranking_model.get("reranking_model_name"), + reranking_provider_name=retrieve_config.reranking_model["reranking_provider_name"], + reranking_model_name=retrieve_config.reranking_model["reranking_model_name"], ) tools.append(tool) @@ -1281,7 +1298,7 @@ class DatasetRetrieval: def get_metadata_filter_condition( self, - dataset_ids: list, + dataset_ids: list[str], query: str, tenant_id: str, user_id: str, @@ -1383,7 +1400,7 @@ class DatasetRetrieval: return output def _automatic_metadata_filter_func( - self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig + self, dataset_ids: list[str], query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig ) -> list[dict[str, Any]] | None: # get all metadata field metadata_stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) @@ -1581,7 +1598,7 @@ class DatasetRetrieval: ) def _get_prompt_template( - self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str + self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list[str], query: str ): model_mode = ModelMode(mode) input_text = query @@ -1673,15 +1690,15 @@ class DatasetRetrieval: def _multiple_retrieve_thread( self, flask_app: Flask, - available_datasets: list, + available_datasets: list[Dataset], metadata_condition: MetadataCondition | None, metadata_filter_document_ids: dict[str, list[str]] | None, all_documents: list[Document], tenant_id: str, reranking_enable: bool, reranking_mode: str, - reranking_model: dict | None, - weights: dict[str, Any] | None, + reranking_model: RerankingModelDict | None, + weights: WeightsDict | None, top_k: int, score_threshold: float, query: str | None, diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 5f3e1a8cae..23a2ac8386 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -2,8 +2,8 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index fa2007122d..ea110fa0a7 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -4,12 +4,12 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index b65cb14d8e..7a00e8a886 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -7,7 +7,6 @@ import re from typing import Any from core.model_manager import ModelInstance -from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer from core.rag.splitter.text_splitter import ( TS, Collection, @@ -16,6 +15,7 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) +from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/model_runtime/callbacks/__init__.py b/api/core/rag/summary_index/__init__.py similarity index 100% rename from api/core/model_runtime/callbacks/__init__.py rename to api/core/rag/summary_index/__init__.py diff --git a/api/core/rag/summary_index/summary_index.py b/api/core/rag/summary_index/summary_index.py new file mode 100644 index 0000000000..31d21dbeee --- /dev/null +++ b/api/core/rag/summary_index/summary_index.py @@ -0,0 +1,91 @@ +import concurrent.futures +import logging + +from core.db.session_factory import session_factory +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict +from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary +from services.summary_index_service import SummaryIndexService +from tasks.generate_summary_index_task import generate_summary_index_task + +logger = logging.getLogger(__name__) + + +class SummaryIndex: + def generate_and_vectorize_summary( + self, + dataset_id: str, + document_id: str, + is_preview: bool, + summary_index_setting: SummaryIndexSettingDict | None = None, + ) -> None: + if is_preview: + with session_factory.create_session() as session: + dataset = session.query(Dataset).filter_by(id=dataset_id).first() + if not dataset or dataset.indexing_technique != "high_quality": + return + + if summary_index_setting is None: + summary_index_setting = dataset.summary_index_setting + + if not summary_index_setting or not summary_index_setting.get("enable"): + return + + if not document_id: + return + + document = session.query(Document).filter_by(id=document_id).first() + # Skip qa_model documents + if document is None or document.doc_form == "qa_model": + return + + query = session.query(DocumentSegment).filter_by( + dataset_id=dataset_id, + document_id=document_id, + status="completed", + enabled=True, + ) + segments = query.all() + segment_ids = [segment.id for segment in segments] + + if not segment_ids: + return + + existing_summaries = ( + session.query(DocumentSegmentSummary) + .filter( + DocumentSegmentSummary.chunk_id.in_(segment_ids), + DocumentSegmentSummary.dataset_id == dataset_id, + DocumentSegmentSummary.status == "completed", + ) + .all() + ) + completed_summary_segment_ids = {i.chunk_id for i in existing_summaries} + # Preview mode should process segments that are MISSING completed summaries + pending_segment_ids = [sid for sid in segment_ids if sid not in completed_summary_segment_ids] + + # If all segments already have completed summaries, nothing to do in preview mode + if not pending_segment_ids: + return + + max_workers = min(10, len(pending_segment_ids)) + + def process_segment(segment_id: str) -> None: + """Process a single segment in a thread with a fresh DB session.""" + with session_factory.create_session() as session: + segment = session.query(DocumentSegment).filter_by(id=segment_id).first() + if segment is None: + return + try: + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) + except Exception: + logger.exception( + "Failed to generate summary for segment %s", + segment_id, + ) + # Continue processing other segments + + with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = [executor.submit(process_segment, segment_id) for segment_id in pending_segment_ids] + concurrent.futures.wait(futures) + else: + generate_summary_index_task.delay(dataset_id, document_id, None) diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index c7f5942f5f..57764574d7 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities.workflow_execution import WorkflowExecution -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 9b8e45b1eb..650cf79550 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,8 +12,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.repositories.workflow_node_execution_repository import ( +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.repositories.workflow_node_execution_repository import ( OrderConfig, WorkflowNodeExecutionRepository, ) diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index 02fcabab5d..dc9f8c96bf 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -11,8 +11,8 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 0e04c56e0e..6607a87032 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -4,10 +4,11 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any -from sqlalchemy import Engine, select -from sqlalchemy.orm import Session, selectinload, sessionmaker +from sqlalchemy import select +from sqlalchemy.orm import Session, selectinload -from core.workflow.nodes.human_input.entities import ( +from core.db.session_factory import session_factory +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, @@ -17,12 +18,12 @@ from core.workflow.nodes.human_input.entities import ( MemberRecipient, WebAppDeliveryMethod, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( DeliveryMethodType, HumanInputFormKind, HumanInputFormStatus, ) -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, FormNotFoundError, HumanInputFormEntity, @@ -198,12 +199,9 @@ class _InvalidTimeoutStatusError(ValueError): class HumanInputFormRepositoryImpl: def __init__( self, - session_factory: sessionmaker | Engine, + *, tenant_id: str, ): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory self._tenant_id = tenant_id def _delivery_method_to_model( @@ -217,7 +215,7 @@ class HumanInputFormRepositoryImpl: id=delivery_id, form_id=form_id, delivery_method_type=delivery_method.type, - delivery_config_id=delivery_method.id, + delivery_config_id=str(delivery_method.id), channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] @@ -343,7 +341,7 @@ class HumanInputFormRepositoryImpl: def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): # Generate unique form ID form_id = str(uuidv7()) start_time = naive_utc_now() @@ -435,7 +433,7 @@ class HumanInputFormRepositoryImpl: HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: form_model: HumanInputForm | None = session.scalars(form_query).first() if form_model is None: return None @@ -448,18 +446,13 @@ class HumanInputFormRepositoryImpl: class HumanInputFormSubmissionRepository: """Repository for fetching and submitting human input forms.""" - def __init__(self, session_factory: sessionmaker | Engine): - if isinstance(session_factory, Engine): - session_factory = sessionmaker(bind=session_factory) - self._session_factory = session_factory - def get_by_token(self, form_token: str) -> HumanInputFormRecord | None: query = ( select(HumanInputFormRecipient) .options(selectinload(HumanInputFormRecipient.form)) .where(HumanInputFormRecipient.access_token == form_token) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -478,7 +471,7 @@ class HumanInputFormSubmissionRepository: HumanInputFormRecipient.recipient_type == recipient_type, ) ) - with self._session_factory(expire_on_commit=False) as session: + with session_factory.create_session() as session: recipient_model = session.scalars(query).first() if recipient_model is None or recipient_model.form is None: return None @@ -494,7 +487,7 @@ class HumanInputFormSubmissionRepository: submission_user_id: str | None, submission_end_user_id: str | None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") @@ -524,7 +517,7 @@ class HumanInputFormSubmissionRepository: timeout_status: HumanInputFormStatus, reason: str | None = None, ) -> HumanInputFormRecord: - with self._session_factory(expire_on_commit=False) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): form_model = session.get(HumanInputForm, form_id) if form_model is None: raise FormNotFoundError(f"form not found, id={form_id}") diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 9091a3190b..55e96515ac 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,10 +9,10 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( Account, @@ -146,7 +146,9 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # No sequence number generation needed anymore - db_model.type = domain_model.workflow_type + from models.workflow import WorkflowType as ModelWorkflowType + + db_model.type = ModelWorkflowType(domain_model.workflow_type.value) db_model.version = domain_model.workflow_version db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None @@ -194,6 +196,13 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Create a new database session with self._session_factory() as session: + existing_model = session.get(WorkflowRun, db_model.id) + if existing_model: + if existing_model.tenant_id != self._tenant_id: + raise ValueError("Unauthorized access to workflow run") + # Preserve the original start time for pause/resume flows. + db_model.created_at = existing_model.created_at + # SQLAlchemy merge intelligently handles both insert and update operations # based on the presence of the primary key session.merge(db_model) diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 324dd059d1..7373ebc7cc 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,11 +17,11 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_storage import storage from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 @@ -146,7 +146,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) index=db_model.index, predecessor_node_id=db_model.predecessor_node_id, node_id=db_model.node_id, - node_type=NodeType(db_model.node_type), + node_type=db_model.node_type, title=db_model.title, inputs=inputs, process_data=process_data, diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index 50105bd707..20cdb3e57f 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -113,17 +113,26 @@ class BuiltinToolProviderController(ToolProviderController): """ return self.get_credentials_schema_by_type(CredentialType.API_KEY) - def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + def get_credentials_schema_by_type(self, credential_type: CredentialType | str) -> list[ProviderConfig]: """ returns the credentials schema of the provider - :param credential_type: the type of the credential - :return: the credentials schema of the provider + :param credential_type: the type of the credential, as CredentialType or str; str values + are normalized via CredentialType.of and may raise ValueError for invalid values. + :return: list[ProviderConfig] for CredentialType.OAUTH2 or CredentialType.API_KEY, an + empty list for CredentialType.UNAUTHORIZED or missing schemas. + + Reads from self.entity.oauth_schema and self.entity.credentials_schema. + Raises ValueError for invalid credential types. """ - if credential_type == CredentialType.OAUTH2.value: + if isinstance(credential_type, str): + credential_type = CredentialType.of(credential_type) + if credential_type == CredentialType.OAUTH2: return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] if credential_type == CredentialType.API_KEY: return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + if credential_type == CredentialType.UNAUTHORIZED: + return [] raise ValueError(f"Invalid credential type: {credential_type}") def get_oauth_client_schema(self) -> list[ProviderConfig]: diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index 2c1e9fb555..dacc49c746 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -3,13 +3,13 @@ from collections.abc import Generator from typing import Any from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.workflow.file.enums import FileType -from core.workflow.file.file_manager import download +from dify_graph.file.enums import FileType +from dify_graph.file.file_manager import download +from dify_graph.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 5009f7ac21..7818bff0ab 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -3,11 +3,11 @@ from collections.abc import Generator from typing import Any from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index 51b0407886..bcf58394ba 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,11 @@ from __future__ import annotations -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but @@ -50,7 +50,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.invoke( user_id=user_id, tenant_id=self.runtime.tenant_id or "", - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, ) diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index afa2ddffed..c6a84e27c6 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -13,7 +13,7 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from core.workflow.file.file_manager import download +from dify_graph.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 218ffafd55..2545290b57 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -5,11 +5,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType +from dify_graph.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 1d439323f2..9025ff6ef1 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -17,11 +17,11 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index de476f6461..64212a2636 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -31,10 +31,10 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from core.workflow.file import FileType -from core.workflow.file.models import FileTransferMethod +from dify_graph.file import FileType +from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile logger = logging.getLogger(__name__) @@ -352,7 +352,7 @@ class ToolEngine: message_id=agent_message.id, type=file_type, transfer_method=FileTransferMethod.TOOL_FILE, - belongs_to="assistant", + belongs_to=MessageFileBelongsTo.ASSISTANT, url=message.url, upload_file_id=tool_file_id, created_by_role=( diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index ca0dc27f3d..210f488afc 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,28 +10,19 @@ from typing import Union from uuid import uuid4 import httpx -from sqlalchemy.orm import Session from configs import dify_config +from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from extensions.ext_database import db as global_db +from dify_graph.file.models import ToolFile as ToolFilePydanticModel from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile logger = logging.getLogger(__name__) -from sqlalchemy.engine import Engine - class ToolFileManager: - _engine: Engine - - def __init__(self, engine: Engine | None = None): - if engine is None: - engine = global_db.engine - self._engine = engine - @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -89,7 +80,7 @@ class ToolFileManager: filepath = f"tools/{tenant_id}/{unique_filename}" storage.save(filepath, file_binary) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -132,7 +123,7 @@ class ToolFileManager: filename = f"{unique_name}{extension}" filepath = f"tools/{tenant_id}/{filename}" storage.save(filepath, blob) - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, @@ -146,6 +137,7 @@ class ToolFileManager: session.add(tool_file) session.commit() + session.refresh(tool_file) return tool_file @@ -157,7 +149,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( @@ -181,7 +173,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: message_file: MessageFile | None = ( session.query(MessageFile) .where( @@ -217,7 +209,9 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: + def get_file_generator_by_tool_file_id( + self, tool_file_id: str + ) -> tuple[Generator | None, ToolFilePydanticModel | None]: """ get file binary @@ -225,7 +219,7 @@ class ToolFileManager: :return: the binary of the file, mime type """ - with Session(self._engine, expire_on_commit=False) as session: + with session_factory.create_session() as session: tool_file: ToolFile | None = ( session.query(ToolFile) .where( @@ -239,11 +233,11 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, tool_file + return stream, ToolFilePydanticModel.model_validate(tool_file) # init tool_file_parser -from core.workflow.file.tool_file_parser import set_tool_file_manager_factory +from dify_graph.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 90d5a647e9..250dd91bfd 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -38,7 +38,7 @@ class ToolLabelManager: db.session.add( ToolLabelBinding( tool_id=provider_id, - tool_type=controller.provider_type.value, + tool_type=controller.provider_type, label_name=label, ) ) @@ -58,7 +58,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") stmt = select(ToolLabelBinding.label_name).where( ToolLabelBinding.tool_id == provider_id, - ToolLabelBinding.tool_type == controller.provider_type.value, + ToolLabelBinding.tool_type == controller.provider_type, ) labels = db.session.scalars(stmt).all() diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d561d39923..23a877b7e3 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -24,20 +24,19 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from core.workflow.runtime.variable_pool import VariablePool +from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity + from dify_graph.nodes.tool.entities import ToolEntity from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -58,11 +57,12 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from core.workflow.nodes.tool.entities import ToolEntity + from dify_graph.nodes.tool.entities import ToolEntity logger = logging.getLogger(__name__) @@ -72,6 +72,11 @@ class ApiProviderControllerItem(TypedDict): controller: ApiToolProviderController +class EmojiIconDict(TypedDict): + background: str + content: str + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -179,7 +184,6 @@ class ToolManager: :return: the tool """ - if provider_type == ToolProviderType.BUILT_IN: # check if the builtin tool need credentials provider_controller = cls.get_builtin_provider(provider_id, tenant_id) @@ -628,9 +632,9 @@ class ToolManager: # MySQL: Use window function to achieve same result sql = """ SELECT id FROM ( - SELECT id, + SELECT id, ROW_NUMBER() OVER ( - PARTITION BY tenant_id, provider + PARTITION BY tenant_id, provider ORDER BY is_default DESC, created_at DESC ) as rn FROM tool_builtin_providers @@ -917,7 +921,7 @@ class ToolManager: ) @classmethod - def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: workflow_provider: WorkflowToolProvider | None = ( db.session.query(WorkflowToolProvider) @@ -934,7 +938,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str]: + def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: api_provider: ApiToolProvider | None = ( db.session.query(ApiToolProvider) @@ -951,7 +955,7 @@ class ToolManager: return {"background": "#252525", "content": "\ud83d\ude01"} @classmethod - def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> Mapping[str, str] | str: + def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict | dict[str, str] | str: try: with Session(db.engine) as session: mcp_service = MCPToolManageService(session=session) @@ -971,7 +975,7 @@ class ToolManager: tenant_id: str, provider_type: ToolProviderType, provider_id: str, - ) -> str | Mapping[str, str]: + ) -> str | EmojiIconDict | dict[str, str]: """ get the tool icon @@ -1017,8 +1021,8 @@ class ToolManager: """ Convert tool parameters type """ - from core.workflow.nodes.tool.entities import ToolNodeData - from core.workflow.nodes.tool.exc import ToolParameterError + from dify_graph.nodes.tool.entities import ToolNodeData + from dify_graph.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 3ac487a471..37a2c957b0 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -116,6 +116,7 @@ class ToolParameterConfigurationManager: return a deep copy of parameters with decrypted values """ + parameters = self._deep_copy(parameters) cache = ToolParameterCache( tenant_id=self.tenant_id, diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 20e10be075..c2b520fa99 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,5 +1,4 @@ import threading -from typing import Any from flask import Flask, current_app from pydantic import BaseModel, Field @@ -7,17 +6,18 @@ from sqlalchemy import select from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment -default_retrieval_model: dict[str, Any] = { +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index 057ec41f65..429b7e6622 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -1,9 +1,10 @@ -from typing import Any, cast +from typing import NotRequired, TypedDict, cast from pydantic import BaseModel, Field from sqlalchemy import select from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext @@ -16,7 +17,19 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model: dict[str, Any] = { + +class DefaultRetrievalModelDict(TypedDict): + search_method: RetrievalMethod + reranking_enable: bool + reranking_model: RerankingModelDict + reranking_mode: NotRequired[str] + weights: NotRequired[WeightsDict | None] + score_threshold: NotRequired[float] + top_k: int + score_threshold_enabled: bool + + +default_retrieval_model: DefaultRetrievalModelDict = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -125,7 +138,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if metadata_condition and not document_ids_filter: return "" # get retrieval model , if the model is not setting , using default - retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + retrieval_model = dataset.retrieval_model or default_retrieval_model retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 622cdcf73b..6fc5fead2d 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -10,7 +10,7 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index e7fba09359..373bd1b1c8 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -9,18 +9,19 @@ from decimal import Decimal from typing import cast from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMResult -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.errors.invoke import ( +from core.tools.entities.tool_entities import ToolProviderType +from dify_graph.model_runtime.entities.llm_entities import LLMResult +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ToolModelInvoke @@ -78,7 +79,7 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: str, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] ) -> LLMResult: """ invoke model with parameters in user's own context diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index fc2b41d960..f7484b93fb 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,4 +1,5 @@ import re +from collections.abc import Mapping from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -20,10 +21,18 @@ class InterfaceDict(TypedDict): operation: dict[str, Any] +class OpenAPISpecDict(TypedDict): + openapi: str + info: dict[str, str] + servers: list[dict[str, Any]] + paths: dict[str, Any] + components: dict[str, Any] + + class ApiBasedToolSchemaParser: @staticmethod def parse_openapi_to_tool_bundle( - openapi: dict, extra_info: dict | None = None, warning: dict | None = None + openapi: Mapping[str, Any], extra_info: dict | None = None, warning: dict | None = None ) -> list[ApiToolBundle]: warning = warning if warning is not None else {} extra_info = extra_info if extra_info is not None else {} @@ -277,7 +286,7 @@ class ApiBasedToolSchemaParser: @staticmethod def parse_swagger_to_openapi( swagger: dict, extra_info: dict | None = None, warning: dict | None = None - ) -> dict[str, Any]: + ) -> OpenAPISpecDict: warning = warning or {} """ parse swagger to openapi @@ -293,7 +302,7 @@ class ApiBasedToolSchemaParser: if len(servers) == 0: raise ToolApiSchemaError("No server found in the swagger yaml.") - converted_openapi: dict[str, Any] = { + converted_openapi: OpenAPISpecDict = { "openapi": "3.0.0", "info": { "title": info.get("title", "Swagger"), diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 8e8c5e9c6a..28f1376655 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -3,9 +3,9 @@ from typing import Any from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: @@ -51,7 +51,7 @@ class WorkflowToolConfigurationUtils: def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None: nodes = graph.get("nodes", []) for node in nodes: - if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT: + if node.get("data", {}).get("type") == BuiltinNodeTypes.HUMAN_INPUT: raise WorkflowToolHumanInputNotSupportedError() @classmethod diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index 56faccb407..aef8b3f779 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -22,7 +22,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db from models.account import Account from models.model import App, AppMode @@ -37,6 +37,7 @@ VARIABLE_TO_PARAMETER_TYPE_MAPPING = { VariableEntityType.CHECKBOX: ToolParameter.ToolParameterType.BOOLEAN, VariableEntityType.FILE: ToolParameter.ToolParameterType.FILE, VariableEntityType.FILE_LIST: ToolParameter.ToolParameterType.FILES, + VariableEntityType.JSON_OBJECT: ToolParameter.ToolParameterType.OBJECT, } diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index b2606009a6..9b9aa7a741 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -8,7 +8,6 @@ from typing import Any, cast from sqlalchemy import select from core.db.session_factory import session_factory -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( @@ -18,7 +17,8 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from factories.file_factory import build_from_mapping from models import Account, Tenant from models.model import App, EndUser diff --git a/api/core/trigger/constants.py b/api/core/trigger/constants.py new file mode 100644 index 0000000000..192faa2d3e --- /dev/null +++ b/api/core/trigger/constants.py @@ -0,0 +1,17 @@ +from typing import Final + +TRIGGER_WEBHOOK_NODE_TYPE: Final[str] = "trigger-webhook" +TRIGGER_SCHEDULE_NODE_TYPE: Final[str] = "trigger-schedule" +TRIGGER_PLUGIN_NODE_TYPE: Final[str] = "trigger-plugin" + +TRIGGER_NODE_TYPES: Final[frozenset[str]] = frozenset( + { + TRIGGER_WEBHOOK_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_PLUGIN_NODE_TYPE, + } +) + + +def is_trigger_node_type(node_type: str) -> bool: + return node_type in TRIGGER_NODE_TYPES diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index bd1ff4ebfe..2a133b2b94 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -11,6 +11,11 @@ from typing import Any from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.events import ( PluginTriggerDebugEvent, @@ -19,9 +24,9 @@ from core.trigger.debug.events import ( build_plugin_pool_key, build_webhook_pool_key, ) -from core.workflow.enums import NodeType from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig +from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_redis import redis_client from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at @@ -41,10 +46,10 @@ class TriggerDebugEventPoller(ABC): app_id: str user_id: str tenant_id: str - node_config: Mapping[str, Any] + node_config: NodeConfigDict node_id: str - def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: Mapping[str, Any], node_id: str): + def __init__(self, tenant_id: str, user_id: str, app_id: str, node_config: NodeConfigDict, node_id: str): self.tenant_id = tenant_id self.user_id = user_id self.app_id = app_id @@ -60,7 +65,7 @@ class PluginTriggerDebugEventPoller(TriggerDebugEventPoller): def poll(self) -> TriggerDebugEvent | None: from services.trigger.trigger_service import TriggerService - plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config.get("data", {})) + plugin_trigger_data = TriggerEventNodeData.model_validate(self.node_config["data"], from_attributes=True) provider_id = TriggerProviderID(plugin_trigger_data.provider_id) pool_key: str = build_plugin_pool_key( name=plugin_trigger_data.event_name, @@ -205,21 +210,19 @@ def create_event_poller( if not node_config: raise ValueError("Node data not found for node %s", node_id) node_type = draft_workflow.get_node_type_from_node_config(node_config) - match node_type: - case NodeType.TRIGGER_PLUGIN: - return PluginTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case NodeType.TRIGGER_WEBHOOK: - return WebhookTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case NodeType.TRIGGER_SCHEDULE: - return ScheduleTriggerDebugEventPoller( - tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id - ) - case _: - raise ValueError("unable to create event poller for node type %s", node_type) + if node_type == TRIGGER_PLUGIN_NODE_TYPE: + return PluginTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + if node_type == TRIGGER_WEBHOOK_NODE_TYPE: + return WebhookTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + if node_type == TRIGGER_SCHEDULE_NODE_TYPE: + return ScheduleTriggerDebugEventPoller( + tenant_id=tenant_id, user_id=user_id, app_id=app_id, node_config=node_config, node_id=node_id + ) + raise ValueError("unable to create event poller for node type %s", node_type) def select_trigger_debug_events( diff --git a/api/core/workflow/__init__.py b/api/core/workflow/__init__.py index e69de29bb2..937012dcee 100644 --- a/api/core/workflow/__init__.py +++ b/api/core/workflow/__init__.py @@ -0,0 +1 @@ +"""Core workflow package.""" diff --git a/api/core/workflow/entities/agent.py b/api/core/workflow/entities/agent.py deleted file mode 100644 index 2b4d6db76f..0000000000 --- a/api/core/workflow/entities/agent.py +++ /dev/null @@ -1,8 +0,0 @@ -from pydantic import BaseModel - - -class AgentNodeStrategyInit(BaseModel): - """Agent node strategy initialization data.""" - - name: str - icon: str | None = None diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py new file mode 100644 index 0000000000..ab34263a79 --- /dev/null +++ b/api/core/workflow/node_factory.py @@ -0,0 +1,466 @@ +import importlib +import pkgutil +from collections.abc import Callable, Iterator, Mapping, MutableMapping +from functools import lru_cache +from typing import TYPE_CHECKING, Any, TypeAlias, cast, final + +from sqlalchemy import select +from sqlalchemy.orm import Session +from typing_extensions import override + +from configs import dify_config +from core.app.entities.app_invoke_entities import DifyRunContext +from core.app.llm.model_access import build_dify_model_access +from core.helper.code_executor.code_executor import ( + CodeExecutionError, + CodeExecutor, +) +from core.helper.ssrf_proxy import ssrf_proxy +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.tools.tool_file_manager import ToolFileManager +from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from core.workflow.nodes.agent.plugin_strategy_adapter import ( + PluginAgentStrategyPresentationProvider, + PluginAgentStrategyResolver, +) +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey +from dify_graph.file.file_manager import file_manager +from dify_graph.graph.graph import NodeFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.code.code_node import WorkflowCodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.nodes.document_extractor import UnstructuredApiConfig +from dify_graph.nodes.http_request import build_http_request_config +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from dify_graph.nodes.llm.protocols import TemplateRenderer +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData +from dify_graph.nodes.template_transform.template_renderer import ( + CodeExecutorJinja2TemplateRenderer, +) +from dify_graph.variables.segments import StringSegment +from extensions.ext_database import db +from models.model import Conversation + +if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState + +LATEST_VERSION = "latest" +_START_NODE_TYPES: frozenset[NodeType] = frozenset( + (BuiltinNodeTypes.START, BuiltinNodeTypes.DATASOURCE, *TRIGGER_NODE_TYPES) +) + + +def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] = frozenset()) -> None: + package = importlib.import_module(package_name) + for _, module_name, _ in pkgutil.walk_packages(package.__path__, package.__name__ + "."): + if module_name in excluded_modules: + continue + importlib.import_module(module_name) + + +@lru_cache(maxsize=1) +def register_nodes() -> None: + """Import production node modules so they self-register with ``Node``.""" + _import_node_package("dify_graph.nodes") + _import_node_package("core.workflow.nodes") + + +def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node]]]: + """Return a read-only snapshot of the current production node registry. + + The workflow layer owns node bootstrap because it must compose built-in + `dify_graph.nodes.*` implementations with workflow-local nodes under + `core.workflow.nodes.*`. Keeping this import side effect here avoids + reintroducing registry bootstrapping into lower-level graph primitives. + """ + register_nodes() + return Node.get_node_type_classes_mapping() + + +def resolve_workflow_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + node_mapping = get_node_type_classes_mapping().get(node_type) + if not node_mapping: + raise ValueError(f"No class mapping found for node type: {node_type}") + + latest_node_class = node_mapping.get(LATEST_VERSION) + matched_node_class = node_mapping.get(node_version) + node_class = matched_node_class or latest_node_class + if not node_class: + raise ValueError(f"No latest version class found for node type: {node_type}") + return node_class + + +def is_start_node_type(node_type: NodeType) -> bool: + """Return True when the node type can serve as a workflow entry point.""" + return node_type in _START_NODE_TYPES + + +def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: + """Resolve the default entry node for a persisted top-level workflow graph. + + This workflow-layer helper depends on start-node semantics defined by + `is_start_node_type`, so it intentionally lives next to the node registry + instead of in the raw `dify_graph.entities.graph_config` schema module. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + raise ValueError("nodes in workflow graph must be a list") + + for node in nodes: + if not isinstance(node, Mapping): + continue + + if node.get("type") == "custom-note": + continue + + node_id = node.get("id") + data = node.get("data") + if not isinstance(node_id, str) or not isinstance(data, Mapping): + continue + + node_type = data.get("type") + if isinstance(node_type, str) and is_start_node_type(node_type): + return node_id + + raise ValueError("Unable to determine default root node ID from workflow graph") + + +class _LazyNodeTypeClassesMapping(MutableMapping[NodeType, Mapping[str, type[Node]]]): + """Mutable dict-like view over the current node registry.""" + + def __init__(self) -> None: + self._cached_snapshot: dict[NodeType, Mapping[str, type[Node]]] = {} + self._cached_version = -1 + self._deleted: set[NodeType] = set() + self._overrides: dict[NodeType, Mapping[str, type[Node]]] = {} + + def _snapshot(self) -> dict[NodeType, Mapping[str, type[Node]]]: + current_version = Node.get_registry_version() + if self._cached_version != current_version: + self._cached_snapshot = dict(get_node_type_classes_mapping()) + self._cached_version = current_version + if not self._deleted and not self._overrides: + return self._cached_snapshot + + snapshot = {key: value for key, value in self._cached_snapshot.items() if key not in self._deleted} + snapshot.update(self._overrides) + return snapshot + + def __getitem__(self, key: NodeType) -> Mapping[str, type[Node]]: + return self._snapshot()[key] + + def __setitem__(self, key: NodeType, value: Mapping[str, type[Node]]) -> None: + self._deleted.discard(key) + self._overrides[key] = value + + def __delitem__(self, key: NodeType) -> None: + if key in self._overrides: + del self._overrides[key] + return + if key in self._cached_snapshot: + self._deleted.add(key) + return + raise KeyError(key) + + def __iter__(self) -> Iterator[NodeType]: + return iter(self._snapshot()) + + def __len__(self) -> int: + return len(self._snapshot()) + + +# Keep the canonical node-class mapping in the workflow layer that also bootstraps +# legacy `core.workflow.nodes.*` registrations. +NODE_TYPE_CLASSES_MAPPING: MutableMapping[NodeType, Mapping[str, type[Node]]] = _LazyNodeTypeClassesMapping() + + +LLMCompatibleNodeData: TypeAlias = LLMNodeData | QuestionClassifierNodeData | ParameterExtractorNodeData + + +def fetch_memory( + *, + conversation_id: str | None, + app_id: str, + node_data_memory: MemoryConfig | None, + model_instance: ModelInstance, +) -> TokenBufferMemory | None: + if not node_data_memory or not conversation_id: + return None + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + +class DefaultWorkflowCodeExecutor: + def execute( + self, + *, + language: CodeLanguage, + code: str, + inputs: Mapping[str, Any], + ) -> Mapping[str, Any]: + return CodeExecutor.execute_workflow_code_template( + language=language, + code=code, + inputs=inputs, + ) + + def is_execution_error(self, error: Exception) -> bool: + return isinstance(error, CodeExecutionError) + + +class DefaultLLMTemplateRenderer(TemplateRenderer): + def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=inputs, + ) + return str(result.get("result", "")) + + +@final +class DifyNodeFactory(NodeFactory): + """ + Default implementation of NodeFactory that resolves node classes from the live registry. + """ + + def __init__( + self, + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ) -> None: + self.graph_init_params = graph_init_params + self.graph_runtime_state = graph_runtime_state + self._dify_context = self._resolve_dify_context(graph_init_params.run_context) + self._code_executor: WorkflowCodeExecutor = DefaultWorkflowCodeExecutor() + self._code_limits = CodeNodeLimits( + max_string_length=dify_config.CODE_MAX_STRING_LENGTH, + max_number=dify_config.CODE_MAX_NUMBER, + min_number=dify_config.CODE_MIN_NUMBER, + max_precision=dify_config.CODE_MAX_PRECISION, + max_depth=dify_config.CODE_MAX_DEPTH, + max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, + max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, + max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, + ) + self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) + self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH + self._http_request_http_client = ssrf_proxy + self._http_request_tool_file_manager_factory = ToolFileManager + self._http_request_file_manager = file_manager + self._document_extractor_unstructured_api_config = UnstructuredApiConfig( + api_url=dify_config.UNSTRUCTURED_API_URL, + api_key=dify_config.UNSTRUCTURED_API_KEY or "", + ) + self._http_request_config = build_http_request_config( + max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, + max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, + max_write_timeout=dify_config.HTTP_REQUEST_MAX_WRITE_TIMEOUT, + max_binary_size=dify_config.HTTP_REQUEST_NODE_MAX_BINARY_SIZE, + max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE, + ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY, + ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, + ) + + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._agent_strategy_resolver = PluginAgentStrategyResolver() + self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() + self._agent_runtime_support = AgentRuntimeSupport() + self._agent_message_transformer = AgentMessageTransformer() + + @staticmethod + def _resolve_dify_context(run_context: Mapping[str, Any]) -> DifyRunContext: + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + @override + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: + """ + Create a Node instance from node configuration data using the traditional mapping. + + :param node_config: node configuration dictionary containing type and other data + :return: initialized Node instance + :raises ValueError: if node_config fails NodeConfigDict/BaseNodeData validation + (including pydantic ValidationError, which subclasses ValueError), + if node type is unknown, or if no implementation exists for the resolved version + """ + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_id = typed_node_config["id"] + node_data = typed_node_config["data"] + node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) + node_type = node_data.type + node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = { + BuiltinNodeTypes.CODE: lambda: { + "code_executor": self._code_executor, + "code_limits": self._code_limits, + }, + BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { + "template_renderer": self._template_renderer, + "max_output_length": self._template_transform_max_output_length, + }, + BuiltinNodeTypes.HTTP_REQUEST: lambda: { + "http_request_config": self._http_request_config, + "http_client": self._http_request_http_client, + "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "file_manager": self._http_request_file_manager, + }, + BuiltinNodeTypes.HUMAN_INPUT: lambda: { + "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + }, + BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { + "unstructured_api_config": self._document_extractor_unstructured_api_config, + "http_client": self._http_request_http_client, + }, + BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=True, + ), + BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( + node_class=node_class, + node_data=node_data, + include_http_client=False, + ), + BuiltinNodeTypes.TOOL: lambda: { + "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + }, + BuiltinNodeTypes.AGENT: lambda: { + "strategy_resolver": self._agent_strategy_resolver, + "presentation_provider": self._agent_strategy_presentation_provider, + "runtime_support": self._agent_runtime_support, + "message_transformer": self._agent_message_transformer, + }, + } + node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})() + return node_class( + id=node_id, + config=typed_node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + **node_init_kwargs, + ) + + @staticmethod + def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData: + """ + Re-validate the permissive graph payload with the concrete NodeData model declared by the resolved node class. + """ + return node_class.validate_node_data(node_data) + + @staticmethod + def _resolve_node_class(*, node_type: NodeType, node_version: str) -> type[Node]: + return resolve_workflow_node_class(node_type=node_type, node_version=node_version) + + def _build_llm_compatible_node_init_kwargs( + self, + *, + node_class: type[Node], + node_data: BaseNodeData, + include_http_client: bool, + ) -> dict[str, object]: + validated_node_data = cast( + LLMCompatibleNodeData, + self._validate_resolved_node_data(node_class=node_class, node_data=node_data), + ) + model_instance = self._build_model_instance_for_llm_node(validated_node_data) + node_init_kwargs: dict[str, object] = { + "credentials_provider": self._llm_credentials_provider, + "model_factory": self._llm_model_factory, + "model_instance": model_instance, + "memory": self._build_memory_for_llm_node( + node_data=validated_node_data, + model_instance=model_instance, + ), + } + if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: + node_init_kwargs["template_renderer"] = self._llm_template_renderer + if include_http_client: + node_init_kwargs["http_client"] = self._http_request_http_client + return node_init_kwargs + + def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: + node_data_model = node_data.model + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) + model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) + provider_model_bundle = model_instance.provider_model_bundle + + provider_model = provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, + model_type=ModelType.LLM, + ) + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + completion_params = dict(node_data_model.completion_params) + stop = completion_params.pop("stop", []) + if not isinstance(stop, list): + stop = [] + + model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + model_instance.provider = node_data_model.provider + model_instance.model_name = node_data_model.name + model_instance.credentials = credentials + model_instance.parameters = completion_params + model_instance.stop = tuple(stop) + model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) + return model_instance + + def _build_memory_for_llm_node( + self, + *, + node_data: LLMCompatibleNodeData, + model_instance: ModelInstance, + ) -> PromptMessageMemory | None: + if node_data.memory is None: + return None + + conversation_id_variable = self.graph_runtime_state.variable_pool.get( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + conversation_id = ( + conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None + ) + return fetch_memory( + conversation_id=conversation_id, + app_id=self._dify_context.app_id, + node_data_memory=node_data.memory, + model_instance=model_instance, + ) diff --git a/api/core/workflow/nodes/__init__.py b/api/core/workflow/nodes/__init__.py index 82a37acbfa..d23f80be59 100644 --- a/api/core/workflow/nodes/__init__.py +++ b/api/core/workflow/nodes/__init__.py @@ -1,3 +1 @@ -from core.workflow.enums import NodeType - -__all__ = ["NodeType"] +"""Workflow node implementations that remain under the legacy core.workflow namespace.""" diff --git a/api/core/workflow/nodes/agent/__init__.py b/api/core/workflow/nodes/agent/__init__.py index 95e7cf895b..ba6c667194 100644 --- a/api/core/workflow/nodes/agent/__init__.py +++ b/api/core/workflow/nodes/agent/__init__.py @@ -1,3 +1,4 @@ from .agent_node import AgentNode +from .entities import AgentNodeData -__all__ = ["AgentNode"] +__all__ = ["AgentNode", "AgentNodeData"] diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index ac86b1784f..5699ccf404 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,88 +1,81 @@ from __future__ import annotations -import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any -from packaging.version import Version -from pydantic import ValidationError -from sqlalchemy import select -from sqlalchemy.orm import Session +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from core.agent.entities import AgentToolEntity -from core.agent.plugin_entities import AgentStrategyParameter -from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.model_runtime.utils.encoders import jsonable_encoder -from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ( - ToolIdentity, - ToolInvokeMessage, - ToolParameter, - ToolProviderType, -) -from core.tools.tool_manager import ToolManager -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.enums import ( - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import ( - AgentLogEvent, - NodeEventBase, - NodeRunResult, - StreamChunkEvent, - StreamCompletedEvent, -) -from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayFileSegment, StringSegment -from extensions.ext_database import db -from factories import file_factory -from factories.agent_factory import get_plugin_agent_strategy -from models import ToolFile -from models.model import Conversation -from services.tools.builtin_tools_manage_service import BuiltinToolManageService - -from .exc import ( - AgentInputTypeError, +from .entities import AgentNodeData +from .exceptions import ( AgentInvocationError, AgentMessageTransformError, - AgentNodeError, - AgentVariableNotFoundError, - AgentVariableTypeError, - ToolFileNotFoundError, ) +from .message_transformer import AgentMessageTransformer +from .runtime_support import AgentRuntimeSupport +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver if TYPE_CHECKING: - from core.agent.strategy.plugin import PluginAgentStrategy - from core.plugin.entities.request import InvokeCredentials + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class AgentNode(Node[AgentNodeData]): - """ - Agent Node - """ + node_type = BuiltinNodeTypes.AGENT - node_type = NodeType.AGENT + _strategy_resolver: AgentStrategyResolver + _presentation_provider: AgentStrategyPresentationProvider + _runtime_support: AgentRuntimeSupport + _message_transformer: AgentMessageTransformer + + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + *, + strategy_resolver: AgentStrategyResolver, + presentation_provider: AgentStrategyPresentationProvider, + runtime_support: AgentRuntimeSupport, + message_transformer: AgentMessageTransformer, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._strategy_resolver = strategy_resolver + self._presentation_provider = presentation_provider + self._runtime_support = runtime_support + self._message_transformer = message_transformer @classmethod def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + dify_ctx = self.require_dify_context() + event.extras["agent_strategy"] = { + "name": self.node_data.agent_strategy_name, + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), + } + def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError + dify_ctx = self.require_dify_context() + try: - strategy = get_plugin_agent_strategy( - tenant_id=self.tenant_id, + strategy = self._strategy_resolver.resolve( + tenant_id=dify_ctx.tenant_id, agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, agent_strategy_name=self.node_data.agent_strategy_name, ) @@ -98,30 +91,34 @@ class AgentNode(Node[AgentNodeData]): agent_parameters = strategy.get_parameters() - # get parameters - parameters = self._generate_agent_parameters( + parameters = self._runtime_support.build_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, ) - parameters_for_log = self._generate_agent_parameters( + parameters_for_log = self._runtime_support.build_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, node_data=self.node_data, + strategy=strategy, + tenant_id=dify_ctx.tenant_id, + app_id=dify_ctx.app_id, + invoke_from=dify_ctx.invoke_from, for_log=True, - strategy=strategy, ) - credentials = self._generate_credentials(parameters=parameters) + credentials = self._runtime_support.build_credentials(parameters=parameters) - # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) try: message_stream = strategy.invoke( params=parameters, - user_id=self.user_id, - app_id=self.app_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, credentials=credentials, ) @@ -137,15 +134,18 @@ class AgentNode(Node[AgentNodeData]): return try: - yield from self._transform_message( + yield from self._message_transformer.transform( messages=message_stream, tool_info={ - "icon": self.agent_strategy_icon, + "icon": self._presentation_provider.get_icon( + tenant_id=dify_ctx.tenant_id, + agent_strategy_provider_name=self.node_data.agent_strategy_provider_name, + ), "agent_strategy": self.node_data.agent_strategy_name, }, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, @@ -162,217 +162,17 @@ class AgentNode(Node[AgentNodeData]): ) ) - def _generate_agent_parameters( - self, - *, - agent_parameters: Sequence[AgentStrategyParameter], - variable_pool: VariablePool, - node_data: AgentNodeData, - for_log: bool = False, - strategy: PluginAgentStrategy, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - agent_parameters (Sequence[AgentParameter]): The list of agent parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (AgentNodeData): The data associated with the agent node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - parameter = agent_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - agent_input = node_data.agent_parameters[parameter_name] - match agent_input.type: - case "variable": - variable = variable_pool.get(agent_input.value) # type: ignore - if variable is None: - raise AgentVariableNotFoundError(str(agent_input.value)) - parameter_value = variable.value - case "mixed" | "constant": - # variable_pool.convert_template expects a string template, - # but if passing a dict, convert to JSON string first before rendering - try: - if not isinstance(agent_input.value, str): - parameter_value = json.dumps(agent_input.value, ensure_ascii=False) - else: - parameter_value = str(agent_input.value) - except TypeError: - parameter_value = str(agent_input.value) - segment_group = variable_pool.convert_template(parameter_value) - parameter_value = segment_group.log if for_log else segment_group.text - # variable_pool.convert_template returns a string, - # so we need to convert it back to a dictionary - try: - if not isinstance(agent_input.value, str): - parameter_value = json.loads(parameter_value) - except json.JSONDecodeError: - parameter_value = parameter_value - case _: - raise AgentInputTypeError(agent_input.type) - value = parameter_value - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - value = [tool for tool in value if tool.get("enabled", False)] - value = self._filter_mcp_type_tool(strategy, value) - for tool in value: - if "schemas" in tool: - tool.pop("schemas") - parameters = tool.get("parameters", {}) - if all(isinstance(v, dict) for _, v in parameters.items()): - params = {} - for key, param in parameters.items(): - if param.get("auto", ParamsAutoGenerated.OPEN) in ( - ParamsAutoGenerated.CLOSE, - 0, - ): - value_param = param.get("value", {}) - if value_param and value_param.get("type", "") == "variable": - variable_selector = value_param.get("value") - if not variable_selector: - raise ValueError("Variable selector is missing for a variable-type parameter.") - - variable = variable_pool.get(variable_selector) - if variable is None: - raise AgentVariableNotFoundError(str(variable_selector)) - - params[key] = variable.value - else: - params[key] = value_param.get("value", "") if value_param is not None else None - else: - params[key] = None - parameters = params - tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} - tool["parameters"] = parameters - - if not for_log: - if parameter.type == "array[tools]": - value = cast(list[dict[str, Any]], value) - tool_value = [] - for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) - setting_params = tool.get("settings", {}) - parameters = tool.get("parameters", {}) - manual_input_params = [key for key, value in parameters.items() if value is not None] - - parameters = {**parameters, **setting_params} - entity = AgentToolEntity( - provider_id=tool.get("provider_name", ""), - provider_type=provider_type, - tool_name=tool.get("tool_name", ""), - tool_parameters=parameters, - plugin_unique_identifier=tool.get("plugin_unique_identifier", None), - credential_id=tool.get("credential_id", None), - ) - - extra = tool.get("extra", {}) - - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - runtime_variable_pool: VariablePool | None = None - if node_data.version != "1" or node_data.tool_node_version is not None: - runtime_variable_pool = variable_pool - tool_runtime = ToolManager.get_agent_tool_runtime( - self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool - ) - if tool_runtime.entity.description: - tool_runtime.entity.description.llm = ( - extra.get("description", "") or tool_runtime.entity.description.llm - ) - for tool_runtime_params in tool_runtime.entity.parameters: - tool_runtime_params.form = ( - ToolParameter.ToolParameterForm.FORM - if tool_runtime_params.name in manual_input_params - else tool_runtime_params.form - ) - manual_input_value = {} - if tool_runtime.entity.parameters: - manual_input_value = { - key: value for key, value in parameters.items() if key in manual_input_params - } - runtime_parameters = { - **tool_runtime.runtime.runtime_parameters, - **manual_input_value, - } - tool_value.append( - { - **tool_runtime.entity.model_dump(mode="json"), - "runtime_parameters": runtime_parameters, - "credential_id": tool.get("credential_id", None), - "provider_type": provider_type.value, - } - ) - value = tool_value - if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: - value = cast(dict[str, Any], value) - model_instance, model_schema = self._fetch_model(value) - # memory config - history_prompt_messages = [] - if node_data.memory: - memory = self._fetch_memory(model_instance) - if memory: - prompt_messages = memory.get_history_prompt_messages( - message_limit=node_data.memory.window.size or None - ) - history_prompt_messages = [ - prompt_message.model_dump(mode="json") for prompt_message in prompt_messages - ] - value["history_prompt_messages"] = history_prompt_messages - if model_schema: - # remove structured output feature to support old version agent plugin - model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) - value["entity"] = model_schema.model_dump(mode="json") - else: - value["entity"] = None - result[parameter_name] = value - - return result - - def _generate_credentials( - self, - parameters: dict[str, Any], - ) -> InvokeCredentials: - """ - Generate credentials based on the given agent parameters. - """ - from core.plugin.entities.request import InvokeCredentials - - credentials = InvokeCredentials() - - # generate credentials for tools selector - credentials.tool_credentials = {} - for tool in parameters.get("tools", []): - if tool.get("credential_id"): - try: - identity = ToolIdentity.model_validate(tool.get("identity", {})) - credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) - except ValidationError: - continue - return credentials - @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: AgentNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AgentNodeData.model_validate(node_data) - + _ = graph_config # Explicitly mark as unused result: dict[str, Any] = {} + typed_node_data = node_data for parameter_name in typed_node_data.agent_parameters: input = typed_node_data.agent_parameters[parameter_name] match input.type: @@ -386,365 +186,3 @@ class AgentNode(Node[AgentNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result - - @property - def agent_strategy_icon(self) -> str | None: - """ - Get agent strategy icon - :return: - """ - from core.plugin.impl.plugin import PluginInstaller - - manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name - ) - icon = current_plugin.declaration.icon - except StopIteration: - icon = None - return icon - - def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None: - # get conversation id - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value - - with Session(db.engine, expire_on_commit=False) as session: - stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - conversation = session.scalar(stmt) - - if not conversation: - return None - - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - return memory - - def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( - tenant_id=self.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM - ) - model_name = value.get("model", "") - model_credentials = provider_model_bundle.configuration.get_current_credentials( - model_type=ModelType.LLM, model=model_name - ) - provider_name = provider_model_bundle.configuration.provider.provider - model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( - tenant_id=self.tenant_id, - provider=provider_name, - model_type=ModelType(value.get("model_type", "")), - model=model_name, - ) - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - return model_instance, model_schema - - def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: - if model_schema.features: - for feature in model_schema.features[:]: # Create a copy to safely modify during iteration - try: - AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value - except ValueError: - model_schema.features.remove(feature) - return model_schema - - def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: - """ - Filter MCP type tool - :param strategy: plugin agent strategy - :param tool: tool - :return: filtered tool dict - """ - meta_version = strategy.meta_version - if meta_version and Version(meta_version) > Version("0.0.1"): - return tools - else: - return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] - - def _transform_message( - self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, - node_type: NodeType, - node_id: str, - node_execution_id: str, - ) -> Generator[NodeEventBase, None, None]: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json_list: list[dict | list] = [] - - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage = LLMUsage.empty_usage() - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileNotFoundError(tool_file_id) - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if node_type == NodeType.AGENT: - if isinstance(message.message.json_object, dict): - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } - else: - msg_metadata = {} - llm_usage = LLMUsage.empty_usage() - agent_execution_metadata = {} - if message.message.json_object: - json_list.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - stream_text = f"Link: {message.message.text}\n" - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise AgentVariableTypeError( - "When 'stream' is True, 'variable_value' must be a string.", - variable_name=variable_name, - expected_type="str", - actual_type=type(variable_value).__name__, - ) - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise AgentNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - message_id=message.message.id, - node_execution_id=node_execution_id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.message_id == agent_log.message_id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 1: append each agent log as its own dict. - if agent_logs: - for log in agent_logs: - json_output.append( - { - "id": log.message_id, - "parent_id": log.parent_id, - "error": log.error, - "status": log.status, - "data": log.data, - "label": log.label, - "metadata": log.metadata, - "node_id": log.node_id, - } - ) - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json_list: - json_output.extend(json_list) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[node_id, var_name], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "text": text, - "usage": jsonable_encoder(llm_usage), - "files": ArrayFileSegment(value=files), - "json": json_output, - **variables, - }, - metadata={ - **agent_execution_metadata, - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, - }, - inputs=parameters_for_log, - llm_usage=llm_usage, - ) - ) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 985ee5eef2..91fed39795 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -5,13 +5,15 @@ from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): - agent_strategy_provider_name: str # redundancy + type: NodeType = BuiltinNodeTypes.AGENT + agent_strategy_provider_name: str agent_strategy_name: str - agent_strategy_label: str # redundancy + agent_strategy_label: str memory: MemoryConfig | None = None # The version of the tool parameter. # If this value is None, it indicates this is a previous version diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exceptions.py similarity index 90% rename from api/core/workflow/nodes/agent/exc.py rename to api/core/workflow/nodes/agent/exceptions.py index ba2c83d8a6..944f5f0b20 100644 --- a/api/core/workflow/nodes/agent/exc.py +++ b/api/core/workflow/nodes/agent/exceptions.py @@ -119,14 +119,3 @@ class AgentVariableTypeError(AgentNodeError): self.expected_type = expected_type self.actual_type = actual_type super().__init__(message) - - -class AgentMaxIterationError(AgentNodeError): - """Exception raised when the agent exceeds the maximum iteration limit.""" - - def __init__(self, max_iteration: int): - self.max_iteration = max_iteration - super().__init__( - f"Agent exceeded the maximum iteration limit of {max_iteration}. " - f"The agent was unable to complete the task within the allowed number of iterations." - ) diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py new file mode 100644 index 0000000000..f58a5665f4 --- /dev/null +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping +from typing import Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( + AgentLogEvent, + NodeEventBase, + NodeRunResult, + StreamChunkEvent, + StreamCompletedEvent, +) +from dify_graph.variables.segments import ArrayFileSegment +from extensions.ext_database import db +from factories import file_factory +from models import ToolFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError + + +class AgentMessageTransformer: + def transform( + self, + *, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator[NodeEventBase, None, None]: + from core.plugin.impl.plugin import PluginInstaller + + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json_list: list[dict | list] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage = LLMUsage.empty_usage() + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=message.message.text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == BuiltinNodeTypes.AGENT: + if isinstance(message.message.json_object, dict): + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata)) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + else: + llm_usage = LLMUsage.empty_usage() + agent_execution_metadata = {} + if message.message.json_object: + json_list.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk=stream_text, + is_final=False, + ) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield StreamChunkEvent( + selector=[node_id, variable_name], + chunk=variable_value, + is_final=False, + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, dict) + if "file" not in message.meta: + raise AgentNodeError("File message is missing 'file' key in meta") + + if not isinstance(message.meta["file"], File): + raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + message_id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + for log in agent_logs: + if log.message_id == agent_log.message_id: + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + json_output: list[dict[str, Any] | list[Any]] = [] + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.message_id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + if json_list: + json_output.extend(json_list) + else: + json_output.append({"data": []}) + + yield StreamChunkEvent( + selector=[node_id, "text"], + chunk="", + is_final=True, + ) + + for var_name in variables: + yield StreamChunkEvent( + selector=[node_id, var_name], + chunk="", + is_final=True, + ) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={ + "text": text, + "usage": jsonable_encoder(llm_usage), + "files": ArrayFileSegment(value=files), + "json": json_output, + **variables, + }, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/plugin_strategy_adapter.py b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py new file mode 100644 index 0000000000..1fc427ad6c --- /dev/null +++ b/api/core/workflow/nodes/agent/plugin_strategy_adapter.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from factories.agent_factory import get_plugin_agent_strategy + +from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver, ResolvedAgentStrategy + + +class PluginAgentStrategyResolver(AgentStrategyResolver): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: + return get_plugin_agent_strategy( + tenant_id=tenant_id, + agent_strategy_provider_name=agent_strategy_provider_name, + agent_strategy_name=agent_strategy_name, + ) + + +class PluginAgentStrategyPresentationProvider(AgentStrategyPresentationProvider): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: + from core.plugin.impl.plugin import PluginInstaller + + manager = PluginInstaller() + try: + plugins = manager.list_plugins(tenant_id) + except Exception: + return None + + try: + current_plugin = next( + plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == agent_strategy_provider_name + ) + except StopIteration: + return None + + return current_plugin.declaration.icon diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py new file mode 100644 index 0000000000..2ff7c964b9 --- /dev/null +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import json +from collections.abc import Sequence +from typing import Any, cast + +from packaging.version import Version +from pydantic import ValidationError +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.agent.entities import AgentToolEntity +from core.agent.plugin_entities import AgentStrategyParameter +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.plugin.entities.request import InvokeCredentials +from core.provider_manager import ProviderManager +from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType +from core.tools.tool_manager import ToolManager +from dify_graph.enums import SystemVariableKey +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import StringSegment +from extensions.ext_database import db +from models.model import Conversation + +from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated +from .exceptions import AgentInputTypeError, AgentVariableNotFoundError +from .strategy_protocols import ResolvedAgentStrategy + + +class AgentRuntimeSupport: + def build_parameters( + self, + *, + agent_parameters: Sequence[AgentStrategyParameter], + variable_pool: VariablePool, + node_data: AgentNodeData, + strategy: ResolvedAgentStrategy, + tenant_id: str, + app_id: str, + invoke_from: Any, + for_log: bool = False, + ) -> dict[str, Any]: + agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters} + + result: dict[str, Any] = {} + for parameter_name in node_data.agent_parameters: + parameter = agent_parameters_dictionary.get(parameter_name) + if not parameter: + result[parameter_name] = None + continue + + agent_input = node_data.agent_parameters[parameter_name] + match agent_input.type: + case "variable": + variable = variable_pool.get(agent_input.value) # type: ignore[arg-type] + if variable is None: + raise AgentVariableNotFoundError(str(agent_input.value)) + parameter_value = variable.value + case "mixed" | "constant": + try: + if not isinstance(agent_input.value, str): + parameter_value = json.dumps(agent_input.value, ensure_ascii=False) + else: + parameter_value = str(agent_input.value) + except TypeError: + parameter_value = str(agent_input.value) + + segment_group = variable_pool.convert_template(parameter_value) + parameter_value = segment_group.log if for_log else segment_group.text + try: + if not isinstance(agent_input.value, str): + parameter_value = json.loads(parameter_value) + except json.JSONDecodeError: + parameter_value = parameter_value + case _: + raise AgentInputTypeError(agent_input.type) + + value = parameter_value + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + value = [tool for tool in value if tool.get("enabled", False)] + value = self._filter_mcp_type_tool(strategy, value) + for tool in value: + if "schemas" in tool: + tool.pop("schemas") + parameters = tool.get("parameters", {}) + if all(isinstance(v, dict) for _, v in parameters.items()): + params = {} + for key, param in parameters.items(): + if param.get("auto", ParamsAutoGenerated.OPEN) in ( + ParamsAutoGenerated.CLOSE, + 0, + ): + value_param = param.get("value", {}) + if value_param and value_param.get("type", "") == "variable": + variable_selector = value_param.get("value") + if not variable_selector: + raise ValueError("Variable selector is missing for a variable-type parameter.") + + variable = variable_pool.get(variable_selector) + if variable is None: + raise AgentVariableNotFoundError(str(variable_selector)) + + params[key] = variable.value + else: + params[key] = value_param.get("value", "") if value_param is not None else None + else: + params[key] = None + parameters = params + tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()} + tool["parameters"] = parameters + + if not for_log: + if parameter.type == "array[tools]": + value = cast(list[dict[str, Any]], value) + tool_value = [] + for tool in value: + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + setting_params = tool.get("settings", {}) + parameters = tool.get("parameters", {}) + manual_input_params = [key for key, value in parameters.items() if value is not None] + + parameters = {**parameters, **setting_params} + entity = AgentToolEntity( + provider_id=tool.get("provider_name", ""), + provider_type=provider_type, + tool_name=tool.get("tool_name", ""), + tool_parameters=parameters, + plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), + ) + + extra = tool.get("extra", {}) + + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version is not None: + runtime_variable_pool = variable_pool + tool_runtime = ToolManager.get_agent_tool_runtime( + tenant_id, + app_id, + entity, + invoke_from, + runtime_variable_pool, + ) + if tool_runtime.entity.description: + tool_runtime.entity.description.llm = ( + extra.get("description", "") or tool_runtime.entity.description.llm + ) + for tool_runtime_params in tool_runtime.entity.parameters: + tool_runtime_params.form = ( + ToolParameter.ToolParameterForm.FORM + if tool_runtime_params.name in manual_input_params + else tool_runtime_params.form + ) + manual_input_value = {} + if tool_runtime.entity.parameters: + manual_input_value = { + key: value for key, value in parameters.items() if key in manual_input_params + } + runtime_parameters = { + **tool_runtime.runtime.runtime_parameters, + **manual_input_value, + } + tool_value.append( + { + **tool_runtime.entity.model_dump(mode="json"), + "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), + "provider_type": provider_type.value, + } + ) + value = tool_value + if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: + value = cast(dict[str, Any], value) + model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + history_prompt_messages = [] + if node_data.memory: + memory = self.fetch_memory( + variable_pool=variable_pool, + app_id=app_id, + model_instance=model_instance, + ) + if memory: + prompt_messages = memory.get_history_prompt_messages( + message_limit=node_data.memory.window.size or None + ) + history_prompt_messages = [ + prompt_message.model_dump(mode="json") for prompt_message in prompt_messages + ] + value["history_prompt_messages"] = history_prompt_messages + if model_schema: + model_schema = self._remove_unsupported_model_features_for_old_version(model_schema) + value["entity"] = model_schema.model_dump(mode="json") + else: + value["entity"] = None + result[parameter_name] = value + + return result + + def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials: + credentials = InvokeCredentials() + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if not tool.get("credential_id"): + continue + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + except ValidationError: + continue + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + return credentials + + def fetch_memory( + self, + *, + variable_pool: VariablePool, + app_id: str, + model_instance: ModelInstance, + ) -> TokenBufferMemory | None: + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + return TokenBufferMemory(conversation=conversation, model_instance=model_instance) + + def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: + provider_manager = ProviderManager() + provider_model_bundle = provider_manager.get_provider_model_bundle( + tenant_id=tenant_id, + provider=value.get("provider", ""), + model_type=ModelType.LLM, + ) + model_name = value.get("model", "") + model_credentials = provider_model_bundle.configuration.get_current_credentials( + model_type=ModelType.LLM, + model=model_name, + ) + provider_name = provider_model_bundle.configuration.provider.provider + model_type_instance = provider_model_bundle.model_type_instance + model_instance = ModelManager().get_model_instance( + tenant_id=tenant_id, + provider=provider_name, + model_type=ModelType(value.get("model_type", "")), + model=model_name, + ) + model_schema = model_type_instance.get_model_schema(model_name, model_credentials) + return model_instance, model_schema + + @staticmethod + def _remove_unsupported_model_features_for_old_version(model_schema: AIModelEntity) -> AIModelEntity: + if model_schema.features: + for feature in model_schema.features[:]: + try: + AgentOldVersionModelFeatures(feature.value) + except ValueError: + model_schema.features.remove(feature) + return model_schema + + @staticmethod + def _filter_mcp_type_tool( + strategy: ResolvedAgentStrategy, + tools: list[dict[str, Any]], + ) -> list[dict[str, Any]]: + meta_version = strategy.meta_version + if meta_version and Version(meta_version) > Version("0.0.1"): + return tools + return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP] diff --git a/api/core/workflow/nodes/agent/strategy_protocols.py b/api/core/workflow/nodes/agent/strategy_protocols.py new file mode 100644 index 0000000000..643d916d15 --- /dev/null +++ b/api/core/workflow/nodes/agent/strategy_protocols.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from collections.abc import Generator, Sequence +from typing import Any, Protocol + +from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials +from core.tools.entities.tool_entities import ToolInvokeMessage + + +class ResolvedAgentStrategy(Protocol): + meta_version: str | None + + def get_parameters(self) -> Sequence[AgentStrategyParameter]: ... + + def invoke( + self, + *, + params: dict[str, Any], + user_id: str, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + credentials: InvokeCredentials | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: ... + + +class AgentStrategyResolver(Protocol): + def resolve( + self, + *, + tenant_id: str, + agent_strategy_provider_name: str, + agent_strategy_name: str, + ) -> ResolvedAgentStrategy: ... + + +class AgentStrategyPresentationProvider(Protocol): + def get_icon(self, *, tenant_id: str, agent_strategy_provider_name: str) -> str | None: ... diff --git a/api/core/workflow/nodes/datasource/__init__.py b/api/core/workflow/nodes/datasource/__init__.py index f6ec44cb77..2e9bed5e00 100644 --- a/api/core/workflow/nodes/datasource/__init__.py +++ b/api/core/workflow/nodes/datasource/__init__.py @@ -1,3 +1 @@ -from .datasource_node import DatasourceNode - -__all__ = ["DatasourceNode"] +"""Datasource workflow node package.""" diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 17f8bcb2db..44f4a23a5a 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,26 +1,22 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.repositories.datasource_manager_protocol import ( - DatasourceManagerProtocol, - DatasourceParameter, - OnlineDriveDownloadFileParam, -) +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from .entities import DatasourceNodeData +from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -28,16 +24,15 @@ class DatasourceNode(Node[DatasourceNodeData]): Datasource Node """ - node_type = NodeType.DATASOURCE + node_type = BuiltinNodeTypes.DATASOURCE execution_type = NodeExecutionType.ROOT def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - datasource_manager: DatasourceManagerProtocol, ): super().__init__( id=id, @@ -45,13 +40,18 @@ class DatasourceNode(Node[DatasourceNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self.datasource_manager = datasource_manager + self.datasource_manager = DatasourceManager + + def populate_start_event(self, event) -> None: + event.provider_id = f"{self.node_data.plugin_id}/{self.node_data.provider_name}" + event.provider_type = self.node_data.provider_type def _run(self) -> Generator: """ Run the datasource node """ + dify_ctx = self.require_dify_context() node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) @@ -75,7 +75,7 @@ class DatasourceNode(Node[DatasourceNodeData]): datasource_info["icon"] = self.datasource_manager.get_icon_url( provider_id=provider_id, datasource_name=node_data.datasource_name or "", - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, datasource_type=datasource_type.value, ) @@ -104,11 +104,11 @@ class DatasourceNode(Node[DatasourceNodeData]): yield from self.datasource_manager.stream_node_events( node_id=self._node_id, - user_id=self.user_id, + user_id=dify_ctx.user_id, datasource_name=node_data.datasource_name or "", datasource_type=datasource_type.value, provider_id=provider_id, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, provider=node_data.provider_name, plugin_id=node_data.plugin_id, credential_id=credential_id, @@ -136,7 +136,7 @@ class DatasourceNode(Node[DatasourceNodeData]): raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=self.tenant_id + file_id=related_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) @@ -180,7 +180,7 @@ class DatasourceNode(Node[DatasourceNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DatasourceNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -189,11 +189,10 @@ class DatasourceNode(Node[DatasourceNodeData]): :param node_data: node data :return: """ - typed_node_data = DatasourceNodeData.model_validate(node_data) result = {} - if typed_node_data.datasource_parameters: - for parameter_name in typed_node_data.datasource_parameters: - input = typed_node_data.datasource_parameters[parameter_name] + if node_data.datasource_parameters: + for parameter_name in node_data.datasource_parameters: + input = node_data.datasource_parameters[parameter_name] match input.type: case "mixed": assert isinstance(input.value, str) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 4802d3ed98..65864474b0 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -3,7 +3,8 @@ from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class DatasourceEntity(BaseModel): @@ -16,6 +17,8 @@ class DatasourceEntity(BaseModel): class DatasourceNodeData(BaseNodeData, DatasourceEntity): + type: NodeType = BuiltinNodeTypes.DATASOURCE + class DatasourceInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] @@ -39,3 +42,14 @@ class DatasourceNodeData(BaseNodeData, DatasourceEntity): return typ datasource_parameters: dict[str, DatasourceInput] | None = None + + +class DatasourceParameter(BaseModel): + workspace_id: str + page_id: str + type: str + + +class OnlineDriveDownloadFileParam(BaseModel): + id: str + bucket: str diff --git a/api/core/workflow/repositories/datasource_manager_protocol.py b/api/core/workflow/nodes/datasource/protocols.py similarity index 71% rename from api/core/workflow/repositories/datasource_manager_protocol.py rename to api/core/workflow/nodes/datasource/protocols.py index 4acf486bef..c006e0885c 100644 --- a/api/core/workflow/repositories/datasource_manager_protocol.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,25 +1,10 @@ from collections.abc import Generator from typing import Any, Protocol -from pydantic import BaseModel +from dify_graph.file import File +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent -from core.workflow.file import File -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent - - -class DatasourceParameter(BaseModel): - workspace_id: str - page_id: str - type: str - - -class OnlineDriveDownloadFileParam(BaseModel): - id: str - bucket: str - - -class DatasourceFinal(BaseModel): - data: dict[str, Any] | None = None +from .entities import DatasourceParameter, OnlineDriveDownloadFileParam class DatasourceManagerProtocol(Protocol): diff --git a/api/core/workflow/nodes/knowledge_index/__init__.py b/api/core/workflow/nodes/knowledge_index/__init__.py index 23897a1e42..efc6a57b3d 100644 --- a/api/core/workflow/nodes/knowledge_index/__init__.py +++ b/api/core/workflow/nodes/knowledge_index/__init__.py @@ -1,3 +1,5 @@ -from .knowledge_index_node import KnowledgeIndexNode +"""Knowledge index workflow node package.""" -__all__ = ["KnowledgeIndexNode"] +KNOWLEDGE_INDEX_NODE_TYPE = "knowledge-index" + +__all__ = ["KNOWLEDGE_INDEX_NODE_TYPE"] diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index bfeb9b5b79..8d2e9bf3cb 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -2,8 +2,11 @@ from typing import Literal, Union from pydantic import BaseModel +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.nodes.base import BaseNodeData +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class RerankingModelConfig(BaseModel): @@ -155,8 +158,8 @@ class KnowledgeIndexNodeData(BaseNodeData): Knowledge index Node Data. """ - type: str = "knowledge-index" + type: NodeType = KNOWLEDGE_INDEX_NODE_TYPE chunk_structure: str index_chunk_variable_selector: list[str] indexing_technique: str | None = None - summary_index_setting: dict | None = None + summary_index_setting: SummaryIndexSettingDict | None = None diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 2aff953bc6..4ea9091c5b 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -1,66 +1,68 @@ -import concurrent.futures -import datetime import logging -import time from collections.abc import Mapping -from typing import Any +from typing import TYPE_CHECKING, Any -from flask import current_app -from sqlalchemy import func, select - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType, SystemVariableKey -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.runtime import VariablePool -from extensions.ext_database import db -from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary -from services.summary_index_service import SummaryIndexService -from tasks.generate_summary_index_task import generate_summary_index_task +from core.rag.index_processor.index_processor import IndexProcessor +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict +from core.rag.summary_index.summary_index import SummaryIndex +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, SystemVariableKey +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template from .entities import KnowledgeIndexNodeData from .exc import ( KnowledgeIndexNodeError, ) -logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState -default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 2, - "score_threshold_enabled": False, -} +logger = logging.getLogger(__name__) +_INVOKE_FROM_DEBUGGER = "debugger" class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): - node_type = NodeType.KNOWLEDGE_INDEX + node_type = KNOWLEDGE_INDEX_NODE_TYPE execution_type = NodeExecutionType.RESPONSE + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + ) -> None: + super().__init__(id, config, graph_init_params, graph_runtime_state) + self.index_processor = IndexProcessor() + self.summary_index_service = SummaryIndex() + def _run(self) -> NodeRunResult: # type: ignore node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) - if not dataset_id: + + # get dataset id as string + dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") - dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first() - if not dataset: - raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.") + dataset_id: str = dataset_id_segment.value + + # get document id as string (may be empty when not provided) + document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - if invoke_from: - is_preview = invoke_from.value == InvokeFrom.DEBUGGER - else: - is_preview = False + invoke_from_value = str(invoke_from.value) if invoke_from else None + is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER + chunks = variable.value variables = {"chunks": chunks} if not chunks: @@ -68,52 +70,49 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required." ) - # index knowledge try: + summary_index_setting = node_data.summary_index_setting if is_preview: # Preview mode: generate summaries for chunks directly without saving to database # Format preview and generate summaries on-the-fly # Get indexing_technique and summary_index_setting from node_data (workflow graph config) # or fallback to dataset if not available in node_data - indexing_technique = node_data.indexing_technique or dataset.indexing_technique - summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting - # Try to get document language if document_id is available - doc_language = None - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) - if document_id: - document = db.session.query(Document).filter_by(id=document_id.value).first() - if document and document.doc_language: - doc_language = document.doc_language - - outputs = self._get_preview_output_with_summaries( - node_data.chunk_structure, - chunks, - dataset=dataset, - indexing_technique=indexing_technique, - summary_index_setting=summary_index_setting, - doc_language=doc_language, + outputs = self.index_processor.get_preview_output( + chunks, dataset_id, document_id, node_data.chunk_structure, summary_index_setting ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, - outputs=outputs, + outputs=outputs.model_dump(exclude_none=True), ) + + original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) + batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + if not batch: + raise KnowledgeIndexNodeError("Batch is required.") + results = self._invoke_knowledge_index( - dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool + dataset_id=dataset_id, + document_id=document_id, + original_document_id=original_document_id_segment.value if original_document_id_segment else "", + is_preview=is_preview, + batch=batch.value, + chunks=chunks, + summary_index_setting=summary_index_setting, ) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=results) except KnowledgeIndexNodeError as e: - logger.warning("Error when running knowledge index node") + logger.warning("Error when running knowledge index node", exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__, ) - # Temporary handle all exceptions from DatasetRetrieval class here. except Exception as e: + logger.error(e, exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -123,392 +122,23 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): def _invoke_knowledge_index( self, - dataset: Dataset, - node_data: KnowledgeIndexNodeData, + dataset_id: str, + document_id: str, + original_document_id: str, + is_preview: bool, + batch: Any, chunks: Mapping[str, Any], - variable_pool: VariablePool, - ) -> Any: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + summary_index_setting: SummaryIndexSettingDict | None = None, + ): if not document_id: - raise KnowledgeIndexNodeError("Document ID is required.") - original_document_id = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) - if not batch: - raise KnowledgeIndexNodeError("Batch is required.") - document = db.session.query(Document).filter_by(id=document_id.value).first() - if not document: - raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.") - doc_id_value = document.id - ds_id_value = dataset.id - dataset_name_value = dataset.name - document_name_value = document.name - created_at_value = document.created_at - # chunk nodes by chunk size - indexing_start_at = time.perf_counter() - index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor() - if original_document_id: - segments = db.session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == original_document_id.value) - ).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - - for segment in segments: - db.session.delete(segment) - db.session.commit() - index_processor.index(dataset, document, chunks) - indexing_end_at = time.perf_counter() - document.indexing_latency = indexing_end_at - indexing_start_at - # update document status - document.indexing_status = "completed" - document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - document.word_count = ( - db.session.query(func.sum(DocumentSegment.word_count)) - .where( - DocumentSegment.document_id == doc_id_value, - DocumentSegment.dataset_id == ds_id_value, - ) - .scalar() + raise KnowledgeIndexNodeError("document_id is required.") + rst = self.index_processor.index_and_clean( + dataset_id, document_id, original_document_id, chunks, batch, summary_index_setting ) - # Update need_summary based on dataset's summary_index_setting - if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True: - document.need_summary = True - else: - document.need_summary = False - db.session.add(document) - # update document segment status - db.session.query(DocumentSegment).where( - DocumentSegment.document_id == doc_id_value, - DocumentSegment.dataset_id == ds_id_value, - ).update( - { - DocumentSegment.status: "completed", - DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), - } + self.summary_index_service.generate_and_vectorize_summary( + dataset_id, document_id, is_preview, summary_index_setting ) - - db.session.commit() - - # Generate summary index if enabled - self._handle_summary_index_generation(dataset, document, variable_pool) - - return { - "dataset_id": ds_id_value, - "dataset_name": dataset_name_value, - "batch": batch.value, - "document_id": doc_id_value, - "document_name": document_name_value, - "created_at": created_at_value.timestamp(), - "display_status": "completed", - } - - def _handle_summary_index_generation( - self, - dataset: Dataset, - document: Document, - variable_pool: VariablePool, - ) -> None: - """ - Handle summary index generation based on mode (debug/preview or production). - - Args: - dataset: Dataset containing the document - document: Document to generate summaries for - variable_pool: Variable pool to check invoke_from - """ - # Only generate summary index for high_quality indexing technique - if dataset.indexing_technique != "high_quality": - return - - # Check if summary index is enabled - summary_index_setting = dataset.summary_index_setting - if not summary_index_setting or not summary_index_setting.get("enable"): - return - - # Skip qa_model documents - if document.doc_form == "qa_model": - return - - # Determine if in preview/debug mode - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER - - if is_preview: - try: - # Query segments that need summary generation - query = db.session.query(DocumentSegment).filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, - ) - segments = query.all() - - if not segments: - logger.info("No segments found for document %s", document.id) - return - - # Filter segments based on mode - segments_to_process = [] - for segment in segments: - # Skip if summary already exists - existing_summary = ( - db.session.query(DocumentSegmentSummary) - .filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed") - .first() - ) - if existing_summary: - continue - - # For parent-child mode, all segments are parent chunks, so process all - segments_to_process.append(segment) - - if not segments_to_process: - logger.info("No segments need summary generation for document %s", document.id) - return - - # Use ThreadPoolExecutor for concurrent generation - flask_app = current_app._get_current_object() # type: ignore - max_workers = min(10, len(segments_to_process)) # Limit to 10 workers - - def process_segment(segment: DocumentSegment) -> None: - """Process a single segment in a thread with Flask app context.""" - with flask_app.app_context(): - try: - SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting) - except Exception: - logger.exception( - "Failed to generate summary for segment %s", - segment.id, - ) - # Continue processing other segments - - with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(process_segment, segment) for segment in segments_to_process] - # Wait for all tasks to complete - concurrent.futures.wait(futures) - - logger.info( - "Successfully generated summary index for %s segments in document %s", - len(segments_to_process), - document.id, - ) - except Exception: - logger.exception("Failed to generate summary index for document %s", document.id) - # Don't fail the entire indexing process if summary generation fails - else: - # Production mode: asynchronous generation - logger.info( - "Queuing summary index generation task for document %s (production mode)", - document.id, - ) - try: - generate_summary_index_task.delay(dataset.id, document.id, None) - logger.info("Summary index generation task queued for document %s", document.id) - except Exception: - logger.exception( - "Failed to queue summary index generation task for document %s", - document.id, - ) - # Don't fail the entire indexing process if task queuing fails - - def _get_preview_output_with_summaries( - self, - chunk_structure: str, - chunks: Any, - dataset: Dataset, - indexing_technique: str | None = None, - summary_index_setting: dict | None = None, - doc_language: str | None = None, - ) -> Mapping[str, Any]: - """ - Generate preview output with summaries for chunks in preview mode. - This method generates summaries on-the-fly without saving to database. - - Args: - chunk_structure: Chunk structure type - chunks: Chunks to generate preview for - dataset: Dataset object (for tenant_id) - indexing_technique: Indexing technique from node config or dataset - summary_index_setting: Summary index setting from node config or dataset - doc_language: Optional document language to ensure summary is generated in the correct language - """ - index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - preview_output = index_processor.format_preview(chunks) - - # Check if summary index is enabled - if indexing_technique != "high_quality": - return preview_output - - if not summary_index_setting or not summary_index_setting.get("enable"): - return preview_output - - # Generate summaries for chunks - if "preview" in preview_output and isinstance(preview_output["preview"], list): - chunk_count = len(preview_output["preview"]) - logger.info( - "Generating summaries for %s chunks in preview mode (dataset: %s)", - chunk_count, - dataset.id, - ) - # Use ParagraphIndexProcessor's generate_summary method - from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor - - # Get Flask app for application context in worker threads - flask_app = None - try: - flask_app = current_app._get_current_object() # type: ignore - except RuntimeError: - logger.warning("No Flask application context available, summary generation may fail") - - def generate_summary_for_chunk(preview_item: dict) -> None: - """Generate summary for a single chunk.""" - if "content" in preview_item: - # Set Flask application context in worker thread - if flask_app: - with flask_app.app_context(): - summary, _ = ParagraphIndexProcessor.generate_summary( - tenant_id=dataset.tenant_id, - text=preview_item["content"], - summary_index_setting=summary_index_setting, - document_language=doc_language, - ) - if summary: - preview_item["summary"] = summary - else: - # Fallback: try without app context (may fail) - summary, _ = ParagraphIndexProcessor.generate_summary( - tenant_id=dataset.tenant_id, - text=preview_item["content"], - summary_index_setting=summary_index_setting, - document_language=doc_language, - ) - if summary: - preview_item["summary"] = summary - - # Generate summaries concurrently using ThreadPoolExecutor - # Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total) - timeout_seconds = min(300, 60 * len(preview_output["preview"])) - errors: list[Exception] = [] - - with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor: - futures = [ - executor.submit(generate_summary_for_chunk, preview_item) - for preview_item in preview_output["preview"] - ] - # Wait for all tasks to complete with timeout - done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds) - - # Cancel tasks that didn't complete in time - if not_done: - timeout_error_msg = ( - f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s" - ) - logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg) - # In preview mode, timeout is also an error - errors.append(TimeoutError(timeout_error_msg)) - for future in not_done: - future.cancel() - # Wait a bit for cancellation to take effect - concurrent.futures.wait(not_done, timeout=5) - - # Collect exceptions from completed futures - for future in done: - try: - future.result() # This will raise any exception that occurred - except Exception as e: - logger.exception("Error in summary generation future") - errors.append(e) - - # In preview mode, if there are any errors, fail the request - if errors: - error_messages = [str(e) for e in errors] - error_summary = ( - f"Failed to generate summaries for {len(errors)} chunk(s). " - f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors - ) - if len(errors) > 3: - error_summary += f" (and {len(errors) - 3} more)" - logger.error("Summary generation failed in preview mode: %s", error_summary) - raise KnowledgeIndexNodeError(error_summary) - - completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None) - logger.info( - "Completed summary generation for preview chunks: %s/%s succeeded", - completed_count, - len(preview_output["preview"]), - ) - - return preview_output - - def _get_preview_output( - self, - chunk_structure: str, - chunks: Any, - dataset: Dataset | None = None, - variable_pool: VariablePool | None = None, - ) -> Mapping[str, Any]: - index_processor = IndexProcessorFactory(chunk_structure).init_index_processor() - preview_output = index_processor.format_preview(chunks) - - # If dataset is provided, try to enrich preview with summaries - if dataset and variable_pool: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) - if document_id: - document = db.session.query(Document).filter_by(id=document_id.value).first() - if document: - # Query summaries for this document - summaries = ( - db.session.query(DocumentSegmentSummary) - .filter_by( - dataset_id=dataset.id, - document_id=document.id, - status="completed", - enabled=True, - ) - .all() - ) - - if summaries: - # Create a map of segment content to summary for matching - # Use content matching as chunks in preview might not be indexed yet - summary_by_content = {} - for summary in summaries: - segment = ( - db.session.query(DocumentSegment) - .filter_by(id=summary.chunk_id, dataset_id=dataset.id) - .first() - ) - if segment: - # Normalize content for matching (strip whitespace) - normalized_content = segment.content.strip() - summary_by_content[normalized_content] = summary.summary_content - - # Enrich preview with summaries by content matching - if "preview" in preview_output and isinstance(preview_output["preview"], list): - matched_count = 0 - for preview_item in preview_output["preview"]: - if "content" in preview_item: - # Normalize content for matching - normalized_chunk_content = preview_item["content"].strip() - if normalized_chunk_content in summary_by_content: - preview_item["summary"] = summary_by_content[normalized_chunk_content] - matched_count += 1 - - if matched_count > 0: - logger.info( - "Enriched preview with %s existing summaries (dataset: %s, document: %s)", - matched_count, - dataset.id, - document.id, - ) - - return preview_output + return rst @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/knowledge_index/protocols.py b/api/core/workflow/nodes/knowledge_index/protocols.py new file mode 100644 index 0000000000..bb52123082 --- /dev/null +++ b/api/core/workflow/nodes/knowledge_index/protocols.py @@ -0,0 +1,47 @@ +from collections.abc import Mapping +from typing import Any, Protocol + +from pydantic import BaseModel, Field + + +class PreviewItem(BaseModel): + content: str | None = Field(default=None) + child_chunks: list[str] | None = Field(default=None) + summary: str | None = Field(default=None) + + +class QaPreview(BaseModel): + answer: str | None = Field(default=None) + question: str | None = Field(default=None) + + +class Preview(BaseModel): + chunk_structure: str + parent_mode: str | None = Field(default=None) + preview: list[PreviewItem] = Field(default_factory=list) + qa_preview: list[QaPreview] = Field(default_factory=list) + total_segments: int + + +class IndexProcessorProtocol(Protocol): + def format_preview(self, chunk_structure: str, chunks: Any) -> Preview: ... + + def index_and_clean( + self, + dataset_id: str, + document_id: str, + original_document_id: str, + chunks: Mapping[str, Any], + batch: Any, + summary_index_setting: dict | None = None, + ) -> dict[str, Any]: ... + + def get_preview_output( + self, chunks: Any, dataset_id: str, document_id: str, chunk_structure: str, summary_index_setting: dict | None + ) -> Preview: ... + + +class SummaryIndexServiceProtocol(Protocol): + def generate_and_vectorize_summary( + self, dataset_id: str, document_id: str, is_preview: bool, summary_index_setting: dict | None = None + ) -> None: ... diff --git a/api/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/core/workflow/nodes/knowledge_retrieval/__init__.py index 4d4a4cbd9f..33ea4277b4 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/__init__.py +++ b/api/core/workflow/nodes/knowledge_retrieval/__init__.py @@ -1,3 +1 @@ -from .knowledge_retrieval_node import KnowledgeRetrievalNode - -__all__ = ["KnowledgeRetrievalNode"] +"""Knowledge retrieval workflow node package.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 86bb2495e7..bc5618685a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -3,8 +3,9 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): @@ -113,7 +114,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): Knowledge retrieval Node Data. """ - type: str = "knowledge-retrieval" + type: NodeType = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL query_variable_selector: list[str] | None | str = None query_attachment_selector: list[str] | None | str = None dataset_ids: list[str] diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0cfd39e485..80f59140be 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,59 +1,66 @@ +"""Knowledge retrieval workflow node implementation. + +This node now lives under ``core.workflow.nodes`` and is discovered directly by +the workflow node registry. +""" + import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ( - NodeType, +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ( + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source -from core.workflow.variables import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from core.workflow.variables.segments import ArrayObjectSegment +from dify_graph.variables.segments import ArrayObjectSegment -from .entities import KnowledgeRetrievalNodeData +from .entities import ( + Condition, + KnowledgeRetrievalNodeData, + MetadataFilteringCondition, +) from .exc import ( KnowledgeRetrievalNodeError, RateLimitExceededError, ) +from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): - node_type = NodeType.KNOWLEDGE_RETRIEVAL + node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - rag_retrieval: RAGRetrievalProtocol, - *, - llm_file_saver: LLMFileSaver | None = None, ): super().__init__( id=id, @@ -63,14 +70,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._rag_retrieval = rag_retrieval - - if llm_file_saver is None: - llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, - ) - self._llm_file_saver = llm_file_saver + self._rag_retrieval = DatasetRetrieval() @classmethod def version(cls): @@ -115,7 +115,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD try: results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) - outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])} + outputs = {"result": ArrayObjectSegment(value=[item.model_dump(by_alias=True) for item in results])} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -160,6 +160,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: + dify_ctx = self.require_dify_context() dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -169,6 +170,12 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if node_data.metadata_filtering_mode is not None: metadata_filtering_mode = node_data.metadata_filtering_mode + resolved_metadata_conditions = ( + self._resolve_metadata_filtering_conditions(node_data.metadata_filtering_conditions) + if node_data.metadata_filtering_conditions + else None + ) + if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: @@ -176,10 +183,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model = node_data.single_retrieval_config.model retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - tenant_id=self.tenant_id, - user_id=self.user_id, - app_id=self.app_id, - user_from=self.user_from.value, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + app_id=dify_ctx.app_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value, completion_params=model.completion_params, @@ -187,7 +194,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD model_mode=model.mode, model_name=model.name, metadata_model_config=node_data.metadata_model_config, - metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, query=query, ) @@ -195,8 +202,8 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") - reranking_model = None - weights = None + reranking_model: RerankingModelDict | None = None + weights: WeightsDict | None = None match node_data.multiple_retrieval_config.reranking_mode: case "reranking_model": if node_data.multiple_retrieval_config.reranking_model: @@ -229,10 +236,10 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( request=KnowledgeRetrievalRequest( - app_id=self.app_id, - tenant_id=self.tenant_id, - user_id=self.user_id, - user_from=self.user_from.value, + app_id=dify_ctx.app_id, + tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, + user_from=dify_ctx.user_from.value, dataset_ids=dataset_ids, query=query, retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value, @@ -245,7 +252,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD weights=weights, reranking_enable=node_data.multiple_retrieval_config.reranking_enable, metadata_model_config=node_data.metadata_model_config, - metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) @@ -254,21 +261,60 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD usage = self._rag_retrieval.llm_usage return retrieval_resource_list, usage + def _resolve_metadata_filtering_conditions( + self, conditions: MetadataFilteringCondition + ) -> MetadataFilteringCondition: + if conditions.conditions is None: + return MetadataFilteringCondition( + logical_operator=conditions.logical_operator, + conditions=None, + ) + + variable_pool = self.graph_runtime_state.variable_pool + resolved_conditions: list[Condition] = [] + for cond in conditions.conditions or []: + value = cond.value + if isinstance(value, str): + segment_group = variable_pool.convert_template(value) + if len(segment_group.value) == 1: + resolved_value = segment_group.value[0].to_object() + else: + resolved_value = segment_group.text + elif isinstance(value, Sequence) and all(isinstance(v, str) for v in value): + resolved_values = [] + for v in value: # type: ignore + segment_group = variable_pool.convert_template(v) + if len(segment_group.value) == 1: + resolved_values.append(segment_group.value[0].to_object()) + else: + resolved_values.append(segment_group.text) + resolved_value = resolved_values + else: + resolved_value = value + resolved_conditions.append( + Condition( + name=cond.name, + comparison_operator=cond.comparison_operator, + value=resolved_value, + ) + ) + return MetadataFilteringCondition( + logical_operator=conditions.logical_operator or "and", + conditions=resolved_conditions, + ) + @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: KnowledgeRetrievalNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) - variable_mapping = {} - if typed_node_data.query_variable_selector: - variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector - if typed_node_data.query_attachment_selector: - variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector + if node_data.query_variable_selector: + variable_mapping[node_id + ".query"] = node_data.query_variable_selector + if node_data.query_attachment_selector: + variable_mapping[node_id + ".queryAttachment"] = node_data.query_attachment_selector return variable_mapping diff --git a/api/core/workflow/repositories/rag_retrieval_protocol.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py similarity index 77% rename from api/core/workflow/repositories/rag_retrieval_protocol.py rename to api/core/workflow/nodes/knowledge_retrieval/retrieval.py index f91cecb694..e1311ab962 100644 --- a/api/core/workflow/repositories/rag_retrieval_protocol.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -2,9 +2,11 @@ from typing import Any, Literal, Protocol from pydantic import BaseModel, Field -from core.model_runtime.entities import LLMUsage -from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition -from core.workflow.nodes.llm.entities import ModelConfig +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from dify_graph.model_runtime.entities import LLMUsage +from dify_graph.nodes.llm.entities import ModelConfig + +from .entities import MetadataFilteringCondition class SourceChildChunk(BaseModel): @@ -28,7 +30,7 @@ class SourceMetadata(BaseModel): segment_id: str | None = Field(default=None, description="Segment unique identifier") retriever_from: str = Field(default="workflow", description="Retriever source context") score: float = Field(default=0.0, description="Retrieval relevance score") - child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks") + child_chunks: list[SourceChildChunk] = Field(default_factory=list, description="List of child chunks") segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved") segment_word_count: int | None = Field(default=0, description="Word count of the segment") segment_position: int | None = Field(default=0, description="Position of segment in document") @@ -74,35 +76,14 @@ class KnowledgeRetrievalRequest(BaseModel): top_k: int = Field(default=0, description="Number of top results to return") score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") reranking_mode: str = Field(default="reranking_model", description="Reranking strategy") - reranking_model: dict | None = Field(default=None, description="Reranking model configuration") - weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking") + reranking_model: RerankingModelDict | None = Field(default=None, description="Reranking model configuration") + weights: WeightsDict | None = Field(default=None, description="Weights for weighted score reranking") reranking_enable: bool = Field(default=True, description="Whether reranking is enabled") attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval") class RAGRetrievalProtocol(Protocol): - """Protocol for RAG-based knowledge retrieval implementations. - - Implementations of this protocol handle knowledge retrieval from datasets - including rate limiting, dataset filtering, and document retrieval. - """ - @property - def llm_usage(self) -> LLMUsage: - """Return accumulated LLM usage for retrieval operations.""" - ... + def llm_usage(self) -> LLMUsage: ... - def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: - """Retrieve knowledge from datasets based on the provided request. - - Args: - request: Knowledge retrieval request with search parameters - - Returns: - List of sources matching the search criteria - - Raises: - RateLimitExceededError: If rate limit is exceeded - ModelNotExistError: If specified model doesn't exist - """ - ... + def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: ... diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py deleted file mode 100644 index 72f150d920..0000000000 --- a/api/core/workflow/nodes/llm/llm_utils.py +++ /dev/null @@ -1,91 +0,0 @@ -from collections.abc import Sequence -from typing import cast - -from core.model_manager import ModelInstance -from core.model_runtime.entities import PromptMessageRole -from core.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from core.model_runtime.entities.model_entities import AIModelEntity -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.workflow.file.models import File -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment - -from .exc import InvalidVariableTypeError - - -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - model_instance.credentials, - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py deleted file mode 100644 index 85df543a2a..0000000000 --- a/api/core/workflow/nodes/node_mapping.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Mapping - -from core.workflow.enums import NodeType -from core.workflow.nodes.base.node import Node - -LATEST_VERSION = "latest" - -# Mapping is built by Node.get_node_type_classes_mapping(), which imports and walks core.workflow.nodes -NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() diff --git a/api/core/workflow/nodes/start/entities.py b/api/core/workflow/nodes/start/entities.py deleted file mode 100644 index 3a99e2cbc2..0000000000 --- a/api/core/workflow/nodes/start/entities.py +++ /dev/null @@ -1,14 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from core.workflow.nodes.base import BaseNodeData -from core.workflow.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/template_transform/entities.py b/api/core/workflow/nodes/template_transform/entities.py deleted file mode 100644 index efb7a72f59..0000000000 --- a/api/core/workflow/nodes/template_transform/entities.py +++ /dev/null @@ -1,11 +0,0 @@ -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - variables: list[VariableSelector] - template: str diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index 6c53acee4f..ea7d20befe 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -3,14 +3,19 @@ from typing import Any, Literal, Union from pydantic import BaseModel, Field, ValidationInfo, field_validator +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.trigger_plugin.exc import TriggerEventParameterError +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType + +from .exc import TriggerEventParameterError class TriggerEventNodeData(BaseNodeData): """Plugin trigger node data""" + type: NodeType = TRIGGER_PLUGIN_NODE_TYPE + class TriggerEventInput(BaseModel): value: Union[Any, list[str]] type: Literal["mixed", "variable", "constant"] @@ -38,8 +43,6 @@ class TriggerEventNodeData(BaseNodeData): raise ValueError("value must be a string, int, float, bool or dict") return type - title: str - desc: str | None = None plugin_id: str = Field(..., description="Plugin ID") provider_id: str = Field(..., description="Provider ID") event_name: str = Field(..., description="Event name") diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index e11cb30a7f..118c2f2668 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,16 +1,18 @@ from collections.abc import Mapping +from typing import Any -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node from .entities import TriggerEventNodeData class TriggerEventNode(Node[TriggerEventNodeData]): - node_type = NodeType.TRIGGER_PLUGIN + node_type = TRIGGER_PLUGIN_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod @@ -32,6 +34,9 @@ class TriggerEventNode(Node[TriggerEventNodeData]): def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + def _run(self) -> NodeRunResult: """ Run the plugin trigger node. @@ -41,7 +46,7 @@ class TriggerEventNode(Node[TriggerEventNodeData]): """ # Get trigger data passed when workflow was triggered - metadata = { + metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { "provider_id": self.node_data.provider_id, "event_name": self.node_data.event_name, diff --git a/api/core/workflow/nodes/trigger_schedule/__init__.py b/api/core/workflow/nodes/trigger_schedule/__init__.py index 6773bae502..07b711a0fd 100644 --- a/api/core/workflow/nodes/trigger_schedule/__init__.py +++ b/api/core/workflow/nodes/trigger_schedule/__init__.py @@ -1,3 +1,3 @@ -from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode +from .trigger_schedule_node import TriggerScheduleNode __all__ = ["TriggerScheduleNode"] diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index a515d02d55..95a2548678 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -2,7 +2,9 @@ from typing import Literal, Union from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): @@ -10,6 +12,7 @@ class TriggerScheduleNodeData(BaseNodeData): Trigger Schedule Node Data """ + type: NodeType = TRIGGER_SCHEDULE_NODE_TYPE mode: str = Field(default="visual", description="Schedule mode: visual or cron") frequency: str | None = Field(default=None, description="Frequency for visual mode: hourly, daily, weekly, monthly") cron_expression: str | None = Field(default=None, description="Cron expression for cron mode") diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py index 2f99880ff1..336d64d58f 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/core/workflow/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index fb5c8a4dce..b9580e6ab1 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,15 +1,17 @@ from collections.abc import Mapping -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.trigger_schedule.entities import TriggerScheduleNodeData +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node + +from .entities import TriggerScheduleNodeData class TriggerScheduleNode(Node[TriggerScheduleNodeData]): - node_type = NodeType.TRIGGER_SCHEDULE + node_type = TRIGGER_SCHEDULE_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod @@ -19,7 +21,7 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { - "type": "trigger-schedule", + "type": TRIGGER_SCHEDULE_NODE_TYPE, "config": { "mode": "visual", "frequency": "daily", diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 1011e60b43..242bf5ef6a 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,10 +1,42 @@ from collections.abc import Sequence from enum import StrEnum -from typing import Literal from pydantic import BaseModel, Field, field_validator -from core.workflow.nodes.base import BaseNodeData +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import NodeType +from dify_graph.variables.types import SegmentType + +_WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + } +) + +_WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + } +) + +_WEBHOOK_PARAMETER_ALLOWED_TYPES = _WEBHOOK_HEADER_ALLOWED_TYPES | _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES + +_WEBHOOK_BODY_ALLOWED_TYPES = frozenset( + { + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.BOOLEAN, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_BOOLEAN, + SegmentType.ARRAY_OBJECT, + SegmentType.FILE, + } +) class Method(StrEnum): @@ -25,29 +57,34 @@ class ContentType(StrEnum): class WebhookParameter(BaseModel): - """Parameter definition for headers, query params, or body.""" + """Parameter definition for headers or query params.""" name: str + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook parameter type: {v}") + return v + class WebhookBodyParameter(BaseModel): """Body parameter with type information.""" name: str - type: Literal[ - "string", - "number", - "boolean", - "object", - "array[string]", - "array[number]", - "array[boolean]", - "array[object]", - "file", - ] = "string" + type: SegmentType = SegmentType.STRING required: bool = False + @field_validator("type", mode="after") + @classmethod + def validate_type(cls, v: SegmentType) -> SegmentType: + if v not in _WEBHOOK_BODY_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook body parameter type: {v}") + return v + class WebhookData(BaseNodeData): """ @@ -57,6 +94,7 @@ class WebhookData(BaseNodeData): class SyncMode(StrEnum): SYNC = "async" # only support + type: NodeType = TRIGGER_WEBHOOK_NODE_TYPE method: Method = Method.GET content_type: ContentType = Field(default=ContentType.JSON) headers: Sequence[WebhookParameter] = Field(default_factory=list) @@ -71,6 +109,22 @@ class WebhookData(BaseNodeData): return v.lower() return v + @field_validator("headers", mode="after") + @classmethod + def validate_header_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_HEADER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook header parameter type: {param.type}") + return v + + @field_validator("params", mode="after") + @classmethod + def validate_query_parameter_types(cls, v: Sequence[WebhookParameter]) -> Sequence[WebhookParameter]: + for param in v: + if param.type not in _WEBHOOK_QUERY_PARAMETER_ALLOWED_TYPES: + raise ValueError(f"Unsupported webhook query parameter type: {param.type}") + return v + status_code: int = 200 # Expected status code for response response_body: str = "" # Template for response body diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py index dc2239c287..4d87f2a069 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/core/workflow/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.exc import BaseNodeError +from dify_graph.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 9f6046c11a..317844cbda 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,14 +2,15 @@ import logging from collections.abc import Mapping from typing import Any -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import NodeExecutionType, NodeType -from core.workflow.file import FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import FileVariable +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.enums import NodeExecutionType +from dify_graph.file import FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import FileVariable from factories import file_factory from factories.variable_factory import build_segment_with_type @@ -19,7 +20,7 @@ logger = logging.getLogger(__name__) class TriggerWebhookNode(Node[WebhookData]): - node_type = NodeType.TRIGGER_WEBHOOK + node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT @classmethod @@ -69,6 +70,7 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): + dify_ctx = self.require_dify_context() related_id = file.get("related_id") transfer_method_value = file.get("transfer_method") if transfer_method_value: @@ -84,7 +86,7 @@ class TriggerWebhookNode(Node[WebhookData]): try: file_obj = file_factory.build_from_mapping( mapping=file, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) @@ -151,7 +153,7 @@ class TriggerWebhookNode(Node[WebhookData]): outputs[param_name] = raw_data continue - if param_type == "file": + if param_type == SegmentType.FILE: # Get File object (already processed by webhook controller) files = webhook_data.get("files", {}) if files and isinstance(files, dict): diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2ea4266b16..2e51a06bab 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,35 +5,94 @@ from typing import Any, cast from configs import dify_config from core.app.apps.exc import GenerateTaskStoppedError -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams -from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.file.models import File -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from core.workflow.nodes import NodeType -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class +from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file.models import File +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory -from models.enums import UserFrom from models.workflow import Workflow logger = logging.getLogger(__name__) +class _WorkflowChildEngineBuilder: + @staticmethod + def _has_node_id(graph_config: Mapping[str, Any], node_id: str) -> bool | None: + """ + Return whether `graph_config["nodes"]` contains the given node id. + + Returns `None` when the nodes payload shape is unexpected, so graph-level + validation can surface the original configuration error. + """ + nodes = graph_config.get("nodes") + if not isinstance(nodes, list): + return None + + for node in nodes: + if not isinstance(node, Mapping): + return None + current_id = node.get("id") + if isinstance(current_id, str) and current_id == node_id: + return True + return False + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) + if has_root_node is False: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + child_graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=root_node_id, + ) + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + child_engine.layer(LLMQuotaLayer()) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + class WorkflowEntry: def __init__( self, @@ -77,6 +136,7 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, graph=graph, @@ -88,6 +148,7 @@ class WorkflowEntry: scale_up_threshold=dify_config.GRAPH_ENGINE_SCALE_UP_THRESHOLD, scale_down_idle_time=dify_config.GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME, ), + child_engine_builder=self._child_engine_builder, ) # Add debug logging layer when in debug mode @@ -150,17 +211,19 @@ class WorkflowEntry: node_config_data = node_config["data"] # Get node type - node_type = NodeType(node_config_data["type"]) + node_type = node_config_data.type # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -170,8 +233,7 @@ class WorkflowEntry: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - typed_node_config = cast(dict[str, object], node_config) - node = cast(Any, node_factory).create_node(typed_node_config) + node = node_factory.create_node(node_config) node_cls = type(node) try: @@ -190,7 +252,7 @@ class WorkflowEntry: variable_mapping=variable_mapping, user_inputs=user_inputs, ) - if node_type != NodeType.DATASOURCE: + if node_type != BuiltinNodeTypes.DATASOURCE: cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -240,7 +302,7 @@ class WorkflowEntry: "height": node_height, "type": "custom", "data": { - "type": NodeType.START, + "type": BuiltinNodeTypes.START, "title": "Start", "desc": "Start", }, @@ -276,11 +338,11 @@ class WorkflowEntry: # Create a minimal graph for single node execution graph_dict = cls._create_single_node_graph(node_id, node_data) - node_type = NodeType(node_data.get("type", "")) - if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}: + node_type = node_data.get("type", "") + if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}: raise ValueError(f"Node type {node_type} not supported") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type]["1"] + node_cls = resolve_workflow_node_class(node_type=node_type, node_version="1") if not node_cls: raise ValueError(f"Node class not found for node type {node_type}") @@ -293,22 +355,21 @@ class WorkflowEntry: # init graph init params and runtime state graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="", workflow_id="", graph_config=graph_dict, - user_id=user_id, - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context=build_dify_run_context( + tenant_id=tenant_id, + app_id="", + user_id=user_id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init workflow run state - node_config: NodeConfigDict = { - "id": node_id, - "data": cast(NodeConfigData, node_data), - } + node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, diff --git a/api/core/workflow/README.md b/api/dify_graph/README.md similarity index 97% rename from api/core/workflow/README.md rename to api/dify_graph/README.md index 9a39f976a6..2fc5b8b890 100644 --- a/api/core/workflow/README.md +++ b/api/dify_graph/README.md @@ -113,8 +113,8 @@ The codebase enforces strict layering via import-linter: 1. Create node class in `nodes//` 1. Inherit from `BaseNode` or appropriate base class 1. Implement `_run()` method -1. Register in `nodes/node_mapping.py` -1. Add tests in `tests/unit_tests/core/workflow/nodes/` +1. Ensure the node module is importable under `nodes//` +1. Add tests in `tests/unit_tests/dify_graph/nodes/` ### Implementing a Custom Layer diff --git a/api/core/model_runtime/errors/__init__.py b/api/dify_graph/__init__.py similarity index 100% rename from api/core/model_runtime/errors/__init__.py rename to api/dify_graph/__init__.py diff --git a/api/core/workflow/constants.py b/api/dify_graph/constants.py similarity index 100% rename from api/core/workflow/constants.py rename to api/dify_graph/constants.py diff --git a/api/core/workflow/context/__init__.py b/api/dify_graph/context/__init__.py similarity index 86% rename from api/core/workflow/context/__init__.py rename to api/dify_graph/context/__init__.py index 1237d6a017..103f526bec 100644 --- a/api/core/workflow/context/__init__.py +++ b/api/dify_graph/context/__init__.py @@ -5,7 +5,7 @@ This package provides Flask-independent context management for workflow execution in multi-threaded environments. """ -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( AppContext, ContextProviderNotFoundError, ExecutionContext, @@ -17,7 +17,7 @@ from core.workflow.context.execution_context import ( register_context_capturer, reset_context_provider, ) -from core.workflow.context.models import SandboxContext +from dify_graph.context.models import SandboxContext __all__ = [ "AppContext", diff --git a/api/core/workflow/context/execution_context.py b/api/dify_graph/context/execution_context.py similarity index 100% rename from api/core/workflow/context/execution_context.py rename to api/dify_graph/context/execution_context.py diff --git a/api/core/workflow/context/models.py b/api/dify_graph/context/models.py similarity index 100% rename from api/core/workflow/context/models.py rename to api/dify_graph/context/models.py diff --git a/api/core/workflow/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py similarity index 96% rename from api/core/workflow/conversation_variable_updater.py rename to api/dify_graph/conversation_variable_updater.py index 6bfb2b2880..17b19f2502 100644 --- a/api/core/workflow/conversation_variable_updater.py +++ b/api/dify_graph/conversation_variable_updater.py @@ -1,7 +1,7 @@ import abc from typing import Protocol -from core.workflow.variables import VariableBase +from dify_graph.variables import VariableBase class ConversationVariableUpdater(Protocol): diff --git a/api/core/workflow/entities/__init__.py b/api/dify_graph/entities/__init__.py similarity index 82% rename from api/core/workflow/entities/__init__.py rename to api/dify_graph/entities/__init__.py index e73c38c1d3..ef7789c49c 100644 --- a/api/core/workflow/entities/__init__.py +++ b/api/dify_graph/entities/__init__.py @@ -1,11 +1,9 @@ -from .agent import AgentNodeStrategyInit from .graph_init_params import GraphInitParams from .workflow_execution import WorkflowExecution from .workflow_node_execution import WorkflowNodeExecution from .workflow_start_reason import WorkflowStartReason __all__ = [ - "AgentNodeStrategyInit", "GraphInitParams", "WorkflowExecution", "WorkflowNodeExecution", diff --git a/api/core/workflow/nodes/base/entities.py b/api/dify_graph/entities/base_node_data.py similarity index 64% rename from api/core/workflow/nodes/base/entities.py rename to api/dify_graph/entities/base_node_data.py index c5426e3fb7..47b37c9daf 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/dify_graph/entities/base_node_data.py @@ -3,16 +3,15 @@ from __future__ import annotations import json from abc import ABC from builtins import type as type_ -from collections.abc import Sequence from enum import StrEnum from typing import Any, Union -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator -from core.workflow.enums import ErrorStrategy - -from .exc import DefaultValueTypeError +from dify_graph.entities.exc import DefaultValueTypeError +from dify_graph.enums import ErrorStrategy, NodeType +# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. _NumberType = Union[int, float] @@ -28,54 +27,6 @@ class RetryConfig(BaseModel): return self.retry_interval / 1000 -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - class DefaultValueType(StrEnum): STRING = "string" NUMBER = "number" @@ -168,12 +119,24 @@ class DefaultValue(BaseModel): class BaseNodeData(ABC, BaseModel): - title: str + # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where + # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. + # `type` therefore accepts downstream string node kinds; unknown node implementations + # are rejected later when the node factory resolves the node registry. + # At that boundary, node-specific fields are still "extra" relative to this shared DTO, + # and persisted templates/workflows also carry undeclared compatibility keys such as + # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive + # here until graph parsing becomes discriminated by node type or those legacy payloads + # are normalized. + model_config = ConfigDict(extra="allow") + + type: NodeType + title: str = "" desc: str | None = None version: str = "1" error_strategy: ErrorStrategy | None = None default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = RetryConfig() + retry_config: RetryConfig = Field(default_factory=RetryConfig) @property def default_value_dict(self) -> dict[str, Any]: @@ -181,32 +144,35 @@ class BaseNodeData(ABC, BaseModel): return {item.key: item.value for item in self.default_value} return {} + def __getitem__(self, key: str) -> Any: + """ + Dict-style access without calling model_dump() on every lookup. + Prefer using model fields and Pydantic's extra storage. + """ + # First, check declared model fields + if key in self.__class__.model_fields: + return getattr(self, key) -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None + # Then, check undeclared compatibility fields stored in Pydantic's extra dict. + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras[key] + raise KeyError(key) -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict + def get(self, key: str, default: Any = None) -> Any: + """ + Dict-style .get() without calling model_dump() on every lookup. + """ + if key in self.__class__.model_fields: + return getattr(self, key) - class MetaData(BaseModel): - pass + extras = getattr(self, "__pydantic_extra__", None) + if extras is None: + extras = getattr(self, "model_extra", None) + if extras is not None and key in extras: + return extras.get(key, default) - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData + return default diff --git a/api/core/workflow/nodes/base/exc.py b/api/dify_graph/entities/exc.py similarity index 100% rename from api/core/workflow/nodes/base/exc.py rename to api/dify_graph/entities/exc.py diff --git a/api/core/workflow/entities/graph_config.py b/api/dify_graph/entities/graph_config.py similarity index 57% rename from api/core/workflow/entities/graph_config.py rename to api/dify_graph/entities/graph_config.py index 209dcfe6bc..36f7b94e82 100644 --- a/api/core/workflow/entities/graph_config.py +++ b/api/dify_graph/entities/graph_config.py @@ -4,21 +4,20 @@ import sys from pydantic import TypeAdapter, with_config +from dify_graph.entities.base_node_data import BaseNodeData + if sys.version_info >= (3, 12): from typing import TypedDict else: from typing_extensions import TypedDict -@with_config(extra="allow") -class NodeConfigData(TypedDict): - type: str - - @with_config(extra="allow") class NodeConfigDict(TypedDict): id: str - data: NodeConfigData + # This is the permissive raw graph boundary. Node factories re-validate `data` + # with the concrete `NodeData` subtype after resolving the node implementation. + data: BaseNodeData NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/core/workflow/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py similarity index 62% rename from api/core/workflow/entities/graph_init_params.py rename to api/dify_graph/entities/graph_init_params.py index ff224a28d1..f785d58a52 100644 --- a/api/core/workflow/entities/graph_init_params.py +++ b/api/dify_graph/entities/graph_init_params.py @@ -3,6 +3,8 @@ from typing import Any from pydantic import BaseModel, Field +DIFY_RUN_CONTEXT_KEY = "_dify" + class GraphInitParams(BaseModel): """GraphInitParams encapsulates the configurations and contextual information @@ -16,15 +18,7 @@ class GraphInitParams(BaseModel): """ # init params - tenant_id: str = Field(..., description="tenant / workspace id") - app_id: str = Field(..., description="app id") workflow_id: str = Field(..., description="workflow id") graph_config: Mapping[str, Any] = Field(..., description="graph config") - user_id: str = Field(..., description="user id") - user_from: str = Field( - ..., description="user from, account or end-user" - ) # Should be UserFrom enum: 'account' | 'end-user' - invoke_from: str = Field( - ..., description="invoke from, service-api, web-app, explore or debugger" - ) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger' + run_context: Mapping[str, Any] = Field(..., description="runtime context") call_depth: int = Field(..., description="call depth") diff --git a/api/core/workflow/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py similarity index 96% rename from api/core/workflow/entities/pause_reason.py rename to api/dify_graph/entities/pause_reason.py index 147f56e8be..86d8c8ca16 100644 --- a/api/core/workflow/entities/pause_reason.py +++ b/api/dify_graph/entities/pause_reason.py @@ -4,7 +4,7 @@ from typing import Annotated, Any, Literal, TypeAlias from pydantic import BaseModel, Field -from core.workflow.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.entities import FormInput, UserAction class PauseReasonType(StrEnum): diff --git a/api/core/workflow/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py similarity index 96% rename from api/core/workflow/entities/workflow_execution.py rename to api/dify_graph/entities/workflow_execution.py index 1b3fb36f1f..459ac46415 100644 --- a/api/core/workflow/entities/workflow_execution.py +++ b/api/dify_graph/entities/workflow_execution.py @@ -13,7 +13,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType from libs.datetime_utils import naive_utc_now diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py similarity index 96% rename from api/core/workflow/entities/workflow_node_execution.py rename to api/dify_graph/entities/workflow_node_execution.py index 4abc9c068d..bc7e0d02e5 100644 --- a/api/core/workflow/entities/workflow_node_execution.py +++ b/api/dify_graph/entities/workflow_node_execution.py @@ -12,7 +12,7 @@ from typing import Any from pydantic import BaseModel, Field, PrivateAttr -from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class WorkflowNodeExecution(BaseModel): @@ -48,7 +48,7 @@ class WorkflowNodeExecution(BaseModel): index: int # Sequence number for ordering in trace visualization predecessor_node_id: str | None = None # ID of the node that executed before this one node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, knowledge) + node_type: NodeType # Type of node (e.g., start, llm, downstream response node) title: str # Display title of the node # Execution data diff --git a/api/core/workflow/entities/workflow_start_reason.py b/api/dify_graph/entities/workflow_start_reason.py similarity index 100% rename from api/core/workflow/entities/workflow_start_reason.py rename to api/dify_graph/entities/workflow_start_reason.py diff --git a/api/core/workflow/enums.py b/api/dify_graph/enums.py similarity index 75% rename from api/core/workflow/enums.py rename to api/dify_graph/enums.py index bb3b13e8c6..cfb135cbb0 100644 --- a/api/core/workflow/enums.py +++ b/api/dify_graph/enums.py @@ -1,4 +1,5 @@ from enum import StrEnum +from typing import ClassVar, TypeAlias class NodeState(StrEnum): @@ -33,56 +34,71 @@ class SystemVariableKey(StrEnum): INVOKE_FROM = "invoke_from" -class NodeType(StrEnum): - START = "start" - END = "end" - ANSWER = "answer" - LLM = "llm" - KNOWLEDGE_RETRIEVAL = "knowledge-retrieval" - KNOWLEDGE_INDEX = "knowledge-index" - IF_ELSE = "if-else" - CODE = "code" - TEMPLATE_TRANSFORM = "template-transform" - QUESTION_CLASSIFIER = "question-classifier" - HTTP_REQUEST = "http-request" - TOOL = "tool" - DATASOURCE = "datasource" - VARIABLE_AGGREGATOR = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. - LOOP = "loop" - LOOP_START = "loop-start" - LOOP_END = "loop-end" - ITERATION = "iteration" - ITERATION_START = "iteration-start" # Fake start node for iteration. - PARAMETER_EXTRACTOR = "parameter-extractor" - VARIABLE_ASSIGNER = "assigner" - DOCUMENT_EXTRACTOR = "document-extractor" - LIST_OPERATOR = "list-operator" - AGENT = "agent" - TRIGGER_WEBHOOK = "trigger-webhook" - TRIGGER_SCHEDULE = "trigger-schedule" - TRIGGER_PLUGIN = "trigger-plugin" - HUMAN_INPUT = "human-input" +NodeType: TypeAlias = str - @property - def is_trigger_node(self) -> bool: - """Check if this node type is a trigger node.""" - return self in [ - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - ] - @property - def is_start_node(self) -> bool: - """Check if this node type can serve as a workflow entry point.""" - return self in [ - NodeType.START, - NodeType.DATASOURCE, - NodeType.TRIGGER_WEBHOOK, - NodeType.TRIGGER_SCHEDULE, - NodeType.TRIGGER_PLUGIN, - ] +class BuiltinNodeTypes: + """Built-in node type string constants. + + `node_type` values are plain strings throughout the graph runtime. This namespace + only exposes the built-in values shipped by `dify_graph`; downstream packages can + use additional strings without extending this class. + """ + + START: ClassVar[NodeType] = "start" + END: ClassVar[NodeType] = "end" + ANSWER: ClassVar[NodeType] = "answer" + LLM: ClassVar[NodeType] = "llm" + KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" + IF_ELSE: ClassVar[NodeType] = "if-else" + CODE: ClassVar[NodeType] = "code" + TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" + QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" + HTTP_REQUEST: ClassVar[NodeType] = "http-request" + TOOL: ClassVar[NodeType] = "tool" + DATASOURCE: ClassVar[NodeType] = "datasource" + VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" + LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" + LOOP: ClassVar[NodeType] = "loop" + LOOP_START: ClassVar[NodeType] = "loop-start" + LOOP_END: ClassVar[NodeType] = "loop-end" + ITERATION: ClassVar[NodeType] = "iteration" + ITERATION_START: ClassVar[NodeType] = "iteration-start" + PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" + VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" + DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" + LIST_OPERATOR: ClassVar[NodeType] = "list-operator" + AGENT: ClassVar[NodeType] = "agent" + HUMAN_INPUT: ClassVar[NodeType] = "human-input" + + +BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( + BuiltinNodeTypes.START, + BuiltinNodeTypes.END, + BuiltinNodeTypes.ANSWER, + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + BuiltinNodeTypes.IF_ELSE, + BuiltinNodeTypes.CODE, + BuiltinNodeTypes.TEMPLATE_TRANSFORM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.HTTP_REQUEST, + BuiltinNodeTypes.TOOL, + BuiltinNodeTypes.DATASOURCE, + BuiltinNodeTypes.VARIABLE_AGGREGATOR, + BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, + BuiltinNodeTypes.LOOP, + BuiltinNodeTypes.LOOP_START, + BuiltinNodeTypes.LOOP_END, + BuiltinNodeTypes.ITERATION, + BuiltinNodeTypes.ITERATION_START, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + BuiltinNodeTypes.VARIABLE_ASSIGNER, + BuiltinNodeTypes.DOCUMENT_EXTRACTOR, + BuiltinNodeTypes.LIST_OPERATOR, + BuiltinNodeTypes.AGENT, + BuiltinNodeTypes.HUMAN_INPUT, +) class NodeExecutionType(StrEnum): @@ -229,6 +245,9 @@ _END_STATE = frozenset( class WorkflowNodeExecutionMetadataKey(StrEnum): """ Node Run Metadata Key. + + Values in this enum are persisted as execution metadata and must stay in sync + with every node that writes `NodeRunResult.metadata`. """ TOTAL_TOKENS = "total_tokens" @@ -236,7 +255,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): CURRENCY = "currency" TOOL_INFO = "tool_info" AGENT_LOG = "agent_log" - TRIGGER_INFO = "trigger_info" ITERATION_ID = "iteration_id" ITERATION_INDEX = "iteration_index" LOOP_ID = "loop_id" @@ -251,6 +269,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum): ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output DATASOURCE_INFO = "datasource_info" + TRIGGER_INFO = "trigger_info" COMPLETED_REASON = "completed_reason" # completed reason for loop node diff --git a/api/core/workflow/errors.py b/api/dify_graph/errors.py similarity index 88% rename from api/core/workflow/errors.py rename to api/dify_graph/errors.py index 5bf1faee5d..463d17713e 100644 --- a/api/core/workflow/errors.py +++ b/api/dify_graph/errors.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.base.node import Node +from dify_graph.nodes.base.node import Node class WorkflowNodeRunFailedError(Exception): diff --git a/api/core/workflow/file/__init__.py b/api/dify_graph/file/__init__.py similarity index 100% rename from api/core/workflow/file/__init__.py rename to api/dify_graph/file/__init__.py diff --git a/api/core/workflow/file/constants.py b/api/dify_graph/file/constants.py similarity index 100% rename from api/core/workflow/file/constants.py rename to api/dify_graph/file/constants.py diff --git a/api/core/workflow/file/enums.py b/api/dify_graph/file/enums.py similarity index 100% rename from api/core/workflow/file/enums.py rename to api/dify_graph/file/enums.py diff --git a/api/core/workflow/file/file_manager.py b/api/dify_graph/file/file_manager.py similarity index 97% rename from api/core/workflow/file/file_manager.py rename to api/dify_graph/file/file_manager.py index a7719400d9..8d998054db 100644 --- a/api/core/workflow/file/file_manager.py +++ b/api/dify_graph/file/file_manager.py @@ -3,14 +3,14 @@ from __future__ import annotations import base64 from collections.abc import Mapping -from core.model_runtime.entities import ( +from dify_graph.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, VideoPromptMessageContent, ) -from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes from . import helpers from .enums import FileAttribute diff --git a/api/core/workflow/file/helpers.py b/api/dify_graph/file/helpers.py similarity index 100% rename from api/core/workflow/file/helpers.py rename to api/dify_graph/file/helpers.py diff --git a/api/core/workflow/file/models.py b/api/dify_graph/file/models.py similarity index 85% rename from api/core/workflow/file/models.py rename to api/dify_graph/file/models.py index cd7d3edde8..dcba00978e 100644 --- a/api/core/workflow/file/models.py +++ b/api/dify_graph/file/models.py @@ -2,10 +2,11 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any +from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator -from core.model_runtime.entities.message_entities import ImagePromptMessageContent +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from . import helpers from .constants import FILE_MODEL_IDENTITY @@ -43,6 +44,24 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 +class ToolFile(BaseModel): + id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") + user_id: UUID = Field(..., description="ID of the user who owns this file") + tenant_id: UUID = Field(..., description="ID of the tenant/organization") + conversation_id: UUID | None = Field(None, description="ID of the associated conversation") + file_key: str = Field(..., max_length=255, description="Storage key for the file") + mimetype: str = Field(..., max_length=255, description="MIME type of the file") + original_url: str | None = Field( + None, max_length=2048, description="Original URL if file was fetched from external source" + ) + name: str = Field(default="", max_length=255, description="Display name of the file") + size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") + + class Config: + from_attributes = True # Enable ORM mode for SQLAlchemy compatibility + populate_by_name = True + + class File(BaseModel): # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. diff --git a/api/core/workflow/file/protocols.py b/api/dify_graph/file/protocols.py similarity index 94% rename from api/core/workflow/file/protocols.py rename to api/dify_graph/file/protocols.py index 8d923148e0..24cbb42735 100644 --- a/api/core/workflow/file/protocols.py +++ b/api/dify_graph/file/protocols.py @@ -14,7 +14,7 @@ class HttpResponseProtocol(Protocol): class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``core.workflow.file``. + """Runtime dependencies required by ``dify_graph.file``. Implementations are expected to be provided by integration layers (for example, ``core.app.workflow.file_runtime``) so the workflow package avoids importing diff --git a/api/core/workflow/file/runtime.py b/api/dify_graph/file/runtime.py similarity index 100% rename from api/core/workflow/file/runtime.py rename to api/dify_graph/file/runtime.py diff --git a/api/core/workflow/file/tool_file_parser.py b/api/dify_graph/file/tool_file_parser.py similarity index 100% rename from api/core/workflow/file/tool_file_parser.py rename to api/dify_graph/file/tool_file_parser.py diff --git a/api/core/workflow/graph/__init__.py b/api/dify_graph/graph/__init__.py similarity index 100% rename from api/core/workflow/graph/__init__.py rename to api/dify_graph/graph/__init__.py diff --git a/api/core/workflow/graph/edge.py b/api/dify_graph/graph/edge.py similarity index 91% rename from api/core/workflow/graph/edge.py rename to api/dify_graph/graph/edge.py index 1d57747dbb..f4f67ea6be 100644 --- a/api/core/workflow/graph/edge.py +++ b/api/dify_graph/graph/edge.py @@ -1,7 +1,7 @@ import uuid from dataclasses import dataclass, field -from core.workflow.enums import NodeState +from dify_graph.enums import NodeState @dataclass diff --git a/api/core/workflow/graph/graph.py b/api/dify_graph/graph/graph.py similarity index 86% rename from api/core/workflow/graph/graph.py rename to api/dify_graph/graph/graph.py index 52bbbb20cc..85117583e0 100644 --- a/api/core/workflow/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -7,9 +7,9 @@ from typing import Protocol, cast, final from pydantic import TypeAdapter -from core.workflow.entities.graph_config import NodeConfigDict -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType -from core.workflow.nodes.base.node import Node +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState +from dify_graph.nodes.base.node import Node from libs.typing import is_str from .edge import Edge @@ -34,7 +34,8 @@ class NodeFactory(Protocol): :param node_config: node configuration dictionary containing type and other data :return: initialized Node instance - :raises ValueError: if node type is unknown or configuration is invalid + :raises ValueError: if node type is unknown or no implementation exists for the resolved version + :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation """ ... @@ -82,53 +83,6 @@ class Graph: return node_configs_map - @classmethod - def _find_root_node_id( - cls, - node_configs_map: Mapping[str, NodeConfigDict], - edge_configs: Sequence[Mapping[str, object]], - root_node_id: str | None = None, - ) -> str: - """ - Find the root node ID if not specified. - - :param node_configs_map: mapping of node ID to node config - :param edge_configs: list of edge configurations - :param root_node_id: explicitly specified root node ID - :return: determined root node ID - """ - if root_node_id: - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - return root_node_id - - # Find nodes with no incoming edges - nodes_with_incoming: set[str] = set() - for edge_config in edge_configs: - target = edge_config.get("target") - if isinstance(target, str): - nodes_with_incoming.add(target) - - root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming] - - # Prefer START node if available - start_node_id = None - for nid in root_candidates: - node_data = node_configs_map[nid]["data"] - node_type = node_data["type"] - if not isinstance(node_type, str): - continue - if NodeType(node_type).is_start_node: - start_node_id = nid - break - - root_node_id = start_node_id or (root_candidates[0] if root_candidates else None) - - if not root_node_id: - raise ValueError("Unable to determine root node ID") - - return root_node_id - @classmethod def _build_edges( cls, edge_configs: list[dict[str, object]] @@ -203,6 +157,23 @@ class Graph: return GraphBuilder(graph_cls=cls) + @staticmethod + def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: + """ + Remove editor-only nodes before `NodeConfigDict` validation. + + Persisted note widgets use a top-level `type == "custom-note"` but leave + `data.type` empty because they are never executable graph nodes. Filter + them while configs are still raw dicts so Pydantic does not validate + their placeholder payloads against `BaseNodeData.type: NodeType`. + """ + filtered_node_configs: list[dict[str, object]] = [] + for node_config in node_configs: + if node_config.get("type", "") == "custom-note": + continue + filtered_node_configs.append(dict(node_config)) + return filtered_node_configs + @classmethod def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: """ @@ -286,15 +257,15 @@ class Graph: *, graph_config: Mapping[str, object], node_factory: NodeFactory, - root_node_id: str | None = None, + root_node_id: str, skip_validation: bool = False, ) -> Graph: """ - Initialize graph + Initialize a graph with an explicit execution entry point. :param graph_config: graph config containing nodes and edges :param node_factory: factory for creating node instances from config data - :param root_node_id: root node id + :param root_node_id: active root node id :return: graph instance """ # Parse configs @@ -302,18 +273,18 @@ class Graph: node_configs = graph_config.get("nodes", []) edge_configs = cast(list[dict[str, object]], edge_configs) + node_configs = cast(list[dict[str, object]], node_configs) + node_configs = cls._filter_canvas_only_nodes(node_configs) node_configs = _ListNodeConfigDict.validate_python(node_configs) if not node_configs: raise ValueError("Graph must have at least one node") - node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"] - # Parse node configurations node_configs_map = cls._parse_node_configs(node_configs) - # Find root node - root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id) + if root_node_id not in node_configs_map: + raise ValueError(f"Root node id {root_node_id} not found in the graph") # Build edges edges, in_edges, out_edges = cls._build_edges(edge_configs) diff --git a/api/core/workflow/graph/graph_template.py b/api/dify_graph/graph/graph_template.py similarity index 100% rename from api/core/workflow/graph/graph_template.py rename to api/dify_graph/graph/graph_template.py diff --git a/api/core/workflow/graph/validation.py b/api/dify_graph/graph/validation.py similarity index 73% rename from api/core/workflow/graph/validation.py rename to api/dify_graph/graph/validation.py index 41b4fdfa60..50d1440b04 100644 --- a/api/core/workflow/graph/validation.py +++ b/api/dify_graph/graph/validation.py @@ -4,7 +4,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import TYPE_CHECKING, Protocol -from core.workflow.enums import NodeExecutionType, NodeType +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType if TYPE_CHECKING: from .graph import Graph @@ -71,7 +71,7 @@ class _RootNodeValidator: """Validates root node invariants.""" invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START) + container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: root_node = graph.root_node @@ -86,7 +86,7 @@ class _RootNodeValidator: ) return issues - node_type = getattr(root_node, "node_type", None) + node_type = root_node.node_type if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: issues.append( GraphValidationIssue( @@ -114,45 +114,9 @@ class GraphValidator: raise GraphValidationError(issues) -@dataclass(frozen=True, slots=True) -class _TriggerStartExclusivityValidator: - """Ensures trigger nodes do not coexist with UserInput (start) nodes.""" - - conflict_code: str = "TRIGGER_START_NODE_CONFLICT" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - start_node_id: str | None = None - trigger_node_ids: list[str] = [] - - for node in graph.nodes.values(): - node_type = getattr(node, "node_type", None) - if not isinstance(node_type, NodeType): - continue - - if node_type == NodeType.START: - start_node_id = node.id - elif node_type.is_trigger_node: - trigger_node_ids.append(node.id) - - if start_node_id and trigger_node_ids: - trigger_list = ", ".join(trigger_node_ids) - return [ - GraphValidationIssue( - code=self.conflict_code, - message=( - f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}." - ), - node_id=start_node_id, - ) - ] - - return [] - - _DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( _EdgeEndpointValidator(), _RootNodeValidator(), - _TriggerStartExclusivityValidator(), ) diff --git a/api/core/workflow/graph_engine/__init__.py b/api/dify_graph/graph_engine/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/__init__.py rename to api/dify_graph/graph_engine/__init__.py diff --git a/api/core/workflow/graph_engine/_engine_utils.py b/api/dify_graph/graph_engine/_engine_utils.py similarity index 100% rename from api/core/workflow/graph_engine/_engine_utils.py rename to api/dify_graph/graph_engine/_engine_utils.py diff --git a/api/core/workflow/graph_engine/command_channels/README.md b/api/dify_graph/graph_engine/command_channels/README.md similarity index 100% rename from api/core/workflow/graph_engine/command_channels/README.md rename to api/dify_graph/graph_engine/command_channels/README.md diff --git a/api/core/workflow/graph_engine/command_channels/__init__.py b/api/dify_graph/graph_engine/command_channels/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/__init__.py rename to api/dify_graph/graph_engine/command_channels/__init__.py diff --git a/api/core/workflow/graph_engine/command_channels/in_memory_channel.py b/api/dify_graph/graph_engine/command_channels/in_memory_channel.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/in_memory_channel.py rename to api/dify_graph/graph_engine/command_channels/in_memory_channel.py diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/dify_graph/graph_engine/command_channels/redis_channel.py similarity index 100% rename from api/core/workflow/graph_engine/command_channels/redis_channel.py rename to api/dify_graph/graph_engine/command_channels/redis_channel.py diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/dify_graph/graph_engine/command_processing/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/command_processing/__init__.py rename to api/dify_graph/graph_engine/command_processing/__init__.py diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/dify_graph/graph_engine/command_processing/command_handlers.py similarity index 94% rename from api/core/workflow/graph_engine/command_processing/command_handlers.py rename to api/dify_graph/graph_engine/command_processing/command_handlers.py index cfe856d9e8..eefd0c366b 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/dify_graph/graph_engine/command_processing/command_handlers.py @@ -3,8 +3,8 @@ from typing import final from typing_extensions import override -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.runtime import VariablePool +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.runtime import VariablePool from ..domain.graph_execution import GraphExecution from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand diff --git a/api/core/workflow/graph_engine/command_processing/command_processor.py b/api/dify_graph/graph_engine/command_processing/command_processor.py similarity index 100% rename from api/core/workflow/graph_engine/command_processing/command_processor.py rename to api/dify_graph/graph_engine/command_processing/command_processor.py diff --git a/api/core/workflow/graph_engine/config.py b/api/dify_graph/graph_engine/config.py similarity index 100% rename from api/core/workflow/graph_engine/config.py rename to api/dify_graph/graph_engine/config.py diff --git a/api/core/workflow/graph_engine/domain/__init__.py b/api/dify_graph/graph_engine/domain/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/domain/__init__.py rename to api/dify_graph/graph_engine/domain/__init__.py diff --git a/api/core/workflow/graph_engine/domain/graph_execution.py b/api/dify_graph/graph_engine/domain/graph_execution.py similarity index 97% rename from api/core/workflow/graph_engine/domain/graph_execution.py rename to api/dify_graph/graph_engine/domain/graph_execution.py index 3ba6e5e37c..0ee4a9f9a7 100644 --- a/api/core/workflow/graph_engine/domain/graph_execution.py +++ b/api/dify_graph/graph_engine/domain/graph_execution.py @@ -8,9 +8,9 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.enums import NodeState -from core.workflow.runtime.graph_runtime_state import GraphExecutionProtocol +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.enums import NodeState +from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol from .node_execution import NodeExecution diff --git a/api/core/workflow/graph_engine/domain/node_execution.py b/api/dify_graph/graph_engine/domain/node_execution.py similarity index 96% rename from api/core/workflow/graph_engine/domain/node_execution.py rename to api/dify_graph/graph_engine/domain/node_execution.py index 85700caa3a..ae8f9a5e50 100644 --- a/api/core/workflow/graph_engine/domain/node_execution.py +++ b/api/dify_graph/graph_engine/domain/node_execution.py @@ -4,7 +4,7 @@ NodeExecution entity representing a node's execution state. from dataclasses import dataclass -from core.workflow.enums import NodeState +from dify_graph.enums import NodeState @dataclass diff --git a/api/core/model_runtime/model_providers/__base/__init__.py b/api/dify_graph/graph_engine/entities/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/__init__.py rename to api/dify_graph/graph_engine/entities/__init__.py diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/dify_graph/graph_engine/entities/commands.py similarity index 96% rename from api/core/workflow/graph_engine/entities/commands.py rename to api/dify_graph/graph_engine/entities/commands.py index 7e7b65247b..c56845cfc4 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/dify_graph/graph_engine/entities/commands.py @@ -11,7 +11,7 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.variables.variables import Variable +from dify_graph.variables.variables import Variable class CommandType(StrEnum): diff --git a/api/core/workflow/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py similarity index 95% rename from api/core/workflow/graph_engine/error_handler.py rename to api/dify_graph/graph_engine/error_handler.py index 62e144c12a..e206f21592 100644 --- a/api/core/workflow/graph_engine/error_handler.py +++ b/api/dify_graph/graph_engine/error_handler.py @@ -6,21 +6,21 @@ import logging import time from typing import TYPE_CHECKING, final -from core.workflow.enums import ( +from dify_graph.enums import ( ErrorStrategy as ErrorStrategyEnum, ) -from core.workflow.enums import ( +from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunRetryEvent, ) -from core.workflow.node_events import NodeRunResult +from dify_graph.node_events import NodeRunResult if TYPE_CHECKING: from .domain import GraphExecution @@ -159,6 +159,7 @@ class ErrorHandler: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.EXCEPTION, inputs=event.node_run_result.inputs, @@ -198,6 +199,7 @@ class ErrorHandler: node_id=event.node_id, node_type=event.node_type, start_at=event.start_at, + finished_at=event.finished_at, node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.EXCEPTION, inputs=event.node_run_result.inputs, diff --git a/api/core/workflow/graph_engine/event_management/__init__.py b/api/dify_graph/graph_engine/event_management/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/event_management/__init__.py rename to api/dify_graph/graph_engine/event_management/__init__.py diff --git a/api/core/workflow/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py similarity index 97% rename from api/core/workflow/graph_engine/event_management/event_handlers.py rename to api/dify_graph/graph_engine/event_management/event_handlers.py index 98a0702e1c..7f5ad40e0e 100644 --- a/api/core/workflow/graph_engine/event_management/event_handlers.py +++ b/api/dify_graph/graph_engine/event_management/event_handlers.py @@ -7,10 +7,9 @@ from collections.abc import Mapping from functools import singledispatchmethod from typing import TYPE_CHECKING, final -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunExceptionEvent, @@ -30,7 +29,8 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState from ..domain.graph_execution import GraphExecution from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/core/workflow/graph_engine/event_management/event_manager.py b/api/dify_graph/graph_engine/event_management/event_manager.py similarity index 98% rename from api/core/workflow/graph_engine/event_management/event_manager.py rename to api/dify_graph/graph_engine/event_management/event_manager.py index ae2e659543..616f621c3e 100644 --- a/api/core/workflow/graph_engine/event_management/event_manager.py +++ b/api/dify_graph/graph_engine/event_management/event_manager.py @@ -9,7 +9,7 @@ from collections.abc import Generator from contextlib import contextmanager from typing import final -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_events import GraphEngineEvent from ..layers.base import GraphEngineLayer diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py similarity index 88% rename from api/core/workflow/graph_engine/graph_engine.py rename to api/dify_graph/graph_engine/graph_engine.py index 7c46fc2239..ea98a46b06 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,14 +9,14 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, cast, final -from core.workflow.context import capture_current_context -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeExecutionType -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.context import capture_current_context +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import NodeExecutionType +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphEngineEvent, GraphNodeEventBase, GraphRunAbortedEvent, @@ -26,10 +26,11 @@ from core.workflow.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from core.workflow.runtime.graph_runtime_state import GraphProtocol + from dify_graph.runtime.graph_runtime_state import GraphProtocol from .command_processing import ( AbortCommandHandler, @@ -49,8 +50,9 @@ from .protocols.command_channel import CommandChannel from .worker_management import WorkerPool if TYPE_CHECKING: - from core.workflow.graph_engine.domain.graph_execution import GraphExecution - from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator + from dify_graph.entities import GraphInitParams + from dify_graph.graph_engine.domain.graph_execution import GraphExecution + from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator logger = logging.getLogger(__name__) @@ -74,6 +76,7 @@ class GraphEngine: graph_runtime_state: GraphRuntimeState, command_channel: CommandChannel, config: GraphEngineConfig = _DEFAULT_CONFIG, + child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, ) -> None: """Initialize the graph engine with all subsystems and dependencies.""" @@ -83,6 +86,9 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._child_engine_builder = child_engine_builder + if child_engine_builder is not None: + self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) # Graph execution tracks the overall execution state self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) @@ -214,6 +220,25 @@ class GraphEngine: self._bind_layer_context(layer) return self + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: dict[str, object] | Mapping[str, object], + root_node_id: str, + layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + ) -> GraphEngine: + return self._graph_runtime_state.create_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + def run(self) -> Generator[GraphEngineEvent, None, None]: """ Execute the graph using the modular architecture. diff --git a/api/core/workflow/graph_engine/graph_state_manager.py b/api/dify_graph/graph_engine/graph_state_manager.py similarity index 98% rename from api/core/workflow/graph_engine/graph_state_manager.py rename to api/dify_graph/graph_engine/graph_state_manager.py index d9773645c3..922a968435 100644 --- a/api/core/workflow/graph_engine/graph_state_manager.py +++ b/api/dify_graph/graph_engine/graph_state_manager.py @@ -6,8 +6,8 @@ import threading from collections.abc import Sequence from typing import TypedDict, final -from core.workflow.enums import NodeState -from core.workflow.graph import Edge, Graph +from dify_graph.enums import NodeState +from dify_graph.graph import Edge, Graph from .ready_queue import ReadyQueue diff --git a/api/core/workflow/graph_engine/graph_traversal/__init__.py b/api/dify_graph/graph_engine/graph_traversal/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/graph_traversal/__init__.py rename to api/dify_graph/graph_engine/graph_traversal/__init__.py diff --git a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py similarity index 97% rename from api/core/workflow/graph_engine/graph_traversal/edge_processor.py rename to api/dify_graph/graph_engine/graph_traversal/edge_processor.py index 9bd0f86fbf..c4625a8ff7 100644 --- a/api/core/workflow/graph_engine/graph_traversal/edge_processor.py +++ b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py @@ -5,9 +5,9 @@ Edge processing logic for graph traversal. from collections.abc import Sequence from typing import TYPE_CHECKING, final -from core.workflow.enums import NodeExecutionType -from core.workflow.graph import Edge, Graph -from core.workflow.graph_events import NodeRunStreamChunkEvent +from dify_graph.enums import NodeExecutionType +from dify_graph.graph import Edge, Graph +from dify_graph.graph_events import NodeRunStreamChunkEvent from ..graph_state_manager import GraphStateManager from ..response_coordinator import ResponseStreamCoordinator diff --git a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py similarity index 98% rename from api/core/workflow/graph_engine/graph_traversal/skip_propagator.py rename to api/dify_graph/graph_engine/graph_traversal/skip_propagator.py index b9c9243963..76445bccd2 100644 --- a/api/core/workflow/graph_engine/graph_traversal/skip_propagator.py +++ b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py @@ -5,7 +5,7 @@ Skip state propagation through the graph. from collections.abc import Sequence from typing import final -from core.workflow.graph import Edge, Graph +from dify_graph.graph import Edge, Graph from ..graph_state_manager import GraphStateManager diff --git a/api/core/workflow/graph_engine/layers/README.md b/api/dify_graph/graph_engine/layers/README.md similarity index 100% rename from api/core/workflow/graph_engine/layers/README.md rename to api/dify_graph/graph_engine/layers/README.md diff --git a/api/core/workflow/graph_engine/layers/__init__.py b/api/dify_graph/graph_engine/layers/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/layers/__init__.py rename to api/dify_graph/graph_engine/layers/__init__.py diff --git a/api/core/workflow/graph_engine/layers/base.py b/api/dify_graph/graph_engine/layers/base.py similarity index 94% rename from api/core/workflow/graph_engine/layers/base.py rename to api/dify_graph/graph_engine/layers/base.py index ff4a483aed..890336c1ca 100644 --- a/api/core/workflow/graph_engine/layers/base.py +++ b/api/dify_graph/graph_engine/layers/base.py @@ -7,10 +7,10 @@ intercept and respond to GraphEngine events. from abc import ABC, abstractmethod -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import ReadOnlyGraphRuntimeState +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import ReadOnlyGraphRuntimeState class GraphEngineLayerNotInitializedError(Exception): diff --git a/api/core/workflow/graph_engine/layers/debug_logging.py b/api/dify_graph/graph_engine/layers/debug_logging.py similarity index 99% rename from api/core/workflow/graph_engine/layers/debug_logging.py rename to api/dify_graph/graph_engine/layers/debug_logging.py index e0402cd09c..1af2e2db9e 100644 --- a/api/core/workflow/graph_engine/layers/debug_logging.py +++ b/api/dify_graph/graph_engine/layers/debug_logging.py @@ -11,7 +11,7 @@ from typing import Any, final from typing_extensions import override -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, diff --git a/api/core/workflow/graph_engine/layers/execution_limits.py b/api/dify_graph/graph_engine/layers/execution_limits.py similarity index 94% rename from api/core/workflow/graph_engine/layers/execution_limits.py rename to api/dify_graph/graph_engine/layers/execution_limits.py index a2d36d142d..48ba5608d9 100644 --- a/api/core/workflow/graph_engine/layers/execution_limits.py +++ b/api/dify_graph/graph_engine/layers/execution_limits.py @@ -15,13 +15,13 @@ from typing import final from typing_extensions import override -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType -from core.workflow.graph_engine.layers import GraphEngineLayer -from core.workflow.graph_events import ( +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType +from dify_graph.graph_engine.layers import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, NodeRunStartedEvent, ) -from core.workflow.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent class LimitType(StrEnum): diff --git a/api/core/workflow/graph_engine/manager.py b/api/dify_graph/graph_engine/manager.py similarity index 94% rename from api/core/workflow/graph_engine/manager.py rename to api/dify_graph/graph_engine/manager.py index 36f1612af0..955c149069 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/dify_graph/graph_engine/manager.py @@ -10,8 +10,8 @@ import logging from collections.abc import Sequence from typing import final -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from core.workflow.graph_engine.entities.commands import ( +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol +from dify_graph.graph_engine.entities.commands import ( AbortCommand, GraphEngineCommand, PauseCommand, diff --git a/api/core/workflow/graph_engine/orchestration/__init__.py b/api/dify_graph/graph_engine/orchestration/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/orchestration/__init__.py rename to api/dify_graph/graph_engine/orchestration/__init__.py diff --git a/api/core/workflow/graph_engine/orchestration/dispatcher.py b/api/dify_graph/graph_engine/orchestration/dispatcher.py similarity index 99% rename from api/core/workflow/graph_engine/orchestration/dispatcher.py rename to api/dify_graph/graph_engine/orchestration/dispatcher.py index 76dd1c7768..f8aaf20b2f 100644 --- a/api/core/workflow/graph_engine/orchestration/dispatcher.py +++ b/api/dify_graph/graph_engine/orchestration/dispatcher.py @@ -8,7 +8,7 @@ import threading import time from typing import TYPE_CHECKING, final -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunExceptionEvent, NodeRunFailedEvent, diff --git a/api/core/workflow/graph_engine/orchestration/execution_coordinator.py b/api/dify_graph/graph_engine/orchestration/execution_coordinator.py similarity index 100% rename from api/core/workflow/graph_engine/orchestration/execution_coordinator.py rename to api/dify_graph/graph_engine/orchestration/execution_coordinator.py diff --git a/api/core/workflow/graph_engine/protocols/command_channel.py b/api/dify_graph/graph_engine/protocols/command_channel.py similarity index 100% rename from api/core/workflow/graph_engine/protocols/command_channel.py rename to api/dify_graph/graph_engine/protocols/command_channel.py diff --git a/api/core/workflow/graph_engine/ready_queue/__init__.py b/api/dify_graph/graph_engine/ready_queue/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/__init__.py rename to api/dify_graph/graph_engine/ready_queue/__init__.py diff --git a/api/core/workflow/graph_engine/ready_queue/factory.py b/api/dify_graph/graph_engine/ready_queue/factory.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/factory.py rename to api/dify_graph/graph_engine/ready_queue/factory.py diff --git a/api/core/workflow/graph_engine/ready_queue/in_memory.py b/api/dify_graph/graph_engine/ready_queue/in_memory.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/in_memory.py rename to api/dify_graph/graph_engine/ready_queue/in_memory.py diff --git a/api/core/workflow/graph_engine/ready_queue/protocol.py b/api/dify_graph/graph_engine/ready_queue/protocol.py similarity index 100% rename from api/core/workflow/graph_engine/ready_queue/protocol.py rename to api/dify_graph/graph_engine/ready_queue/protocol.py diff --git a/api/core/workflow/graph_engine/response_coordinator/__init__.py b/api/dify_graph/graph_engine/response_coordinator/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/response_coordinator/__init__.py rename to api/dify_graph/graph_engine/response_coordinator/__init__.py diff --git a/api/core/workflow/graph_engine/response_coordinator/coordinator.py b/api/dify_graph/graph_engine/response_coordinator/coordinator.py similarity index 98% rename from api/core/workflow/graph_engine/response_coordinator/coordinator.py rename to api/dify_graph/graph_engine/response_coordinator/coordinator.py index e82ba29438..941a8a496b 100644 --- a/api/core/workflow/graph_engine/response_coordinator/coordinator.py +++ b/api/dify_graph/graph_engine/response_coordinator/coordinator.py @@ -14,11 +14,11 @@ from uuid import uuid4 from pydantic import BaseModel, Field -from core.workflow.enums import NodeExecutionType, NodeState -from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from core.workflow.nodes.base.template import TextSegment, VariableSegment -from core.workflow.runtime import VariablePool -from core.workflow.runtime.graph_runtime_state import GraphProtocol +from dify_graph.enums import NodeExecutionType, NodeState +from dify_graph.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent +from dify_graph.nodes.base.template import TextSegment, VariableSegment +from dify_graph.runtime import VariablePool +from dify_graph.runtime.graph_runtime_state import GraphProtocol from .path import Path from .session import ResponseSession diff --git a/api/core/workflow/graph_engine/response_coordinator/path.py b/api/dify_graph/graph_engine/response_coordinator/path.py similarity index 100% rename from api/core/workflow/graph_engine/response_coordinator/path.py rename to api/dify_graph/graph_engine/response_coordinator/path.py diff --git a/api/core/workflow/graph_engine/response_coordinator/session.py b/api/dify_graph/graph_engine/response_coordinator/session.py similarity index 54% rename from api/core/workflow/graph_engine/response_coordinator/session.py rename to api/dify_graph/graph_engine/response_coordinator/session.py index 5e4fada7d9..11a9f5dac5 100644 --- a/api/core/workflow/graph_engine/response_coordinator/session.py +++ b/api/dify_graph/graph_engine/response_coordinator/session.py @@ -8,12 +8,16 @@ by ResponseStreamCoordinator to manage streaming sessions. from __future__ import annotations from dataclasses import dataclass +from typing import Protocol, cast -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.knowledge_index import KnowledgeIndexNode -from core.workflow.runtime.graph_runtime_state import NodeProtocol +from dify_graph.nodes.base.template import Template +from dify_graph.runtime.graph_runtime_state import NodeProtocol + + +class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): + """Structural contract required from nodes that can open a response session.""" + + def get_streaming_template(self) -> Template: ... @dataclass @@ -33,10 +37,9 @@ class ResponseSession: """ Create a ResponseSession from a response-capable node. - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer, - but at runtime this must be an `AnswerNode`, `EndNode`, or `KnowledgeIndexNode` that provides: - - `id: str` - - `get_streaming_template() -> Template` + The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. + At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which + graph nodes should be treated as response-capable before they reach this factory. Args: node: Node from the materialized workflow graph. @@ -45,13 +48,17 @@ class ResponseSession: ResponseSession configured with the node's streaming template Raises: - TypeError: If node is not a supported response node type. + TypeError: If node does not implement the response-session streaming contract. """ - if not isinstance(node, AnswerNode | EndNode | KnowledgeIndexNode): - raise TypeError("ResponseSession.from_node only supports AnswerNode, EndNode, or KnowledgeIndexNode") + response_node = cast(_ResponseSessionNodeProtocol, node) + try: + template = response_node.get_streaming_template() + except AttributeError as exc: + raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc + return cls( node_id=node.id, - template=node.get_streaming_template(), + template=template, ) def is_complete(self) -> bool: diff --git a/api/core/workflow/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py similarity index 74% rename from api/core/workflow/graph_engine/worker.py rename to api/dify_graph/graph_engine/worker.py index 9e218f6fa6..988c20d72a 100644 --- a/api/core/workflow/graph_engine/worker.py +++ b/api/dify_graph/graph_engine/worker.py @@ -14,11 +14,14 @@ from typing import TYPE_CHECKING, final from typing_extensions import override -from core.workflow.context import IExecutionContext -from core.workflow.graph import Graph -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event -from core.workflow.nodes.base.node import Node +from dify_graph.context import IExecutionContext +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from libs.datetime_utils import naive_utc_now from .ready_queue import ReadyQueue @@ -65,6 +68,7 @@ class Worker(threading.Thread): self._stop_event = threading.Event() self._layers = layers if layers is not None else [] self._last_task_time = time.time() + self._current_node_started_at: datetime | None = None def stop(self) -> None: """Signal the worker to stop processing.""" @@ -104,18 +108,15 @@ class Worker(threading.Thread): self._last_task_time = time.time() node = self._graph.nodes[node_id] try: + self._current_node_started_at = None self._execute_node(node) self._ready_queue.task_done() except Exception as e: - error_event = NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=str(e), - start_at=datetime.now(), + self._event_queue.put( + self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) ) - self._event_queue.put(error_event) + finally: + self._current_node_started_at = None def _execute_node(self, node: Node) -> None: """ @@ -136,6 +137,8 @@ class Worker(threading.Thread): try: node_events = node.run() for event in node_events: + if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: + self._current_node_started_at = event.start_at self._event_queue.put(event) if is_node_result_event(event): result_event = event @@ -149,6 +152,8 @@ class Worker(threading.Thread): try: node_events = node.run() for event in node_events: + if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: + self._current_node_started_at = event.start_at self._event_queue.put(event) if is_node_result_event(event): result_event = event @@ -177,3 +182,24 @@ class Worker(threading.Thread): except Exception: # Silently ignore layer errors to prevent disrupting node execution continue + + def _build_fallback_failure_event( + self, node: Node, error: Exception, *, started_at: datetime | None = None + ) -> NodeRunFailedEvent: + """Build a failed event when worker-level execution aborts before a node emits its own result event.""" + failure_time = naive_utc_now() + error_message = str(error) + return NodeRunFailedEvent( + id=node.execution_id, + node_id=node.id, + node_type=node.node_type, + in_iteration_id=None, + error=error_message, + start_at=started_at or failure_time, + finished_at=failure_time, + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=error_message, + error_type=type(error).__name__, + ), + ) diff --git a/api/core/workflow/graph_engine/worker_management/__init__.py b/api/dify_graph/graph_engine/worker_management/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/worker_management/__init__.py rename to api/dify_graph/graph_engine/worker_management/__init__.py diff --git a/api/core/workflow/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py similarity index 98% rename from api/core/workflow/graph_engine/worker_management/worker_pool.py rename to api/dify_graph/graph_engine/worker_management/worker_pool.py index 2c14f53746..cc93087783 100644 --- a/api/core/workflow/graph_engine/worker_management/worker_pool.py +++ b/api/dify_graph/graph_engine/worker_management/worker_pool.py @@ -10,9 +10,9 @@ import queue import threading from typing import final -from core.workflow.context import IExecutionContext -from core.workflow.graph import Graph -from core.workflow.graph_events import GraphNodeEventBase +from dify_graph.context import IExecutionContext +from dify_graph.graph import Graph +from dify_graph.graph_events import GraphNodeEventBase from ..config import GraphEngineConfig from ..layers.base import GraphEngineLayer diff --git a/api/core/workflow/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py similarity index 100% rename from api/core/workflow/graph_events/__init__.py rename to api/dify_graph/graph_events/__init__.py diff --git a/api/core/workflow/graph_events/agent.py b/api/dify_graph/graph_events/agent.py similarity index 100% rename from api/core/workflow/graph_events/agent.py rename to api/dify_graph/graph_events/agent.py diff --git a/api/core/workflow/graph_events/base.py b/api/dify_graph/graph_events/base.py similarity index 87% rename from api/core/workflow/graph_events/base.py rename to api/dify_graph/graph_events/base.py index 3714679201..4560cf5085 100644 --- a/api/core/workflow/graph_events/base.py +++ b/api/dify_graph/graph_events/base.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field -from core.workflow.enums import NodeType -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import NodeType +from dify_graph.node_events import NodeRunResult class GraphEngineEvent(BaseModel): diff --git a/api/core/workflow/graph_events/graph.py b/api/dify_graph/graph_events/graph.py similarity index 90% rename from api/core/workflow/graph_events/graph.py rename to api/dify_graph/graph_events/graph.py index f46526bcab..f4aaba64d6 100644 --- a/api/core/workflow/graph_events/graph.py +++ b/api/dify_graph/graph_events/graph.py @@ -1,8 +1,8 @@ from pydantic import Field -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events import BaseGraphEvent +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph_events import BaseGraphEvent class GraphRunStartedEvent(BaseGraphEvent): diff --git a/api/core/workflow/graph_events/human_input.py b/api/dify_graph/graph_events/human_input.py similarity index 100% rename from api/core/workflow/graph_events/human_input.py rename to api/dify_graph/graph_events/human_input.py diff --git a/api/core/workflow/graph_events/iteration.py b/api/dify_graph/graph_events/iteration.py similarity index 100% rename from api/core/workflow/graph_events/iteration.py rename to api/dify_graph/graph_events/iteration.py diff --git a/api/core/workflow/graph_events/loop.py b/api/dify_graph/graph_events/loop.py similarity index 100% rename from api/core/workflow/graph_events/loop.py rename to api/dify_graph/graph_events/loop.py diff --git a/api/core/workflow/graph_events/node.py b/api/dify_graph/graph_events/node.py similarity index 89% rename from api/core/workflow/graph_events/node.py rename to api/dify_graph/graph_events/node.py index 975d72ad1f..df19d6c03b 100644 --- a/api/core/workflow/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -4,8 +4,7 @@ from datetime import datetime from pydantic import Field from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities import AgentNodeStrategyInit -from core.workflow.entities.pause_reason import PauseReason +from dify_graph.entities.pause_reason import PauseReason from .base import GraphNodeEventBase @@ -13,8 +12,8 @@ from .base import GraphNodeEventBase class NodeRunStartedEvent(GraphNodeEventBase): node_title: str predecessor_node_id: str | None = None - agent_strategy: AgentNodeStrategyInit | None = None start_at: datetime = Field(..., description="node start time") + extras: dict[str, object] = Field(default_factory=dict) # FIXME(-LAN-): only for ToolNode provider_type: str = "" @@ -37,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase): class NodeRunSucceededEvent(GraphNodeEventBase): start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunExceptionEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") + finished_at: datetime | None = Field(default=None, description="node finish time") class NodeRunRetryEvent(NodeRunStartedEvent): diff --git a/api/core/model_runtime/README.md b/api/dify_graph/model_runtime/README.md similarity index 100% rename from api/core/model_runtime/README.md rename to api/dify_graph/model_runtime/README.md diff --git a/api/core/model_runtime/README_CN.md b/api/dify_graph/model_runtime/README_CN.md similarity index 100% rename from api/core/model_runtime/README_CN.md rename to api/dify_graph/model_runtime/README_CN.md diff --git a/api/core/model_runtime/model_providers/__init__.py b/api/dify_graph/model_runtime/__init__.py similarity index 100% rename from api/core/model_runtime/model_providers/__init__.py rename to api/dify_graph/model_runtime/__init__.py diff --git a/api/core/model_runtime/schema_validators/__init__.py b/api/dify_graph/model_runtime/callbacks/__init__.py similarity index 100% rename from api/core/model_runtime/schema_validators/__init__.py rename to api/dify_graph/model_runtime/callbacks/__init__.py diff --git a/api/core/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py similarity index 94% rename from api/core/model_runtime/callbacks/base_callback.py rename to api/dify_graph/model_runtime/callbacks/base_callback.py index a745a91510..20faf3d6cd 100644 --- a/api/core/model_runtime/callbacks/base_callback.py +++ b/api/dify_graph/model_runtime/callbacks/base_callback.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel _TEXT_COLOR_MAPPING = { "blue": "36;1", diff --git a/api/core/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py similarity index 94% rename from api/core/model_runtime/callbacks/logging_callback.py rename to api/dify_graph/model_runtime/callbacks/logging_callback.py index b366fcc57b..49b9ab27eb 100644 --- a/api/core/model_runtime/callbacks/logging_callback.py +++ b/api/dify_graph/model_runtime/callbacks/logging_callback.py @@ -4,10 +4,10 @@ import sys from collections.abc import Sequence from typing import cast -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/entities/__init__.py b/api/dify_graph/model_runtime/entities/__init__.py similarity index 100% rename from api/core/model_runtime/entities/__init__.py rename to api/dify_graph/model_runtime/entities/__init__.py diff --git a/api/core/model_runtime/entities/common_entities.py b/api/dify_graph/model_runtime/entities/common_entities.py similarity index 100% rename from api/core/model_runtime/entities/common_entities.py rename to api/dify_graph/model_runtime/entities/common_entities.py diff --git a/api/core/model_runtime/entities/defaults.py b/api/dify_graph/model_runtime/entities/defaults.py similarity index 98% rename from api/core/model_runtime/entities/defaults.py rename to api/dify_graph/model_runtime/entities/defaults.py index 51c9c51257..53b732e5c6 100644 --- a/api/core/model_runtime/entities/defaults.py +++ b/api/dify_graph/model_runtime/entities/defaults.py @@ -1,4 +1,4 @@ -from core.model_runtime.entities.model_entities import DefaultParameterName +from dify_graph.model_runtime.entities.model_entities import DefaultParameterName PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { DefaultParameterName.TEMPERATURE: { diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/dify_graph/model_runtime/entities/llm_entities.py similarity index 97% rename from api/core/model_runtime/entities/llm_entities.py rename to api/dify_graph/model_runtime/entities/llm_entities.py index 2c7c421eed..eec682a2ae 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/dify_graph/model_runtime/entities/llm_entities.py @@ -7,8 +7,8 @@ from typing import Any, TypedDict, Union from pydantic import BaseModel, Field -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from core.model_runtime.entities.model_entities import ModelUsage, PriceInfo +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo class LLMMode(StrEnum): diff --git a/api/core/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py similarity index 98% rename from api/core/model_runtime/entities/message_entities.py rename to api/dify_graph/model_runtime/entities/message_entities.py index 9e46d72893..402bfdc606 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/dify_graph/model_runtime/entities/message_entities.py @@ -276,7 +276,4 @@ class ToolPromptMessage(PromptMessage): :return: True if prompt message is empty, False otherwise """ - if not super().is_empty() and not self.tool_call_id: - return False - - return True + return super().is_empty() and not self.tool_call_id diff --git a/api/core/model_runtime/entities/model_entities.py b/api/dify_graph/model_runtime/entities/model_entities.py similarity index 98% rename from api/core/model_runtime/entities/model_entities.py rename to api/dify_graph/model_runtime/entities/model_entities.py index 19194d162c..fbcde6740a 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/dify_graph/model_runtime/entities/model_entities.py @@ -6,7 +6,7 @@ from typing import Any from pydantic import BaseModel, ConfigDict, model_validator -from core.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.common_entities import I18nObject class ModelType(StrEnum): diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py similarity index 95% rename from api/core/model_runtime/entities/provider_entities.py rename to api/dify_graph/model_runtime/entities/provider_entities.py index 2d88751668..97a99ea7ce 100644 --- a/api/core/model_runtime/entities/provider_entities.py +++ b/api/dify_graph/model_runtime/entities/provider_entities.py @@ -3,8 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType class ConfigurateMethod(StrEnum): diff --git a/api/core/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py similarity index 100% rename from api/core/model_runtime/entities/rerank_entities.py rename to api/dify_graph/model_runtime/entities/rerank_entities.py diff --git a/api/core/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py similarity index 89% rename from api/core/model_runtime/entities/text_embedding_entities.py rename to api/dify_graph/model_runtime/entities/text_embedding_entities.py index 854c448250..a0210c169d 100644 --- a/api/core/model_runtime/entities/text_embedding_entities.py +++ b/api/dify_graph/model_runtime/entities/text_embedding_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal from pydantic import BaseModel -from core.model_runtime.entities.model_entities import ModelUsage +from dify_graph.model_runtime.entities.model_entities import ModelUsage class EmbeddingUsage(ModelUsage): diff --git a/api/core/model_runtime/utils/__init__.py b/api/dify_graph/model_runtime/errors/__init__.py similarity index 100% rename from api/core/model_runtime/utils/__init__.py rename to api/dify_graph/model_runtime/errors/__init__.py diff --git a/api/core/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py similarity index 92% rename from api/core/model_runtime/errors/invoke.py rename to api/dify_graph/model_runtime/errors/invoke.py index 80cf01fb6c..1a57078b98 100644 --- a/api/core/model_runtime/errors/invoke.py +++ b/api/dify_graph/model_runtime/errors/invoke.py @@ -4,7 +4,8 @@ class InvokeError(ValueError): description: str | None = None def __init__(self, description: str | None = None): - self.description = description + if description is not None: + self.description = description def __str__(self): return self.description or self.__class__.__name__ diff --git a/api/core/model_runtime/errors/validate.py b/api/dify_graph/model_runtime/errors/validate.py similarity index 100% rename from api/core/model_runtime/errors/validate.py rename to api/dify_graph/model_runtime/errors/validate.py diff --git a/api/core/model_runtime/memory/__init__.py b/api/dify_graph/model_runtime/memory/__init__.py similarity index 100% rename from api/core/model_runtime/memory/__init__.py rename to api/dify_graph/model_runtime/memory/__init__.py diff --git a/api/core/model_runtime/memory/prompt_message_memory.py b/api/dify_graph/model_runtime/memory/prompt_message_memory.py similarity index 89% rename from api/core/model_runtime/memory/prompt_message_memory.py rename to api/dify_graph/model_runtime/memory/prompt_message_memory.py index 4491ddfd05..a76a7faf71 100644 --- a/api/core/model_runtime/memory/prompt_message_memory.py +++ b/api/dify_graph/model_runtime/memory/prompt_message_memory.py @@ -3,7 +3,7 @@ from __future__ import annotations from collections.abc import Sequence from typing import Protocol -from core.model_runtime.entities import PromptMessage +from dify_graph.model_runtime.entities import PromptMessage DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 diff --git a/api/core/workflow/graph_engine/entities/__init__.py b/api/dify_graph/model_runtime/model_providers/__base/__init__.py similarity index 100% rename from api/core/workflow/graph_engine/entities/__init__.py rename to api/dify_graph/model_runtime/model_providers/__base/__init__.py diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py similarity index 97% rename from api/core/model_runtime/model_providers/__base/ai_model.py rename to api/dify_graph/model_runtime/model_providers/__base/ai_model.py index c3e50eaddd..ac7ae9925b 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py @@ -6,9 +6,10 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationError from redis import RedisError from configs import dify_config -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from core.model_runtime.entities.model_entities import ( +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE +from dify_graph.model_runtime.entities.model_entities import ( AIModelEntity, DefaultParameterName, ModelType, @@ -16,7 +17,7 @@ from core.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) -from core.model_runtime.errors.invoke import ( +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -24,7 +25,6 @@ from core.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py similarity index 98% rename from api/core/model_runtime/model_providers/__base/large_language_model.py rename to api/dify_graph/model_runtime/model_providers/__base/large_language_model.py index c32ab0879e..bf864ca227 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py @@ -7,21 +7,21 @@ from typing import Union from pydantic import ConfigDict from configs import dify_config -from core.model_runtime.callbacks.base_callback import Callback -from core.model_runtime.callbacks.logging_callback import LoggingCallback -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageContentUnionTypes, PromptMessageTool, TextPromptMessageContent, ) -from core.model_runtime.entities.model_entities import ( +from dify_graph.model_runtime.entities.model_entities import ( ModelType, PriceType, ) -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py similarity index 89% rename from api/core/model_runtime/model_providers/__base/moderation_model.py rename to api/dify_graph/model_runtime/model_providers/__base/moderation_model.py index 7aff0184f4..5fa3d1634b 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py @@ -2,8 +2,8 @@ import time from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class ModerationModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py similarity index 92% rename from api/core/model_runtime/model_providers/__base/rerank_model.py rename to api/dify_graph/model_runtime/model_providers/__base/rerank_model.py index 0a576b832a..5da2b84b95 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py @@ -1,6 +1,6 @@ -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.rerank_entities import RerankResult -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class RerankModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py similarity index 88% rename from api/core/model_runtime/model_providers/__base/speech2text_model.py rename to api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py index 9d3bf13e79..e69069a85d 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py @@ -2,8 +2,8 @@ from typing import IO from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class Speech2TextModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py similarity index 94% rename from api/core/model_runtime/model_providers/__base/text_embedding_model.py rename to api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py index 4c902e2c11..3438da2ada 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,9 +1,9 @@ from pydantic import ConfigDict from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel class TextEmbeddingModel(AIModel): diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py similarity index 100% rename from api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py rename to api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py similarity index 94% rename from api/core/model_runtime/model_providers/__base/tts_model.py rename to api/dify_graph/model_runtime/model_providers/__base/tts_model.py index a83c8be37c..0656529f22 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py @@ -3,8 +3,8 @@ from collections.abc import Iterable from pydantic import ConfigDict -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel logger = logging.getLogger(__name__) diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/dify_graph/model_runtime/model_providers/__init__.py similarity index 100% rename from api/core/workflow/nodes/answer/__init__.py rename to api/dify_graph/model_runtime/model_providers/__init__.py diff --git a/api/core/model_runtime/model_providers/_position.yaml b/api/dify_graph/model_runtime/model_providers/_position.yaml similarity index 100% rename from api/core/model_runtime/model_providers/_position.yaml rename to api/dify_graph/model_runtime/model_providers/_position.yaml diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py similarity index 92% rename from api/core/model_runtime/model_providers/model_provider_factory.py rename to api/dify_graph/model_runtime/model_providers/model_provider_factory.py index 9cfc6889ac..de0677a348 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -10,18 +10,20 @@ from redis import RedisError import contexts from configs import dify_config -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.model_providers.__base.moderation_model import ModerationModel -from core.model_runtime.model_providers.__base.rerank_model import RerankModel -from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from core.model_runtime.model_providers.__base.tts_model import TTSModel -from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel +from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator +from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) from extensions.ext_redis import redis_client from models.provider_ids import ModelProviderID @@ -280,7 +282,8 @@ class ModelProviderFactory: all_model_type_models.append(model_schema) simple_provider_schema = provider_schema.to_simple_provider() - simple_provider_schema.models.extend(all_model_type_models) + if model_type: + simple_provider_schema.models = all_model_type_models providers.append(simple_provider_schema) diff --git a/api/core/workflow/nodes/end/__init__.py b/api/dify_graph/model_runtime/schema_validators/__init__.py similarity index 100% rename from api/core/workflow/nodes/end/__init__.py rename to api/dify_graph/model_runtime/schema_validators/__init__.py diff --git a/api/core/model_runtime/schema_validators/common_validator.py b/api/dify_graph/model_runtime/schema_validators/common_validator.py similarity index 97% rename from api/core/model_runtime/schema_validators/common_validator.py rename to api/dify_graph/model_runtime/schema_validators/common_validator.py index 2caedeaf48..04cdb8e4f7 100644 --- a/api/core/model_runtime/schema_validators/common_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/common_validator.py @@ -1,6 +1,6 @@ from typing import Union, cast -from core.model_runtime.entities.provider_entities import CredentialFormSchema, FormType +from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType class CommonValidator: diff --git a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py similarity index 78% rename from api/core/model_runtime/schema_validators/model_credential_schema_validator.py rename to api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py index 0ac935ca31..a97796e98f 100644 --- a/api/core/model_runtime/schema_validators/model_credential_schema_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py @@ -1,6 +1,6 @@ -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ModelCredentialSchema -from core.model_runtime.schema_validators.common_validator import CommonValidator +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator class ModelCredentialSchemaValidator(CommonValidator): diff --git a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py similarity index 79% rename from api/core/model_runtime/schema_validators/provider_credential_schema_validator.py rename to api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py index 06350f92a9..2fed75a76c 100644 --- a/api/core/model_runtime/schema_validators/provider_credential_schema_validator.py +++ b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py @@ -1,5 +1,5 @@ -from core.model_runtime.entities.provider_entities import ProviderCredentialSchema -from core.model_runtime.schema_validators.common_validator import CommonValidator +from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator class ProviderCredentialSchemaValidator(CommonValidator): diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/dify_graph/model_runtime/utils/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/__init__.py rename to api/dify_graph/model_runtime/utils/__init__.py diff --git a/api/core/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py similarity index 100% rename from api/core/model_runtime/utils/encoders.py rename to api/dify_graph/model_runtime/utils/encoders.py diff --git a/api/core/workflow/node_events/__init__.py b/api/dify_graph/node_events/__init__.py similarity index 100% rename from api/core/workflow/node_events/__init__.py rename to api/dify_graph/node_events/__init__.py diff --git a/api/core/workflow/node_events/agent.py b/api/dify_graph/node_events/agent.py similarity index 100% rename from api/core/workflow/node_events/agent.py rename to api/dify_graph/node_events/agent.py diff --git a/api/core/workflow/node_events/base.py b/api/dify_graph/node_events/base.py similarity index 86% rename from api/core/workflow/node_events/base.py rename to api/dify_graph/node_events/base.py index 7fec47e21f..2f6259ae7d 100644 --- a/api/core/workflow/node_events/base.py +++ b/api/dify_graph/node_events/base.py @@ -3,8 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage class NodeEventBase(BaseModel): diff --git a/api/core/workflow/node_events/iteration.py b/api/dify_graph/node_events/iteration.py similarity index 100% rename from api/core/workflow/node_events/iteration.py rename to api/dify_graph/node_events/iteration.py diff --git a/api/core/workflow/node_events/loop.py b/api/dify_graph/node_events/loop.py similarity index 100% rename from api/core/workflow/node_events/loop.py rename to api/dify_graph/node_events/loop.py diff --git a/api/core/workflow/node_events/node.py b/api/dify_graph/node_events/node.py similarity index 79% rename from api/core/workflow/node_events/node.py rename to api/dify_graph/node_events/node.py index 2468bd0ac3..2e3973b8fa 100644 --- a/api/core/workflow/node_events/node.py +++ b/api/dify_graph/node_events/node.py @@ -1,19 +1,19 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.model_runtime.entities.llm_entities import LLMUsage -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.file import File +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult from .base import NodeEventBase class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") context_files: list[File] | None = Field(default=None, description="context files") diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py new file mode 100644 index 0000000000..0223149bb8 --- /dev/null +++ b/api/dify_graph/nodes/__init__.py @@ -0,0 +1,3 @@ +from dify_graph.enums import BuiltinNodeTypes + +__all__ = ["BuiltinNodeTypes"] diff --git a/api/core/workflow/nodes/variable_assigner/common/__init__.py b/api/dify_graph/nodes/answer/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/common/__init__.py rename to api/dify_graph/nodes/answer/__init__.py diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py similarity index 76% rename from api/core/workflow/nodes/answer/answer_node.py rename to api/dify_graph/nodes/answer/answer_node.py index 388447368e..4286e1a492 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/dify_graph/nodes/answer/answer_node.py @@ -1,17 +1,17 @@ from collections.abc import Mapping, Sequence from typing import Any -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.answer.entities import AnswerNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.variables import ArrayFileSegment, FileSegment, Segment +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.answer.entities import AnswerNodeData +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.variables import ArrayFileSegment, FileSegment, Segment class AnswerNode(Node[AnswerNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER execution_type = NodeExecutionType.RESPONSE @classmethod @@ -48,12 +48,10 @@ class AnswerNode(Node[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: AnswerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = AnswerNodeData.model_validate(node_data) - - variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) + _ = graph_config # Explicitly mark as unused + variable_template_parser = VariableTemplateParser(template=node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/core/workflow/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py similarity index 91% rename from api/core/workflow/nodes/answer/entities.py rename to api/dify_graph/nodes/answer/entities.py index 850ff14880..cd82df1ac4 100644 --- a/api/core/workflow/nodes/answer/entities.py +++ b/api/dify_graph/nodes/answer/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum, auto from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class AnswerNodeData(BaseNodeData): @@ -11,6 +12,7 @@ class AnswerNodeData(BaseNodeData): Answer Node Data. """ + type: NodeType = BuiltinNodeTypes.ANSWER answer: str = Field(..., description="answer template string") diff --git a/api/core/workflow/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py similarity index 79% rename from api/core/workflow/nodes/base/__init__.py rename to api/dify_graph/nodes/base/__init__.py index f83df0e323..036e25895d 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/dify_graph/nodes/base/__init__.py @@ -1,4 +1,4 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData +from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState from .usage_tracking_mixin import LLMUsageTrackingMixin __all__ = [ @@ -6,6 +6,5 @@ __all__ = [ "BaseIterationState", "BaseLoopNodeData", "BaseLoopState", - "BaseNodeData", "LLMUsageTrackingMixin", ] diff --git a/api/dify_graph/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py new file mode 100644 index 0000000000..4f8b2682e1 --- /dev/null +++ b/api/dify_graph/nodes/base/entities.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from collections.abc import Sequence +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, field_validator + +from dify_graph.entities.base_node_data import BaseNodeData + + +class VariableSelector(BaseModel): + """ + Variable Selector. + """ + + variable: str + value_selector: Sequence[str] + + +class OutputVariableType(StrEnum): + STRING = "string" + NUMBER = "number" + INTEGER = "integer" + SECRET = "secret" + BOOLEAN = "boolean" + OBJECT = "object" + FILE = "file" + ARRAY = "array" + ARRAY_STRING = "array[string]" + ARRAY_NUMBER = "array[number]" + ARRAY_OBJECT = "array[object]" + ARRAY_BOOLEAN = "array[boolean]" + ARRAY_FILE = "array[file]" + ANY = "any" + ARRAY_ANY = "array[any]" + + +class OutputVariableEntity(BaseModel): + """ + Output Variable Entity. + """ + + variable: str + value_type: OutputVariableType = OutputVariableType.ANY + value_selector: Sequence[str] + + @field_validator("value_type", mode="before") + @classmethod + def normalize_value_type(cls, v: Any) -> Any: + """ + Normalize value_type to handle case-insensitive array types. + Converts 'Array[...]' to 'array[...]' for backward compatibility. + """ + if isinstance(v, str) and v.startswith("Array["): + return v.lower() + return v + + +class BaseIterationNodeData(BaseNodeData): + start_node_id: str | None = None + + +class BaseIterationState(BaseModel): + iteration_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData + + +class BaseLoopNodeData(BaseNodeData): + start_node_id: str | None = None + + +class BaseLoopState(BaseModel): + loop_node_id: str + index: int + inputs: dict + + class MetaData(BaseModel): + pass + + metadata: MetaData diff --git a/api/core/workflow/nodes/base/node.py b/api/dify_graph/nodes/base/node.py similarity index 81% rename from api/core/workflow/nodes/base/node.py rename to api/dify_graph/nodes/base/node.py index 976e8032b8..56b46a5894 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -1,20 +1,26 @@ from __future__ import annotations -import importlib import logging import operator -import pkgutil from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import AgentNodeStrategyInit, GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_events import ( +from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import ( + ErrorStrategy, + NodeExecutionType, + NodeState, + NodeType, + WorkflowNodeExecutionStatus, +) +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunAgentLogEvent, NodeRunFailedEvent, @@ -34,7 +40,7 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.node_events import ( AgentLogEvent, HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, @@ -53,17 +59,32 @@ from core.workflow.node_events import ( StreamChunkEvent, StreamCompletedEvent, ) -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom - -from .entities import BaseNodeData, RetryConfig NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) +_MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) +class DifyRunContextProtocol(Protocol): + tenant_id: str + app_id: str + user_id: str + user_from: Any + invoke_from: Any + + +class _MappingDifyRunContext: + def __init__(self, mapping: Mapping[str, Any]) -> None: + self.tenant_id = str(mapping["tenant_id"]) + self.app_id = str(mapping["app_id"]) + self.user_id = str(mapping["user_id"]) + self.user_from = mapping["user_from"] + self.invoke_from = mapping["invoke_from"] + + class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -130,15 +151,15 @@ class Node(Generic[NodeDataT]): Later, in __init__: :: - config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate() - │ - ▼ - CodeNodeData instance - (stored in self._node_data) + config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) + │ + ▼ + CodeNodeData instance + (stored in self._node_data) Example: class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE # No need to implement _get_title, _get_error_strategy, etc. """ super().__init_subclass__(**kwargs) @@ -156,7 +177,8 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under core.workflow.nodes.* + # Only register production node implementations defined under the + # canonical workflow namespaces. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -164,7 +186,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith("core.workflow.nodes."): + if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -180,6 +202,7 @@ class Node(Generic[NodeDataT]): else: latest_key = max(version_keys) if version_keys else version bucket["latest"] = bucket[latest_key] + Node._registry_version += 1 @classmethod def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: @@ -214,43 +237,47 @@ class Node(Generic[NodeDataT]): # Global registry populated via __init_subclass__ _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} + _registry_version: ClassVar[int] = 0 + + @classmethod + def get_registry_version(cls) -> int: + return cls._registry_version def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, ) -> None: self._graph_init_params = graph_init_params + self._run_context = MappingProxyType(dict(graph_init_params.run_context)) self.id = id - self.tenant_id = graph_init_params.tenant_id - self.app_id = graph_init_params.app_id self.workflow_id = graph_init_params.workflow_id self.graph_config = graph_init_params.graph_config - self.user_id = graph_init_params.user_id - self.user_from = UserFrom(graph_init_params.user_from) - self.invoke_from = InvokeFrom(graph_init_params.invoke_from) self.workflow_call_depth = graph_init_params.call_depth self.graph_runtime_state = graph_runtime_state self.state: NodeState = NodeState.UNKNOWN # node execution state - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required.") + node_id = config["id"] self._node_id = node_id self._node_execution_id: str = "" self._start_at = naive_utc_now() - raw_node_data = config.get("data") or {} - if not isinstance(raw_node_data, Mapping): - raise ValueError("Node config data must be a mapping.") - - self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data) + self._node_data = self.validate_node_data(config["data"]) self.post_init() + @classmethod + def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model.""" + return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + + def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: + """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" + self._node_data = self.validate_node_data(cast(BaseNodeData, data)) + def post_init(self) -> None: """Optional hook for subclasses requiring extra initialization.""" return @@ -259,6 +286,38 @@ class Node(Generic[NodeDataT]): def graph_init_params(self) -> GraphInitParams: return self._graph_init_params + @property + def run_context(self) -> Mapping[str, Any]: + return self._run_context + + def get_run_context_value(self, key: str, default: Any = None) -> Any: + return self._run_context.get(key, default) + + def require_run_context_value(self, key: str) -> Any: + value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) + if value is _MISSING_RUN_CONTEXT_VALUE: + raise ValueError(f"run_context missing required key: {key}") + return value + + def require_dify_context(self) -> DifyRunContextProtocol: + raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + + if isinstance(raw_ctx, Mapping): + missing_keys = [ + key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx + ] + if missing_keys: + raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") + return _MappingDifyRunContext(raw_ctx) + + for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): + if not hasattr(raw_ctx, attr): + raise TypeError(f"invalid dify context object, missing attribute: {attr}") + + return cast(DifyRunContextProtocol, raw_ctx) + @property def execution_id(self) -> str: return self._node_execution_id @@ -291,9 +350,6 @@ class Node(Generic[NodeDataT]): return None return str(execution_id) - def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT: - return cast(NodeDataT, self._node_data_type.model_validate(data)) - @abstractmethod def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: """ @@ -302,6 +358,10 @@ class Node(Generic[NodeDataT]): """ raise NotImplementedError + def populate_start_event(self, event: NodeRunStartedEvent) -> None: + """Allow subclasses to enrich the started event without cross-node imports in the base class.""" + _ = event + def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() self._start_at = naive_utc_now() @@ -315,41 +375,10 @@ class Node(Generic[NodeDataT]): in_iteration_id=None, start_at=self._start_at, ) - - # === FIXME(-LAN-): Needs to refactor. - from core.workflow.nodes.tool.tool_node import ToolNode - - if isinstance(self, ToolNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from core.workflow.nodes.datasource.datasource_node import DatasourceNode - - if isinstance(self, DatasourceNode): - plugin_id = getattr(self.node_data, "plugin_id", "") - provider_name = getattr(self.node_data, "provider_name", "") - - start_event.provider_id = f"{plugin_id}/{provider_name}" - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode - - if isinstance(self, TriggerEventNode): - start_event.provider_id = getattr(self.node_data, "provider_id", "") - start_event.provider_type = getattr(self.node_data, "provider_type", "") - - from typing import cast - - from core.workflow.nodes.agent.agent_node import AgentNode - from core.workflow.nodes.agent.entities import AgentNodeData - - if isinstance(self, AgentNode): - start_event.agent_strategy = AgentNodeStrategyInit( - name=cast(AgentNodeData, self.node_data).agent_strategy_name, - icon=self.agent_strategy_icon, - ) - - # === + try: + self.populate_start_event(start_event) + except Exception: + logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) yield start_event try: @@ -377,11 +406,13 @@ class Node(Generic[NodeDataT]): error=str(e), error_type="WorkflowNodeError", ) + finished_at = naive_utc_now() yield NodeRunFailedEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, error=str(e), ) @@ -391,7 +422,7 @@ class Node(Generic[NodeDataT]): cls, *, graph_config: Mapping[str, Any], - config: Mapping[str, Any], + config: NodeConfigDict, ) -> Mapping[str, Sequence[str]]: """Extracts references variable selectors from node configuration. @@ -429,13 +460,12 @@ class Node(Generic[NodeDataT]): :param config: node config :return: """ - node_id = config.get("id") - if not node_id: - raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - - # Pass raw dict data instead of creating NodeData instance + node_id = config["id"] + node_data = cls.validate_node_data(config["data"]) data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) + graph_config=graph_config, + node_id=node_id, + node_data=node_data, ) return data @@ -445,7 +475,7 @@ class Node(Generic[NodeDataT]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: NodeDataT, ) -> Mapping[str, Sequence[str]]: return {} @@ -469,30 +499,20 @@ class Node(Generic[NodeDataT]): @abstractmethod def version(cls) -> str: """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`. - # - # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING` - # in `api/core/workflow/nodes/__init__.py`. + # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so + # registry lookups can resolve numeric versions and `latest`. raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @classmethod def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return mapping of NodeType -> {version -> Node subclass} using __init_subclass__ registry. + """Return a read-only view of the currently registered node classes. - Import all modules under core.workflow.nodes so subclasses register themselves on import. - Then we return a readonly view of the registry to avoid accidental mutation. + This accessor intentionally performs no imports. The embedding layer that + owns bootstrap (for example `core.workflow.node_factory`) must import any + extension node packages before calling it so their subclasses register via + `__init_subclass__`. """ - # Import all node modules to ensure they are loaded (thus registered) - import core.workflow.nodes as _nodes_pkg - - for _, _modname, _ in pkgutil.walk_packages(_nodes_pkg.__path__, _nodes_pkg.__name__ + "."): - # Avoid importing modules that depend on the registry to prevent circular imports. - if _modname == "core.workflow.nodes.node_mapping": - continue - importlib.import_module(_modname) - - # Return a readonly view so callers can't mutate the registry by accident - return {nt: MappingProxyType(ver_map) for nt, ver_map in cls._registry.items()} + return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} @property def retry(self) -> bool: @@ -550,6 +570,7 @@ class Node(Generic[NodeDataT]): return self._node_data def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: + finished_at = naive_utc_now() match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( @@ -557,6 +578,7 @@ class Node(Generic[NodeDataT]): node_id=self.id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, error=result.error, ) @@ -566,6 +588,7 @@ class Node(Generic[NodeDataT]): node_id=self.id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=result, ) case _: @@ -588,6 +611,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: + finished_at = naive_utc_now() match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -595,6 +619,7 @@ class Node(Generic[NodeDataT]): node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=event.node_run_result, ) case WorkflowNodeExecutionStatus.FAILED: @@ -603,6 +628,7 @@ class Node(Generic[NodeDataT]): node_id=self._node_id, node_type=self.node_type, start_at=self._start_at, + finished_at=finished_at, node_run_result=event.node_run_result, error=event.node_run_result.error, ) @@ -767,11 +793,16 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: + from core.rag.entities.citation_metadata import RetrievalSourceMetadata + + retriever_resources = [ + RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources + ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=event.retriever_resources, + retriever_resources=retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/core/workflow/nodes/base/template.py b/api/dify_graph/nodes/base/template.py similarity index 98% rename from api/core/workflow/nodes/base/template.py rename to api/dify_graph/nodes/base/template.py index 81f4b9f6fb..5976e808e3 100644 --- a/api/core/workflow/nodes/base/template.py +++ b/api/dify_graph/nodes/base/template.py @@ -11,7 +11,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Any, Union -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @dataclass(frozen=True) diff --git a/api/core/workflow/nodes/base/usage_tracking_mixin.py b/api/dify_graph/nodes/base/usage_tracking_mixin.py similarity index 89% rename from api/core/workflow/nodes/base/usage_tracking_mixin.py rename to api/dify_graph/nodes/base/usage_tracking_mixin.py index d9a0ef8972..bd49419fd3 100644 --- a/api/core/workflow/nodes/base/usage_tracking_mixin.py +++ b/api/dify_graph/nodes/base/usage_tracking_mixin.py @@ -1,5 +1,5 @@ -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.runtime import GraphRuntimeState +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState class LLMUsageTrackingMixin: diff --git a/api/core/workflow/nodes/base/variable_template_parser.py b/api/dify_graph/nodes/base/variable_template_parser.py similarity index 100% rename from api/core/workflow/nodes/base/variable_template_parser.py rename to api/dify_graph/nodes/base/variable_template_parser.py diff --git a/api/core/workflow/nodes/code/__init__.py b/api/dify_graph/nodes/code/__init__.py similarity index 100% rename from api/core/workflow/nodes/code/__init__.py rename to api/dify_graph/nodes/code/__init__.py diff --git a/api/core/workflow/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py similarity index 96% rename from api/core/workflow/nodes/code/code_node.py rename to api/dify_graph/nodes/code/code_node.py index 7b1cbfcfea..82d5fced62 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/dify_graph/nodes/code/code_node.py @@ -3,13 +3,14 @@ from decimal import Decimal from textwrap import dedent from typing import TYPE_CHECKING, Any, Protocol, cast -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.variables.segments import ArrayFileSegment -from core.workflow.variables.types import SegmentType +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.variables.segments import ArrayFileSegment +from dify_graph.variables.types import SegmentType from .exc import ( CodeNodeError, @@ -18,8 +19,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class WorkflowCodeExecutor(Protocol): @@ -71,13 +72,13 @@ _DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { class CodeNode(Node[CodeNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE _limits: CodeNodeLimits def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -466,15 +467,12 @@ class CodeNode(Node[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: CodeNodeData, ) -> Mapping[str, Sequence[str]]: _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = CodeNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } @property diff --git a/api/core/workflow/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py similarity index 82% rename from api/core/workflow/nodes/code/entities.py rename to api/dify_graph/nodes/code/entities.py index 8b73b89e2f..55b4ee4862 100644 --- a/api/core/workflow/nodes/code/entities.py +++ b/api/dify_graph/nodes/code/entities.py @@ -3,9 +3,10 @@ from typing import Annotated, Literal from pydantic import AfterValidator, BaseModel -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.variables.types import SegmentType +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.variables.types import SegmentType class CodeLanguage(StrEnum): @@ -39,6 +40,8 @@ class CodeNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = BuiltinNodeTypes.CODE + class Output(BaseModel): type: Annotated[SegmentType, AfterValidator(_validate_type)] children: dict[str, "CodeNodeData.Output"] | None = None diff --git a/api/core/workflow/nodes/code/exc.py b/api/dify_graph/nodes/code/exc.py similarity index 100% rename from api/core/workflow/nodes/code/exc.py rename to api/dify_graph/nodes/code/exc.py diff --git a/api/core/workflow/nodes/code/limits.py b/api/dify_graph/nodes/code/limits.py similarity index 100% rename from api/core/workflow/nodes/code/limits.py rename to api/dify_graph/nodes/code/limits.py diff --git a/api/core/workflow/nodes/document_extractor/__init__.py b/api/dify_graph/nodes/document_extractor/__init__.py similarity index 100% rename from api/core/workflow/nodes/document_extractor/__init__.py rename to api/dify_graph/nodes/document_extractor/__init__.py diff --git a/api/core/workflow/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py similarity index 60% rename from api/core/workflow/nodes/document_extractor/entities.py rename to api/dify_graph/nodes/document_extractor/entities.py index db05bbf4fe..1110cc2710 100644 --- a/api/core/workflow/nodes/document_extractor/entities.py +++ b/api/dify_graph/nodes/document_extractor/entities.py @@ -1,10 +1,12 @@ from collections.abc import Sequence from dataclasses import dataclass -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class DocumentExtractorNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/document_extractor/exc.py b/api/dify_graph/nodes/document_extractor/exc.py similarity index 100% rename from api/core/workflow/nodes/document_extractor/exc.py rename to api/dify_graph/nodes/document_extractor/exc.py diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py similarity index 87% rename from api/core/workflow/nodes/document_extractor/node.py rename to api/dify_graph/nodes/document_extractor/node.py index 59be4c54ef..27196f1aca 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/dify_graph/nodes/document_extractor/node.py @@ -4,6 +4,7 @@ import json import logging import os import tempfile +import zipfile from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any @@ -20,13 +21,14 @@ from docx.oxml.text.paragraph import CT_P from docx.table import Table from docx.text.paragraph import Paragraph -from core.helper import ssrf_proxy -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, file_manager -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayStringSegment, FileSegment +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, file_manager +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayStringSegment, FileSegment from .entities import DocumentExtractorNodeData, UnstructuredApiConfig from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -34,8 +36,8 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class DocumentExtractorNode(Node[DocumentExtractorNodeData]): @@ -44,7 +46,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): Supports plain text, PDF, and DOC/DOCX files. """ - node_type = NodeType.DOCUMENT_EXTRACTOR + node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR @classmethod def version(cls) -> str: @@ -53,11 +55,12 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, unstructured_api_config: UnstructuredApiConfig | None = None, + http_client: HttpClientProtocol, ) -> None: super().__init__( id=id, @@ -66,6 +69,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): graph_runtime_state=graph_runtime_state, ) self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() + self._http_client = http_client def _run(self): variable_selector = self.node_data.variable_selector @@ -80,12 +84,24 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): value = variable.value inputs = {"variable_selector": variable_selector} + if isinstance(value, list): + value = list(filter(lambda x: x, value)) process_data = {"documents": value if isinstance(value, list) else [value]} + if not value: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={"text": ArrayStringSegment(value=[])}, + ) + try: if isinstance(value, list): extracted_text_list = [ - _extract_text_from_file(file, unstructured_api_config=self._unstructured_api_config) + _extract_text_from_file( + self._http_client, file, unstructured_api_config=self._unstructured_api_config + ) for file in value ] return NodeRunResult( @@ -95,7 +111,9 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): outputs={"text": ArrayStringSegment(value=extracted_text_list)}, ) elif isinstance(value, File): - extracted_text = _extract_text_from_file(value, unstructured_api_config=self._unstructured_api_config) + extracted_text = _extract_text_from_file( + self._http_client, value, unstructured_api_config=self._unstructured_api_config + ) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=inputs, @@ -105,6 +123,7 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): else: raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") except DocumentExtractorError as e: + logger.warning(e, exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=str(e), @@ -118,12 +137,10 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: DocumentExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = DocumentExtractorNodeData.model_validate(node_data) - - return {node_id + ".files": typed_node_data.variable_selector} + _ = graph_config # Explicitly mark as unused + return {node_id + ".files": node_data.variable_selector} def _extract_text_by_mime_type( @@ -379,6 +396,32 @@ def parser_docx_part(block, doc: Document, content_items, i): content_items.append((i, "table", Table(block, doc))) +def _normalize_docx_zip(file_content: bytes) -> bytes: + """ + Some DOCX files (e.g. exported by Evernote on Windows) are malformed: + ZIP entry names use backslash (\\) as path separator instead of the forward + slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry + "word\\document.xml" is never found when python-docx looks for + "word/document.xml", which triggers a KeyError about a missing relationship. + + This function rewrites the ZIP in-memory, normalizing all entry names to + use forward slashes without touching any actual document content. + """ + try: + with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: + out_buf = io.BytesIO() + with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: + for item in zin.infolist(): + data = zin.read(item.filename) + # Normalize backslash path separators to forward slash + item.filename = item.filename.replace("\\", "/") + zout.writestr(item, data) + return out_buf.getvalue() + except zipfile.BadZipFile: + # Not a valid zip — return as-is and let python-docx report the real error + return file_content + + def _extract_text_from_docx(file_content: bytes) -> str: """ Extract text from a DOCX file. @@ -386,7 +429,15 @@ def _extract_text_from_docx(file_content: bytes) -> str: """ try: doc_file = io.BytesIO(file_content) - doc = docx.Document(doc_file) + try: + doc = docx.Document(doc_file) + except Exception as e: + logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) + # Some DOCX files exported by tools like Evernote on Windows use + # backslash path separators in ZIP entries and/or single-quoted XML + # attributes, both of which break python-docx on Linux. Normalize and retry. + file_content = _normalize_docx_zip(file_content) + doc = docx.Document(io.BytesIO(file_content)) text = [] # Keep track of paragraph and table positions @@ -439,13 +490,13 @@ def _extract_text_from_docx(file_content: bytes) -> str: raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e -def _download_file_content(file: File) -> bytes: +def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: """Download the content of a file based on its transfer method.""" try: if file.transfer_method == FileTransferMethod.REMOTE_URL: if file.remote_url is None: raise FileDownloadError("Missing URL for remote file") - response = ssrf_proxy.get(file.remote_url) + response = http_client.get(file.remote_url) response.raise_for_status() return response.content else: @@ -454,8 +505,10 @@ def _download_file_content(file: File) -> bytes: raise FileDownloadError(f"Error downloading file: {str(e)}") from e -def _extract_text_from_file(file: File, *, unstructured_api_config: UnstructuredApiConfig) -> str: - file_content = _download_file_content(file) +def _extract_text_from_file( + http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig +) -> str: + file_content = _download_file_content(http_client, file) if file.extension: extracted_text = _extract_text_by_file_extension( file_content=file_content, diff --git a/api/core/workflow/utils/__init__.py b/api/dify_graph/nodes/end/__init__.py similarity index 100% rename from api/core/workflow/utils/__init__.py rename to api/dify_graph/nodes/end/__init__.py diff --git a/api/core/workflow/nodes/end/end_node.py b/api/dify_graph/nodes/end/end_node.py similarity index 79% rename from api/core/workflow/nodes/end/end_node.py rename to api/dify_graph/nodes/end/end_node.py index 2efcb4f418..1f5cfab22b 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/dify_graph/nodes/end/end_node.py @@ -1,12 +1,12 @@ -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.template import Template -from core.workflow.nodes.end.entities import EndNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.template import Template +from dify_graph.nodes.end.entities import EndNodeData class EndNode(Node[EndNodeData]): - node_type = NodeType.END + node_type = BuiltinNodeTypes.END execution_type = NodeExecutionType.RESPONSE @classmethod diff --git a/api/core/workflow/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py similarity index 71% rename from api/core/workflow/nodes/end/entities.py rename to api/dify_graph/nodes/end/entities.py index 87a221b5f6..be7f0c8de8 100644 --- a/api/core/workflow/nodes/end/entities.py +++ b/api/dify_graph/nodes/end/entities.py @@ -1,6 +1,8 @@ from pydantic import BaseModel, Field -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.entities import OutputVariableEntity class EndNodeData(BaseNodeData): @@ -8,6 +10,7 @@ class EndNodeData(BaseNodeData): END Node Data. """ + type: NodeType = BuiltinNodeTypes.END outputs: list[OutputVariableEntity] diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/dify_graph/nodes/http_request/__init__.py similarity index 100% rename from api/core/workflow/nodes/http_request/__init__.py rename to api/dify_graph/nodes/http_request/__init__.py diff --git a/api/core/workflow/nodes/http_request/config.py b/api/dify_graph/nodes/http_request/config.py similarity index 100% rename from api/core/workflow/nodes/http_request/config.py rename to api/dify_graph/nodes/http_request/config.py diff --git a/api/core/workflow/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py similarity index 97% rename from api/core/workflow/nodes/http_request/entities.py rename to api/dify_graph/nodes/http_request/entities.py index 0eda20f485..f594d58ae6 100644 --- a/api/core/workflow/nodes/http_request/entities.py +++ b/api/dify_graph/nodes/http_request/entities.py @@ -8,7 +8,8 @@ import charset_normalizer import httpx from pydantic import BaseModel, Field, ValidationInfo, field_validator -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" @@ -89,6 +90,7 @@ class HttpRequestNodeData(BaseNodeData): Code Node Data. """ + type: NodeType = BuiltinNodeTypes.HTTP_REQUEST method: Literal[ "get", "post", diff --git a/api/core/workflow/nodes/http_request/exc.py b/api/dify_graph/nodes/http_request/exc.py similarity index 100% rename from api/core/workflow/nodes/http_request/exc.py rename to api/dify_graph/nodes/http_request/exc.py diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py similarity index 99% rename from api/core/workflow/nodes/http_request/executor.py rename to api/dify_graph/nodes/http_request/executor.py index de14c8c517..892b0fc688 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/dify_graph/nodes/http_request/executor.py @@ -10,9 +10,9 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from core.workflow.file.enums import FileTransferMethod -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ArrayFileSegment, FileSegment +from dify_graph.file.enums import FileTransferMethod +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ArrayFileSegment, FileSegment from ..protocols import FileManagerProtocol, HttpClientProtocol from .entities import ( diff --git a/api/core/workflow/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py similarity index 87% rename from api/core/workflow/nodes/http_request/node.py rename to api/dify_graph/nodes/http_request/node.py index 11458db758..3e5253d809 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -3,15 +3,16 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request.executor import Executor -from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from core.workflow.variables.segments import ArrayFileSegment +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request.executor import Executor +from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.variables.segments import ArrayFileSegment from factories import file_factory from .config import build_http_request_config, resolve_http_request_config @@ -27,17 +28,17 @@ from .exc import HttpRequestNodeError, RequestBodyError logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = NodeType.HTTP_REQUEST + node_type = BuiltinNodeTypes.HTTP_REQUEST def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -100,6 +101,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): timeout=self._get_request_timeout(self.node_data), variable_pool=self.graph_runtime_state.variable_pool, http_request_config=self._http_request_config, + # Must be 0 to disable executor-level retries, as the graph engine handles them. + # This is critical to prevent nested retries. max_retries=0, ssl_verify=self.node_data.ssl_verify, http_client=self._http_client, @@ -163,18 +166,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HttpRequestNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = HttpRequestNodeData.model_validate(node_data) - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) - if typed_node_data.body: - body_type = typed_node_data.body.type - data = typed_node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(node_data.params) + if node_data.body: + body_type = node_data.body.type + data = node_data.body.data match body_type: case "none": pass @@ -212,6 +212,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ + dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -236,8 +237,8 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, conversation_id=None, file_binary=content, mimetype=mime_type, @@ -249,7 +250,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=dify_ctx.tenant_id, ) files.append(file) diff --git a/api/core/workflow/nodes/human_input/__init__.py b/api/dify_graph/nodes/human_input/__init__.py similarity index 100% rename from api/core/workflow/nodes/human_input/__init__.py rename to api/dify_graph/nodes/human_input/__init__.py diff --git a/api/core/workflow/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py similarity index 82% rename from api/core/workflow/nodes/human_input/entities.py rename to api/dify_graph/nodes/human_input/entities.py index a4473dfa7d..2a33b4a0a8 100644 --- a/api/core/workflow/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -8,12 +8,15 @@ from collections.abc import Mapping, Sequence from datetime import datetime, timedelta from typing import Annotated, Any, ClassVar, Literal, Self +import bleach +import markdown from pydantic import BaseModel, Field, field_validator, model_validator -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.runtime import VariablePool -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.runtime import VariablePool +from dify_graph.variables.consts import SELECTORS_LENGTH from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit @@ -57,6 +60,39 @@ class EmailDeliveryConfig(BaseModel): """Configuration for email delivery method.""" URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "blockquote", + "br", + "code", + "em", + "h1", + "h2", + "h3", + "h4", + "h5", + "h6", + "hr", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] recipients: EmailRecipients @@ -71,8 +107,8 @@ class EmailDeliveryConfig(BaseModel): body: str debug_mode: bool = False - def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig": - if not user_id: + def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": + if user_id is None: debug_recipients = EmailRecipients(whole_workspace=False, items=[]) return self.model_copy(update={"recipients": debug_recipients}) debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) @@ -97,6 +133,43 @@ class EmailDeliveryConfig(BaseModel): return templated_body return variable_pool.convert_template(templated_body).text + @classmethod + def render_markdown_body(cls, body: str) -> str: + """Render markdown to safe HTML for email delivery.""" + sanitized_markdown = bleach.clean( + body, + tags=[], + attributes={}, + strip=True, + strip_comments=True, + ) + rendered_html = markdown.markdown( + sanitized_markdown, + extensions=["nl2br", "tables"], + extension_configs={"tables": {"use_align_attribute": True}}, + ) + return bleach.clean( + rendered_html, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + strip_comments=True, + ) + + @classmethod + def sanitize_subject(cls, subject: str) -> str: + """Sanitize email subject to plain text and prevent CRLF injection.""" + sanitized_subject = bleach.clean( + subject, + tags=[], + attributes={}, + strip=True, + strip_comments=True, + ) + sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) + return " ".join(sanitized_subject.split()) + class _DeliveryMethodBase(BaseModel): """Base delivery method configuration.""" @@ -140,7 +213,7 @@ def apply_debug_email_recipient( method: DeliveryChannelConfig, *, enabled: bool, - user_id: str, + user_id: str | None, ) -> DeliveryChannelConfig: if not enabled: return method @@ -148,7 +221,7 @@ def apply_debug_email_recipient( return method if not method.config.debug_mode: return method - debug_config = method.config.with_debug_recipient(user_id or "") + debug_config = method.config.with_debug_recipient(user_id) return method.model_copy(update={"config": debug_config}) @@ -214,6 +287,7 @@ class UserAction(BaseModel): class HumanInputNodeData(BaseNodeData): """Human Input node data.""" + type: NodeType = BuiltinNodeTypes.HUMAN_INPUT delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py similarity index 100% rename from api/core/workflow/nodes/human_input/enums.py rename to api/dify_graph/nodes/human_input/enums.py diff --git a/api/core/workflow/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py similarity index 85% rename from api/core/workflow/nodes/human_input/human_input_node.py rename to api/dify_graph/nodes/human_input/human_input_node.py index 1d7522ea25..794e33d92e 100644 --- a/api/core/workflow/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -3,44 +3,44 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.app.entities.app_invoke_entities import InvokeFrom -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import ( +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from dify_graph.node_events import ( HumanInputFormFilledEvent, HumanInputFormTimeoutEvent, NodeRunResult, PauseRequestedEvent, ) -from core.workflow.node_events.base import NodeEventBase -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.node_events.base import NodeEventBase +from dify_graph.node_events.node import StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter -from extensions.ext_database import db +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.datetime_utils import naive_utc_now from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: - from core.workflow.entities.graph_init_params import GraphInitParams - from core.workflow.runtime.graph_runtime_state import GraphRuntimeState + from dify_graph.entities.graph_init_params import GraphInitParams + from dify_graph.runtime.graph_runtime_state import GraphRuntimeState _SELECTED_BRANCH_KEY = "selected_branch" +_INVOKE_FROM_DEBUGGER = "debugger" +_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) class HumanInputNode(Node[HumanInputNodeData]): - node_type = NodeType.HUMAN_INPUT + node_type = BuiltinNodeTypes.HUMAN_INPUT execution_type = NodeExecutionType.BRANCH _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( @@ -64,10 +64,10 @@ class HumanInputNode(Node[HumanInputNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository | None = None, + form_repository: HumanInputFormRepository, ) -> None: super().__init__( id=id, @@ -75,11 +75,6 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - if form_repository is None: - form_repository = HumanInputFormRepositoryImpl( - session_factory=db.engine, - tenant_id=self.tenant_id, - ) self._form_repository = form_repository @classmethod @@ -163,30 +158,39 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults def _should_require_console_recipient(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + invoke_from = self._invoke_from_value() + if invoke_from == _INVOKE_FROM_DEBUGGER: return True - if self.invoke_from == InvokeFrom.EXPLORE: + if invoke_from == _INVOKE_FROM_EXPLORE: return self._node_data.is_webapp_enabled() return False def _display_in_ui(self) -> bool: - if self.invoke_from == InvokeFrom.DEBUGGER: + if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: return True return self._node_data.is_webapp_enabled() def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: + dify_ctx = self.require_dify_context() + invoke_from = self._invoke_from_value() enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}: + if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] return [ apply_debug_email_recipient( method, - enabled=self.invoke_from == InvokeFrom.DEBUGGER, - user_id=self.user_id or "", + enabled=invoke_from == _INVOKE_FROM_DEBUGGER, + user_id=dify_ctx.user_id, ) for method in enabled_methods ] + def _invoke_from_value(self) -> str: + invoke_from = self.require_dify_context().invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() @@ -220,10 +224,11 @@ class HumanInputNode(Node[HumanInputNodeData]): """ repo = self._form_repository form = repo.get_form(self._workflow_execution_id, self.id) + dify_ctx = self.require_dify_context() if form is None: display_in_ui = self._display_in_ui() params = FormCreateParams( - app_id=self.app_id, + app_id=dify_ctx.app_id, workflow_execution_id=self._workflow_execution_id, node_id=self.id, form_config=self._node_data, @@ -233,7 +238,9 @@ class HumanInputNode(Node[HumanInputNodeData]): resolved_default_values=self.resolve_default_values(), console_recipient_required=self._should_require_console_recipient(), console_creator_account_id=( - self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None + dify_ctx.user_id + if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} + else None ), backstage_recipient_required=True, ) @@ -342,7 +349,7 @@ class HumanInputNode(Node[HumanInputNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: HumanInputNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selectors referenced in form content and input default values. @@ -351,5 +358,4 @@ class HumanInputNode(Node[HumanInputNodeData]): 1. Variables referenced in form_content ({{#node_name.var_name#}}) 2. Variables referenced in input default values """ - validated_node_data = HumanInputNodeData.model_validate(node_data) - return validated_node_data.extract_variable_selector_to_variable_mapping(node_id) + return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/core/workflow/nodes/if_else/__init__.py b/api/dify_graph/nodes/if_else/__init__.py similarity index 100% rename from api/core/workflow/nodes/if_else/__init__.py rename to api/dify_graph/nodes/if_else/__init__.py diff --git a/api/core/workflow/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py similarity index 70% rename from api/core/workflow/nodes/if_else/entities.py rename to api/dify_graph/nodes/if_else/entities.py index b22bd6f508..ff09f3c023 100644 --- a/api/core/workflow/nodes/if_else/entities.py +++ b/api/dify_graph/nodes/if_else/entities.py @@ -2,8 +2,9 @@ from typing import Literal from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData -from core.workflow.utils.condition.entities import Condition +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.utils.condition.entities import Condition class IfElseNodeData(BaseNodeData): @@ -11,6 +12,8 @@ class IfElseNodeData(BaseNodeData): If Else Node Data. """ + type: NodeType = BuiltinNodeTypes.IF_ELSE + class Case(BaseModel): """ Case entity representing a single logical condition group diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py similarity index 85% rename from api/core/workflow/nodes/if_else/if_else_node.py rename to api/dify_graph/nodes/if_else/if_else_node.py index cda5f1dd42..7c0370e48c 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -3,17 +3,17 @@ from typing import Any, Literal from typing_extensions import deprecated -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.runtime import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.condition.processor import ConditionProcessor +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.runtime import VariablePool +from dify_graph.utils.condition.entities import Condition +from dify_graph.utils.condition.processor import ConditionProcessor class IfElseNode(Node[IfElseNodeData]): - node_type = NodeType.IF_ELSE + node_type = BuiltinNodeTypes.IF_ELSE execution_type = NodeExecutionType.BRANCH @classmethod @@ -97,13 +97,11 @@ class IfElseNode(Node[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IfElseNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IfElseNodeData.model_validate(node_data) - var_mapping: dict[str, list[str]] = {} - for case in typed_node_data.cases or []: + _ = graph_config # Explicitly mark as unused + for case in node_data.cases or []: for condition in case.conditions: key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" var_mapping[key] = condition.variable_selector diff --git a/api/core/workflow/nodes/iteration/__init__.py b/api/dify_graph/nodes/iteration/__init__.py similarity index 100% rename from api/core/workflow/nodes/iteration/__init__.py rename to api/dify_graph/nodes/iteration/__init__.py diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py similarity index 83% rename from api/core/workflow/nodes/iteration/entities.py rename to api/dify_graph/nodes/iteration/entities.py index 63a41ec755..58fd112b12 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/dify_graph/nodes/iteration/entities.py @@ -3,7 +3,9 @@ from typing import Any from pydantic import Field -from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState class ErrorHandleMode(StrEnum): @@ -17,6 +19,7 @@ class IterationNodeData(BaseIterationNodeData): Iteration Node Data. """ + type: NodeType = BuiltinNodeTypes.ITERATION parent_loop_id: str | None = None # redundant field, not used currently iterator_selector: list[str] # variable selector output_selector: list[str] # output selector @@ -31,7 +34,7 @@ class IterationStartNodeData(BaseNodeData): Iteration Start Node Data. """ - pass + type: NodeType = BuiltinNodeTypes.ITERATION_START class IterationState(BaseIterationState): diff --git a/api/core/workflow/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py similarity index 100% rename from api/core/workflow/nodes/iteration/exc.py rename to api/dify_graph/nodes/iteration/exc.py diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py similarity index 85% rename from api/core/workflow/nodes/iteration/iteration_node.py rename to api/dify_graph/nodes/iteration/iteration_node.py index 54b0561dd8..033ec8672f 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -6,21 +6,22 @@ from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import ( +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import ( IterationFailedEvent, IterationNextEvent, IterationStartedEvent, @@ -29,13 +30,13 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from core.workflow.runtime import VariablePool -from core.workflow.variables import IntegerVariable, NoneSegment -from core.workflow.variables.segments import ArrayAnySegment, ArraySegment -from core.workflow.variables.variables import Variable +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.runtime import VariablePool +from dify_graph.variables import IntegerVariable, NoneSegment +from dify_graph.variables.segments import ArrayAnySegment, ArraySegment +from dify_graph.variables.variables import Variable from libs.datetime_utils import naive_utc_now from .exc import ( @@ -48,8 +49,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.context import IExecutionContext - from core.workflow.graph_engine import GraphEngine + from dify_graph.context import IExecutionContext + from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -61,7 +62,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): Iteration Node. """ - node_type = NodeType.ITERATION + node_type = BuiltinNodeTypes.ITERATION execution_type = NodeExecutionType.CONTAINER @classmethod @@ -235,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): future_to_index: dict[ Future[ tuple[ - datetime, + float, list[GraphNodeEventBase], object | None, dict[str, Variable], @@ -260,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): try: result = future.result() ( - iter_start_at, + iteration_duration, events, output_value, conversation_snapshot, @@ -273,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # Yield all events from this iteration yield from events - # Update tokens and timing - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # The worker computes duration before we replay buffered events here, + # so slow downstream consumers don't inflate per-iteration timing. + iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) @@ -304,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): index: int, item: object, execution_context: "IExecutionContext", - ) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: """Execute a single iteration in parallel mode and return results.""" with execution_context: iter_start_at = datetime.now(UTC).replace(tzinfo=None) @@ -326,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): conversation_snapshot = self._extract_conversation_variable_snapshot( variable_pool=graph_engine.graph_runtime_state.variable_pool ) + iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() return ( - iter_start_at, + iteration_duration, events, output_value, conversation_snapshot, @@ -337,7 +340,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _capture_execution_context(self) -> "IExecutionContext": """Capture current execution context for parallel iterations.""" - from core.workflow.context import capture_current_context + from dify_graph.context import capture_current_context return capture_current_context() @@ -460,21 +463,18 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: IterationNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = IterationNodeData.model_validate(node_data) - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": typed_node_data.iterator_selector, + f"{node_id}.input_selector": node_data.iterator_selector, } iteration_node_ids = set() # Find all nodes that belong to this loop nodes = graph_config.get("nodes", []) for node in nodes: - node_data = node.get("data", {}) - if node_data.get("iteration_id") == node_id: + node_config_data = node.get("data", {}) + if node_config_data.get("iteration_id") == node_id: in_iteration_node_id = node.get("id") if in_iteration_node_id: iteration_node_ids.add(in_iteration_node_id) @@ -487,17 +487,16 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # variable selector to variable mapping try: - # Get node class - from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - - node_type = NodeType(sub_node_config.get("data", {}).get("type")) - if node_type not in NODE_TYPE_CLASSES_MAPPING: + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type + node_mapping = Node.get_node_type_classes_mapping() + if node_type not in node_mapping: continue - node_version = sub_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_version = str(typed_sub_node_config["data"].version) + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: @@ -563,7 +562,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") current_index = index_variable.value for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.ITERATION_START: + if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: continue if isinstance(event, GraphNodeEventBase): @@ -587,24 +586,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return def _create_graph_engine(self, index: int, item: object): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.app.workflow.node_factory import DifyNodeFactory - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) # Create a deep copy of the variable pool for each iteration @@ -621,28 +610,17 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): total_tokens=0, node_run_steps=0, ) + root_node_id = self.node_data.start_node_id + if root_node_id is None: + raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the iteration graph with the new node factory - iteration_graph = Graph.init( - graph_config=self.graph_config, node_factory=node_factory, root_node_id=self.node_data.start_node_id - ) - - if not iteration_graph: - raise IterationGraphNotFoundError("iteration graph not found") - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( - workflow_id=self.workflow_id, - graph=iteration_graph, - graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), - ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine + try: + return self.graph_runtime_state.create_child_engine( + workflow_id=self.workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state_copy, + graph_config=self.graph_config, + root_node_id=root_node_id, + ) + except ChildGraphNotFoundError as exc: + raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/dify_graph/nodes/iteration/iteration_start_node.py similarity index 53% rename from api/core/workflow/nodes/iteration/iteration_start_node.py rename to api/dify_graph/nodes/iteration/iteration_start_node.py index 30d9fccbfd..a8ecf3d83b 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/dify_graph/nodes/iteration/iteration_start_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.iteration.entities import IterationStartNodeData +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.iteration.entities import IterationStartNodeData class IterationStartNode(Node[IterationStartNodeData]): @@ -9,7 +9,7 @@ class IterationStartNode(Node[IterationStartNodeData]): Iteration Start Node. """ - node_type = NodeType.ITERATION_START + node_type = BuiltinNodeTypes.ITERATION_START @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/list_operator/__init__.py b/api/dify_graph/nodes/list_operator/__init__.py similarity index 100% rename from api/core/workflow/nodes/list_operator/__init__.py rename to api/dify_graph/nodes/list_operator/__init__.py diff --git a/api/core/workflow/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py similarity index 89% rename from api/core/workflow/nodes/list_operator/entities.py rename to api/dify_graph/nodes/list_operator/entities.py index e51a91f07f..41b3a40b78 100644 --- a/api/core/workflow/nodes/list_operator/entities.py +++ b/api/dify_graph/nodes/list_operator/entities.py @@ -3,7 +3,8 @@ from enum import StrEnum from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class FilterOperator(StrEnum): @@ -62,6 +63,7 @@ class ExtractConfig(BaseModel): class ListOperatorNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.LIST_OPERATOR variable: Sequence[str] = Field(default_factory=list) filter_by: FilterBy order_by: OrderByConfig diff --git a/api/core/workflow/nodes/list_operator/exc.py b/api/dify_graph/nodes/list_operator/exc.py similarity index 100% rename from api/core/workflow/nodes/list_operator/exc.py rename to api/dify_graph/nodes/list_operator/exc.py diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/dify_graph/nodes/list_operator/node.py similarity index 96% rename from api/core/workflow/nodes/list_operator/node.py rename to api/dify_graph/nodes/list_operator/node.py index d9ef16fbe7..dc8b8904f7 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/dify_graph/nodes/list_operator/node.py @@ -1,12 +1,12 @@ from collections.abc import Callable, Sequence from typing import Any, TypeAlias, TypeVar -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from core.workflow.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.file import File +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment +from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment from .entities import FilterOperator, ListOperatorNodeData, Order from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError @@ -35,7 +35,7 @@ def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = NodeType.LIST_OPERATOR + node_type = BuiltinNodeTypes.LIST_OPERATOR @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/llm/__init__.py b/api/dify_graph/nodes/llm/__init__.py similarity index 100% rename from api/core/workflow/nodes/llm/__init__.py rename to api/dify_graph/nodes/llm/__init__.py diff --git a/api/core/workflow/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py similarity index 91% rename from api/core/workflow/nodes/llm/entities.py rename to api/dify_graph/nodes/llm/entities.py index fe6f2290aa..6ca01a21da 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -3,10 +3,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator -from core.model_runtime.entities import ImagePromptMessageContent, LLMMode from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.base.entities import VariableSelector +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode +from dify_graph.nodes.base.entities import VariableSelector class ModelConfig(BaseModel): @@ -59,6 +60,7 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): class LLMNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.LLM model: ModelConfig prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate prompt_config: PromptConfig = Field(default_factory=PromptConfig) diff --git a/api/core/workflow/nodes/llm/exc.py b/api/dify_graph/nodes/llm/exc.py similarity index 100% rename from api/core/workflow/nodes/llm/exc.py rename to api/dify_graph/nodes/llm/exc.py diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py similarity index 88% rename from api/core/workflow/nodes/llm/file_saver.py rename to api/dify_graph/nodes/llm/file_saver.py index 3c06ab7d81..50e52a3b6f 100644 --- a/api/core/workflow/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -1,14 +1,11 @@ import mimetypes import typing as tp -from sqlalchemy import Engine - from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.helper import ssrf_proxy from core.tools.signature import sign_tool_file from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import File, FileTransferMethod, FileType -from extensions.ext_database import db as global_db +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.nodes.protocols import HttpClientProtocol class LLMFileSaver(tp.Protocol): @@ -59,30 +56,20 @@ class LLMFileSaver(tp.Protocol): raise NotImplementedError() -EngineFactory: tp.TypeAlias = tp.Callable[[], Engine] - - class FileSaverImpl(LLMFileSaver): - _engine_factory: EngineFactory _tenant_id: str _user_id: str - def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None): - if engine_factory is None: - - def _factory(): - return global_db.engine - - engine_factory = _factory - self._engine_factory = engine_factory + def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): self._user_id = user_id self._tenant_id = tenant_id + self._http_client = http_client def _get_tool_file_manager(self): - return ToolFileManager(engine=self._engine_factory()) + return ToolFileManager() def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = ssrf_proxy.get(url) + http_response = self._http_client.get(url) http_response.raise_for_status() data = http_response.content mime_type_from_header = http_response.headers.get("Content-Type") diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py new file mode 100644 index 0000000000..2be391a424 --- /dev/null +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -0,0 +1,477 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, cast + +from core.model_manager import ModelInstance +from dify_graph.file import FileType, file_manager +from dify_graph.file.models import File +from dify_graph.model_runtime.entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentType, + PromptMessageRole, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.runtime import VariablePool +from dify_graph.variables import ArrayFileSegment, FileSegment +from dify_graph.variables.segments import ArrayAnySegment, NoneSegment + +from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from .exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from .protocols import TemplateRenderer + + +def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( + model_instance.model_name, + dict(model_instance.credentials), + ) + if not model_schema: + raise ValueError(f"Model schema not found for {model_instance.model_name}") + return model_schema + + +def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: + variable = variable_pool.get(selector) + if variable is None: + return [] + elif isinstance(variable, FileSegment): + return [variable.value] + elif isinstance(variable, ArrayFileSegment): + return variable.value + elif isinstance(variable, NoneSegment | ArrayAnySegment): + return [] + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") + + +def convert_history_messages_to_text( + *, + history_messages: Sequence[PromptMessage], + human_prefix: str, + ai_prefix: str, +) -> str: + string_messages: list[str] = [] + for message in history_messages: + if message.role == PromptMessageRole.USER: + role = human_prefix + elif message.role == PromptMessageRole.ASSISTANT: + role = ai_prefix + else: + continue + + if isinstance(message.content, list): + content_parts = [] + for content in message.content: + if isinstance(content, TextPromptMessageContent): + content_parts.append(content.data) + elif isinstance(content, ImagePromptMessageContent): + content_parts.append("[image]") + + inner_msg = "\n".join(content_parts) + string_messages.append(f"{role}: {inner_msg}") + else: + string_messages.append(f"{role}: {message.content}") + + return "\n".join(string_messages) + + +def fetch_memory_text( + *, + memory: PromptMessageMemory, + max_token_limit: int, + message_limit: int | None = None, + human_prefix: str = "Human", + ai_prefix: str = "Assistant", +) -> str: + history_messages = memory.get_history_prompt_messages( + max_token_limit=max_token_limit, + message_limit=message_limit, + ) + return convert_history_messages_to_text( + history_messages=history_messages, + human_prefix=human_prefix, + ai_prefix=ai_prefix, + ) + + +def fetch_prompt_messages( + *, + sys_query: str | None = None, + sys_files: Sequence[File], + context: str | None = None, + memory: PromptMessageMemory | None = None, + model_instance: ModelInstance, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + stop: Sequence[str] | None = None, + memory_config: MemoryConfig | None = None, + vision_enabled: bool = False, + vision_detail: ImagePromptMessageContent.DETAIL, + variable_pool: VariablePool, + jinja2_variables: Sequence[VariableSelector], + context_files: list[File] | None = None, + template_renderer: TemplateRenderer | None = None, +) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: + prompt_messages: list[PromptMessage] = [] + model_schema = fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + prompt_messages.extend( + handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + template_renderer=template_renderer, + ) + ) + + prompt_messages.extend( + handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + ) + + if sys_query: + prompt_messages.extend( + handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + template_renderer=template_renderer, + ) + ) + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + prompt_messages.extend( + handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + ) + + memory_text = handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + prompt_content = prompt_messages[0].content + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + if sys_query: + if isinstance(prompt_content, str): + prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + _append_file_prompts( + prompt_messages=prompt_messages, + files=sys_files, + vision_enabled=vision_enabled, + vision_detail=vision_detail, + ) + _append_file_prompts( + prompt_messages=prompt_messages, + files=context_files or [], + vision_enabled=vision_enabled, + vision_detail=vision_detail, + ) + + filtered_prompt_messages: list[PromptMessage] = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + if not model_schema.features: + if content_item.type == PromptMessageContentType.TEXT: + prompt_message_content.append(content_item) + continue + + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if not prompt_message_content: + continue + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: + prompt_message.content = prompt_message_content + filtered_prompt_messages.append(prompt_message) + elif not prompt_message.is_empty(): + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop + + +def handle_list_messages( + *, + messages: Sequence[LLMNodeChatModelMessage], + context: str | None, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + vision_detail_config: ImagePromptMessageContent.DETAIL, + template_renderer: TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + prompt_messages.append( + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], + role=message.role, + ) + ) + continue + + template = message.text.replace("{#context#}", context) if context else message.text + segment_group = variable_pool.convert_template(template) + file_contents: list[PromptMessageContentUnionTypes] = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) + ) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_contents.append( + file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) + ) + + if segment_group.text: + prompt_messages.append( + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=segment_group.text)], + role=message.role, + ) + ) + if file_contents: + prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) + + return prompt_messages + + +def render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + template_renderer: TemplateRenderer | None = None, +) -> str: + if not template: + return "" + if template_renderer is None: + raise ValueError("template_renderer is required for jinja2 prompt rendering") + + jinja2_inputs: dict[str, Any] = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + + +def handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str | None, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + template_renderer: TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + if template.edition_type == "jinja2": + result_text = render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + template_renderer=template_renderer, + ) + else: + template_text = template.text.replace("{#context#}", context) if context else template.text + result_text = variable_pool.convert_template(template_text).text + return [ + combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], + role=PromptMessageRole.USER, + ) + ] + + +def combine_message_content_with_role( + *, + contents: str | list[PromptMessageContentUnionTypes] | None = None, + role: PromptMessageRole, +) -> PromptMessage: + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: + rest_tokens = 2000 + runtime_model_schema = fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: ModelInstance, +) -> Sequence[PromptMessage]: + if not memory or not memory_config: + return [] + rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) + return memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + + +def handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: ModelInstance, +) -> str: + if not memory or not memory_config: + return "" + + rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + + return fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + + +def _append_file_prompts( + *, + prompt_messages: list[PromptMessage], + files: Sequence[File], + vision_enabled: bool, + vision_detail: ImagePromptMessageContent.DETAIL, +) -> None: + if not vision_enabled or not files: + return + + file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] + if ( + prompt_messages + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + existing_contents = prompt_messages[-1].content + assert isinstance(existing_contents, list) + prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) diff --git a/api/core/workflow/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py similarity index 65% rename from api/core/workflow/nodes/llm/node.py rename to api/dify_graph/nodes/llm/node.py index c06db0dc16..5ed90ed7e3 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -11,17 +11,29 @@ from typing import TYPE_CHECKING, Any, Literal from sqlalchemy import select -from core.helper.code_executor import CodeExecutor, CodeLanguage from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelInstance -from core.model_runtime.entities import ( +from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.tools.signature import sign_upload_file +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ( + BuiltinNodeTypes, + NodeType, + SystemVariableKey, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, - PromptMessageContentType, TextPromptMessageContent, ) -from core.model_runtime.entities.llm_entities import ( +from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkWithStructuredOutput, @@ -29,30 +41,10 @@ from core.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - PromptMessageRole, - SystemPromptMessage, - UserPromptMessage, -) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from core.tools.signature import sign_upload_file -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ( - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from core.workflow.file import File, FileTransferMethod, FileType, file_manager -from core.workflow.node_events import ( +from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ( ModelInvokeCompletedEvent, NodeEventBase, NodeRunResult, @@ -60,15 +52,15 @@ from core.workflow.node_events import ( StreamChunkEvent, StreamCompletedEvent, ) -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import VariablePool -from core.workflow.variables import ( +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.runtime import VariablePool +from dify_graph.variables import ( ArrayFileSegment, ArraySegment, - FileSegment, NoneSegment, ObjectSegment, StringSegment, @@ -87,22 +79,19 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, VariableNotFoundError, ) from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState logger = logging.getLogger(__name__) class LLMNode(Node[LLMNodeData]): - node_type = NodeType.LLM + node_type = BuiltinNodeTypes.LLM # Compiled regex for extracting blocks (with compatibility for attributes) _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) @@ -116,17 +105,20 @@ class LLMNode(Node[LLMNodeData]): _model_factory: ModelFactory _model_instance: ModelInstance _memory: PromptMessageMemory | None + _template_renderer: TemplateRenderer def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, credentials_provider: CredentialsProvider, model_factory: ModelFactory, model_instance: ModelInstance, + http_client: HttpClientProtocol, + template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -143,11 +135,14 @@ class LLMNode(Node[LLMNodeData]): self._model_factory = model_factory self._model_instance = model_instance self._memory = memory + self._template_renderer = template_renderer if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver @@ -235,6 +230,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, + template_renderer=self._template_renderer, ) # handle invoke result @@ -242,7 +238,7 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -672,7 +668,7 @@ class LLMNode(Node[LLMNodeData]): ) elif isinstance(context_value_variable, ArraySegment): context_str = "" - original_retriever_resource: list[RetrievalSourceMetadata] = [] + original_retriever_resource: list[dict[str, Any]] = [] context_files: list[File] = [] for item in context_value_variable.value: if isinstance(item, str): @@ -688,11 +684,14 @@ class LLMNode(Node[LLMNodeData]): retriever_resource = self._convert_to_original_retriever_resource(item) if retriever_resource: original_retriever_resource.append(retriever_resource) + segment_id = retriever_resource.get("segment_id") + if not segment_id: + continue attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) .where( - SegmentAttachmentBinding.segment_id == retriever_resource.segment_id, + SegmentAttachmentBinding.segment_id == segment_id, ) ).all() if attachments_with_bindings: @@ -702,7 +701,7 @@ class LLMNode(Node[LLMNodeData]): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=self.tenant_id, + tenant_id=self.require_dify_context().tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, @@ -718,7 +717,7 @@ class LLMNode(Node[LLMNodeData]): context_files=context_files, ) - def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None: + def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -726,28 +725,26 @@ class LLMNode(Node[LLMNodeData]): ): metadata = context_dict.get("metadata", {}) - source = RetrievalSourceMetadata( - position=metadata.get("position"), - dataset_id=metadata.get("dataset_id"), - dataset_name=metadata.get("dataset_name"), - document_id=metadata.get("document_id"), - document_name=metadata.get("document_name"), - data_source_type=metadata.get("data_source_type"), - segment_id=metadata.get("segment_id"), - retriever_from=metadata.get("retriever_from"), - score=metadata.get("score"), - hit_count=metadata.get("segment_hit_count"), - word_count=metadata.get("segment_word_count"), - segment_position=metadata.get("segment_position"), - index_node_hash=metadata.get("segment_index_node_hash"), - content=context_dict.get("content"), - page=metadata.get("page"), - doc_metadata=metadata.get("doc_metadata"), - files=context_dict.get("files"), - summary=context_dict.get("summary"), - ) - - return source + return { + "position": metadata.get("position"), + "dataset_id": metadata.get("dataset_id"), + "dataset_name": metadata.get("dataset_name"), + "document_id": metadata.get("document_id"), + "document_name": metadata.get("document_name"), + "data_source_type": metadata.get("data_source_type"), + "segment_id": metadata.get("segment_id"), + "retriever_from": metadata.get("retriever_from"), + "score": metadata.get("score"), + "hit_count": metadata.get("segment_hit_count"), + "word_count": metadata.get("segment_word_count"), + "segment_position": metadata.get("segment_position"), + "index_node_hash": metadata.get("segment_index_node_hash"), + "content": context_dict.get("content"), + "page": metadata.get("page"), + "doc_metadata": metadata.get("doc_metadata"), + "files": context_dict.get("files"), + "summary": context_dict.get("summary"), + } return None @@ -767,182 +764,24 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, + template_renderer: TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - # For chat model - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - ) - ) - - # Get memory messages for chat mode - memory_messages = _handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Extend prompt_messages with memory messages - prompt_messages.extend(memory_messages) - - # Add current query to the prompt messages - if sys_query: - message = LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - prompt_messages.extend( - LLMNode.handle_list_messages( - messages=[message], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - ) - ) - - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - # For completion model - prompt_messages.extend( - _handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - ) - - # Get memory text for completion model - memory_text = _handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - # Insert histories into the prompt - prompt_content = prompt_messages[0].content - # For issue #11247 - Check if prompt content is a string or a list - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - # Add current query to the prompt message - if sys_query: - if isinstance(prompt_content, str): - prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - # The sys_files will be deprecated later - if vision_enabled and sys_files: - file_prompts = [] - for file in sys_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # The context_files - if vision_enabled and context_files: - file_prompts = [] - for file in context_files: - file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) - file_prompts.append(file_prompt) - # If last prompt is a user prompt, add files into its contents, - # otherwise append a new user prompt - if ( - len(prompt_messages) > 0 - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - # Remove empty messages and filter unsupported content - filtered_prompt_messages = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - # Skip content if features are not defined - if not model_schema.features: - if content_item.type != PromptMessageContentType.TEXT: - continue - prompt_message_content.append(content_item) - continue - - # Skip content if corresponding feature is not supported - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - if prompt_message.is_empty(): - continue - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. " - "Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop + return llm_utils.fetch_prompt_messages( + sys_query=sys_query, + sys_files=sys_files, + context=context, + memory=memory, + model_instance=model_instance, + prompt_template=prompt_template, + stop=stop, + memory_config=memory_config, + vision_enabled=vision_enabled, + vision_detail=vision_detail, + variable_pool=variable_pool, + jinja2_variables=jinja2_variables, + context_files=context_files, + template_renderer=template_renderer, + ) @classmethod def _extract_variable_selector_to_variable_mapping( @@ -950,14 +789,11 @@ class LLMNode(Node[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LLMNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type _ = graph_config # Explicitly mark as unused - # Create typed NodeData from dict - typed_node_data = LLMNodeData.model_validate(node_data) - - prompt_template = typed_node_data.prompt_template + prompt_template = node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list): for prompt in prompt_template: @@ -975,7 +811,7 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = typed_node_data.memory + memory = node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -983,16 +819,16 @@ class LLMNode(Node[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if typed_node_data.context.enabled: - variable_mapping["#context#"] = typed_node_data.context.variable_selector + if node_data.context.enabled: + variable_mapping["#context#"] = node_data.context.variable_selector - if typed_node_data.vision.enabled: - variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector + if node_data.vision.enabled: + variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if typed_node_data.memory: + if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if typed_node_data.prompt_config: + if node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): @@ -1005,7 +841,7 @@ class LLMNode(Node[LLMNodeData]): break if enable_jinja: - for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: + for variable_selector in node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -1045,59 +881,16 @@ class LLMNode(Node[LLMNodeData]): jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, + template_renderer: TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=message.role - ) - prompt_messages.append(prompt_message) - else: - # Get segment group from basic message - if context: - template = message.text.replace("{#context#}", context) - else: - template = message.text - segment_group = variable_pool.convert_template(template) - - # Process segments for images - file_contents = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_content = file_manager.to_prompt_message_content( - file, image_detail_config=vision_detail_config - ) - file_contents.append(file_content) - - # Create message with text from all segments - plain_text = segment_group.text - if plain_text: - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=plain_text)], role=message.role - ) - prompt_messages.append(prompt_message) - - if file_contents: - # Create message with image contents - prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) - prompt_messages.append(prompt_message) - - return prompt_messages + return llm_utils.handle_list_messages( + messages=messages, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail_config, + template_renderer=template_renderer, + ) @staticmethod def handle_blocking_result( @@ -1236,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]): @property def model_instance(self) -> ModelInstance: return self._model_instance - - -def _combine_message_content_with_role( - *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole -): - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def _render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, -): - if not template: - return "" - - jinja2_inputs = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - code_execute_resp = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=jinja2_inputs, - ) - result_text = code_execute_resp["result"] - return result_text - - -def _calculate_rest_token( - *, - prompt_messages: list[PromptMessage], - model_instance: ModelInstance, -) -> int: - rest_tokens = 2000 - runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def _handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> Sequence[PromptMessage]: - memory_messages: Sequence[PromptMessage] = [] - # Get messages from memory for chat model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - memory_messages = memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - return memory_messages - - -def _handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> str: - memory_text = "" - # Get history text from memory for completion model - if memory and memory_config: - rest_tokens = _calculate_rest_token( - prompt_messages=[], - model_instance=model_instance, - ) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - memory_text = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - return memory_text - - -def _handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, -) -> Sequence[PromptMessage]: - """Handle completion template processing outside of LLMNode class. - - Args: - template: The completion model prompt template - context: Optional context string - jinja2_variables: Variables for jinja2 template rendering - variable_pool: Variable pool for template conversion - - Returns: - Sequence of prompt messages - """ - prompt_messages = [] - if template.edition_type == "jinja2": - result_text = _render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - ) - else: - if context: - template_text = template.text.replace("{#context#}", context) - else: - template_text = template.text - result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER - ) - prompt_messages.append(prompt_message) - return prompt_messages diff --git a/api/core/workflow/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py similarity index 69% rename from api/core/workflow/nodes/llm/protocols.py rename to api/dify_graph/nodes/llm/protocols.py index 8e0365299d..9e95d341c9 100644 --- a/api/core/workflow/nodes/llm/protocols.py +++ b/api/dify_graph/nodes/llm/protocols.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Mapping from typing import Any, Protocol from core.model_manager import ModelInstance @@ -19,3 +20,11 @@ class ModelFactory(Protocol): def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: """Create a model instance that is ready for schema lookup and invocation.""" ... + + +class TemplateRenderer(Protocol): + """Port for rendering prompt templates used by LLM-compatible nodes.""" + + def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: + """Render the given Jinja2 template into plain text.""" + ... diff --git a/api/core/workflow/nodes/loop/__init__.py b/api/dify_graph/nodes/loop/__init__.py similarity index 100% rename from api/core/workflow/nodes/loop/__init__.py rename to api/dify_graph/nodes/loop/__init__.py diff --git a/api/core/workflow/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py similarity index 83% rename from api/core/workflow/nodes/loop/entities.py rename to api/dify_graph/nodes/loop/entities.py index 4090f27799..f0bfad5a0f 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/dify_graph/nodes/loop/entities.py @@ -3,9 +3,11 @@ from typing import Annotated, Any, Literal from pydantic import AfterValidator, BaseModel, Field, field_validator -from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData -from core.workflow.utils.condition.entities import Condition -from core.workflow.variables.types import SegmentType +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState +from dify_graph.utils.condition.entities import Condition +from dify_graph.variables.types import SegmentType _VALID_VAR_TYPE = frozenset( [ @@ -39,6 +41,7 @@ class LoopVariableData(BaseModel): class LoopNodeData(BaseLoopNodeData): + type: NodeType = BuiltinNodeTypes.LOOP loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] @@ -58,7 +61,7 @@ class LoopStartNodeData(BaseNodeData): Loop Start Node Data. """ - pass + type: NodeType = BuiltinNodeTypes.LOOP_START class LoopEndNodeData(BaseNodeData): @@ -66,7 +69,7 @@ class LoopEndNodeData(BaseNodeData): Loop End Node Data. """ - pass + type: NodeType = BuiltinNodeTypes.LOOP_END class LoopState(BaseLoopState): diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/dify_graph/nodes/loop/loop_end_node.py similarity index 53% rename from api/core/workflow/nodes/loop/loop_end_node.py rename to api/dify_graph/nodes/loop/loop_end_node.py index 1e3e317b53..0287708fb3 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/dify_graph/nodes/loop/loop_end_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopEndNodeData +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopEndNodeData class LoopEndNode(Node[LoopEndNodeData]): @@ -9,7 +9,7 @@ class LoopEndNode(Node[LoopEndNodeData]): Loop End Node. """ - node_type = NodeType.LOOP_END + node_type = BuiltinNodeTypes.LOOP_END @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py similarity index 85% rename from api/core/workflow/nodes/loop/loop_node.py rename to api/dify_graph/nodes/loop/loop_node.py index 40ec0cf8b1..3c546ffa23 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -5,19 +5,20 @@ from collections.abc import Callable, Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Literal, cast -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import ( +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphNodeEventBase, GraphRunFailedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import ( +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import ( LoopFailedEvent, LoopNextEvent, LoopStartedEvent, @@ -26,16 +27,16 @@ from core.workflow.node_events import ( NodeRunResult, StreamCompletedEvent, ) -from core.workflow.nodes.base import LLMUsageTrackingMixin -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from core.workflow.utils.condition.processor import ConditionProcessor -from core.workflow.variables import Segment, SegmentType +from dify_graph.nodes.base import LLMUsageTrackingMixin +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData +from dify_graph.utils.condition.processor import ConditionProcessor +from dify_graph.variables import Segment, SegmentType from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable from libs.datetime_utils import naive_utc_now if TYPE_CHECKING: - from core.workflow.graph_engine import GraphEngine + from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) @@ -45,7 +46,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): Loop Node. """ - node_type = NodeType.LOOP + node_type = BuiltinNodeTypes.LOOP execution_type = NodeExecutionType.CONTAINER @classmethod @@ -249,11 +250,11 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): if isinstance(event, GraphNodeEventBase): self._append_loop_info_to_event(event=event, loop_run_index=current_index) - if isinstance(event, GraphNodeEventBase) and event.node_type == NodeType.LOOP_START: + if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: continue if isinstance(event, GraphNodeEventBase): yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == NodeType.LOOP_END: + if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) @@ -298,11 +299,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: LoopNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = LoopNodeData.model_validate(node_data) - variable_mapping = {} # Extract loop node IDs statically from graph_config @@ -317,17 +315,16 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # variable selector to variable mapping try: - # Get node class - from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - - node_type = NodeType(sub_node_config.get("data", {}).get("type")) - if node_type not in NODE_TYPE_CLASSES_MAPPING: + typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) + node_type = typed_sub_node_config["data"].type + node_mapping = Node.get_node_type_classes_mapping() + if node_type not in node_mapping: continue - node_version = sub_node_config.get("data", {}).get("version", "1") - node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] + node_version = str(typed_sub_node_config["data"].version) + node_cls = node_mapping[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=sub_node_config + graph_config=graph_config, config=typed_sub_node_config ) sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) except NotImplementedError: @@ -342,7 +339,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in typed_node_data.loop_variables or []: + for loop_variable in node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping @@ -412,24 +409,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): return build_segment_with_type(var_type, value) def _create_graph_engine(self, start_at: datetime, root_node_id: str): - # Import dependencies - from core.app.workflow.layers.llm_quota import LLMQuotaLayer - from core.app.workflow.node_factory import DifyNodeFactory - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState - # Create GraphInitParams from node attributes + # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -439,22 +426,10 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): start_at=start_at.timestamp(), ) - # Create a new node factory with the new GraphRuntimeState - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy - ) - - # Initialize the loop graph with the new node factory - loop_graph = Graph.init(graph_config=self.graph_config, node_factory=node_factory, root_node_id=root_node_id) - - # Create a new GraphEngine for this iteration - graph_engine = GraphEngine( + return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, - graph=loop_graph, + graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state_copy, - command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs - config=GraphEngineConfig(), + graph_config=self.graph_config, + root_node_id=root_node_id, ) - graph_engine.layer(LLMQuotaLayer()) - - return graph_engine diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/dify_graph/nodes/loop/loop_start_node.py similarity index 53% rename from api/core/workflow/nodes/loop/loop_start_node.py rename to api/dify_graph/nodes/loop/loop_start_node.py index 95bb5c4018..e171b4df2f 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/dify_graph/nodes/loop/loop_start_node.py @@ -1,7 +1,7 @@ -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.loop.entities import LoopStartNodeData +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.loop.entities import LoopStartNodeData class LoopStartNode(Node[LoopStartNodeData]): @@ -9,7 +9,7 @@ class LoopStartNode(Node[LoopStartNodeData]): Loop Start Node. """ - node_type = NodeType.LOOP_START + node_type = BuiltinNodeTypes.LOOP_START @classmethod def version(cls) -> str: diff --git a/api/core/workflow/nodes/parameter_extractor/__init__.py b/api/dify_graph/nodes/parameter_extractor/__init__.py similarity index 100% rename from api/core/workflow/nodes/parameter_extractor/__init__.py rename to api/dify_graph/nodes/parameter_extractor/__init__.py diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py similarity index 93% rename from api/core/workflow/nodes/parameter_extractor/entities.py rename to api/dify_graph/nodes/parameter_extractor/entities.py index 90d78ae429..2fb042c16c 100644 --- a/api/core/workflow/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -8,9 +8,10 @@ from pydantic import ( ) from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig -from core.workflow.variables.types import SegmentType +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" _OLD_SELECT_TYPE_NAME = "select" @@ -83,6 +84,7 @@ class ParameterExtractorNodeData(BaseNodeData): Parameter Extractor Node Data. """ + type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR model: ModelConfig query: list[str] parameters: list[ParameterConfig] diff --git a/api/core/workflow/nodes/parameter_extractor/exc.py b/api/dify_graph/nodes/parameter_extractor/exc.py similarity index 97% rename from api/core/workflow/nodes/parameter_extractor/exc.py rename to api/dify_graph/nodes/parameter_extractor/exc.py index 5a58780575..c25b809d1c 100644 --- a/api/core/workflow/nodes/parameter_extractor/exc.py +++ b/api/dify_graph/nodes/parameter_extractor/exc.py @@ -1,6 +1,6 @@ from typing import Any -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType class ParameterExtractorNodeError(ValueError): diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py similarity index 95% rename from api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py rename to api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index 4272b98116..3913a27697 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -6,9 +6,20 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast from core.model_manager import ModelInstance -from core.model_runtime.entities import ImagePromptMessageContent -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import ( +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.file import File +from dify_graph.model_runtime.entities import ImagePromptMessageContent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -16,26 +27,16 @@ from core.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.enums import ( - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from core.workflow.file import File -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.llm import llm_utils -from core.workflow.runtime import VariablePool -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.llm import llm_utils +from dify_graph.runtime import VariablePool +from dify_graph.variables.types import ArrayValidation, SegmentType from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData @@ -64,9 +65,9 @@ from .prompts import ( logger = logging.getLogger(__name__) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory + from dify_graph.runtime import GraphRuntimeState def extract_json(text): @@ -96,7 +97,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): Parameter Extractor Node. """ - node_type = NodeType.PARAMETER_EXTRACTOR + node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR _model_instance: ModelInstance _credentials_provider: "CredentialsProvider" @@ -106,7 +107,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, @@ -297,7 +298,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): tools=tools, stop=list(stop), stream=False, - user=self.user_id, + user=self.require_dify_context().user_id, ) # handle invoke result @@ -837,15 +838,13 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ParameterExtractorNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = ParameterExtractorNodeData.model_validate(node_data) + _ = graph_config # Explicitly mark as unused + variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} - - if typed_node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) + if node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/dify_graph/nodes/parameter_extractor/prompts.py similarity index 100% rename from api/core/workflow/nodes/parameter_extractor/prompts.py rename to api/dify_graph/nodes/parameter_extractor/prompts.py diff --git a/api/core/workflow/nodes/protocols.py b/api/dify_graph/nodes/protocols.py similarity index 83% rename from api/core/workflow/nodes/protocols.py rename to api/dify_graph/nodes/protocols.py index fda524d701..62d3bcdca1 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -1,8 +1,10 @@ +from collections.abc import Generator from typing import Any, Protocol import httpx -from core.workflow.file import File +from dify_graph.file import File +from dify_graph.file.models import ToolFile class HttpClientProtocol(Protocol): @@ -40,3 +42,5 @@ class ToolFileManagerProtocol(Protocol): mimetype: str, filename: str | None = None, ) -> Any: ... + + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/dify_graph/nodes/question_classifier/__init__.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/__init__.py rename to api/dify_graph/nodes/question_classifier/__init__.py diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py similarity index 77% rename from api/core/workflow/nodes/question_classifier/entities.py rename to api/dify_graph/nodes/question_classifier/entities.py index edde30708a..0c1601d439 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,8 +1,9 @@ from pydantic import BaseModel, Field from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm import ModelConfig, VisionConfig +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.llm import ModelConfig, VisionConfig class ClassConfig(BaseModel): @@ -11,6 +12,7 @@ class ClassConfig(BaseModel): class QuestionClassifierNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER query_variable_selector: list[str] model: ModelConfig classes: list[ClassConfig] diff --git a/api/core/workflow/nodes/question_classifier/exc.py b/api/dify_graph/nodes/question_classifier/exc.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/exc.py rename to api/dify_graph/nodes/question_classifier/exc.py diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py similarity index 88% rename from api/core/workflow/nodes/question_classifier/question_classifier_node.py rename to api/dify_graph/nodes/question_classifier/question_classifier_node.py index 6005bff1a6..59d0a2a4d8 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -4,30 +4,32 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any from core.model_manager import ModelInstance -from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from core.model_runtime.memory import PromptMessageMemory -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ( +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ( + BuiltinNodeTypes, NodeExecutionType, - NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult -from core.workflow.nodes.base.entities import VariableSelector -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.nodes.llm import ( +from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole +from dify_graph.model_runtime.memory import PromptMessageMemory +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.protocols import HttpClientProtocol from libs.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData @@ -43,12 +45,12 @@ from .template_prompts import ( ) if TYPE_CHECKING: - from core.workflow.file.models import File - from core.workflow.runtime import GraphRuntimeState + from dify_graph.file.models import File + from dify_graph.runtime import GraphRuntimeState class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = NodeType.QUESTION_CLASSIFIER + node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER execution_type = NodeExecutionType.BRANCH _file_outputs: list["File"] @@ -57,17 +59,20 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): _model_factory: "ModelFactory" _model_instance: ModelInstance _memory: PromptMessageMemory | None + _template_renderer: TemplateRenderer def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, credentials_provider: "CredentialsProvider", model_factory: "ModelFactory", model_instance: ModelInstance, + http_client: HttpClientProtocol, + template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, llm_file_saver: LLMFileSaver | None = None, ): @@ -84,11 +89,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self._model_factory = model_factory self._model_instance = model_instance self._memory = memory + self._template_renderer = template_renderer if llm_file_saver is None: + dify_ctx = self.require_dify_context() llm_file_saver = FileSaverImpl( - user_id=graph_init_params.user_id, - tenant_id=graph_init_params.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, + http_client=http_client, ) self._llm_file_saver = llm_file_saver @@ -137,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = LLMNode.fetch_prompt_messages( + prompt_messages, stop = llm_utils.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -148,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + template_renderer=self._template_renderer, ) result_text = "" @@ -160,7 +169,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.user_id, + user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, @@ -247,16 +256,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: QuestionClassifierNodeData, ) -> Mapping[str, Sequence[str]]: # graph_config is not used in this node type - # Create typed NodeData from dict - typed_node_data = QuestionClassifierNodeData.model_validate(node_data) - - variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_mapping = {"query": node_data.query_variable_selector} variable_selectors: list[VariableSelector] = [] - if typed_node_data.instruction: - variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) + if node_data.instruction: + variable_template_parser = VariableTemplateParser(template=node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) @@ -285,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = LLMNode.fetch_prompt_messages( + prompt_messages, _ = llm_utils.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", sys_files=[], @@ -298,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + template_renderer=self._template_renderer, ) rest_tokens = 2000 diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/dify_graph/nodes/question_classifier/template_prompts.py similarity index 100% rename from api/core/workflow/nodes/question_classifier/template_prompts.py rename to api/dify_graph/nodes/question_classifier/template_prompts.py diff --git a/api/core/workflow/nodes/start/__init__.py b/api/dify_graph/nodes/start/__init__.py similarity index 100% rename from api/core/workflow/nodes/start/__init__.py rename to api/dify_graph/nodes/start/__init__.py diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py new file mode 100644 index 0000000000..92ebd1a2ec --- /dev/null +++ b/api/dify_graph/nodes/start/entities.py @@ -0,0 +1,16 @@ +from collections.abc import Sequence + +from pydantic import Field + +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.variables.input_entities import VariableEntity + + +class StartNodeData(BaseNodeData): + """ + Start Node Data + """ + + type: NodeType = BuiltinNodeTypes.START + variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py similarity index 83% rename from api/core/workflow/nodes/start/start_node.py rename to api/dify_graph/nodes/start/start_node.py index 4e5545d330..5e6055ea34 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/dify_graph/nodes/start/start_node.py @@ -2,16 +2,16 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.variables.input_entities import VariableEntityType +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.variables.input_entities import VariableEntityType class StartNode(Node[StartNodeData]): - node_type = NodeType.START + node_type = BuiltinNodeTypes.START execution_type = NodeExecutionType.ROOT @classmethod diff --git a/api/core/workflow/nodes/template_transform/__init__.py b/api/dify_graph/nodes/template_transform/__init__.py similarity index 100% rename from api/core/workflow/nodes/template_transform/__init__.py rename to api/dify_graph/nodes/template_transform/__init__.py diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py new file mode 100644 index 0000000000..ac29239958 --- /dev/null +++ b/api/dify_graph/nodes/template_transform/entities.py @@ -0,0 +1,13 @@ +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.entities import VariableSelector + + +class TemplateTransformNodeData(BaseNodeData): + """ + Template Transform Node Data. + """ + + type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM + variables: list[VariableSelector] + template: str diff --git a/api/core/workflow/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py similarity index 62% rename from api/core/workflow/nodes/template_transform/template_renderer.py rename to api/dify_graph/nodes/template_transform/template_renderer.py index a5f06bf2bb..9b679d4497 100644 --- a/api/core/workflow/nodes/template_transform/template_renderer.py +++ b/api/dify_graph/nodes/template_transform/template_renderer.py @@ -3,7 +3,8 @@ from __future__ import annotations from collections.abc import Mapping from typing import Any, Protocol -from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage +from dify_graph.nodes.code.code_node import WorkflowCodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage class TemplateRenderError(ValueError): @@ -21,18 +22,18 @@ class Jinja2TemplateRenderer(Protocol): class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): """Adapter that renders Jinja2 templates via CodeExecutor.""" - _code_executor: type[CodeExecutor] + _code_executor: WorkflowCodeExecutor - def __init__(self, code_executor: type[CodeExecutor] | None = None) -> None: - self._code_executor = code_executor or CodeExecutor + def __init__(self, code_executor: WorkflowCodeExecutor) -> None: + self._code_executor = code_executor def render_template(self, template: str, variables: Mapping[str, Any]) -> str: try: - result = self._code_executor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=template, inputs=variables - ) - except CodeExecutionError as exc: - raise TemplateRenderError(str(exc)) from exc + result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) + except Exception as exc: + if self._code_executor.is_execution_error(exc): + raise TemplateRenderError(str(exc)) from exc + raise rendered = result.get("result") if not isinstance(rendered, str): diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py similarity index 75% rename from api/core/workflow/nodes/template_transform/template_transform_node.py rename to api/dify_graph/nodes/template_transform/template_transform_node.py index 3dc8afd9be..dc6fce2b0a 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -1,36 +1,36 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData -from core.workflow.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData +from dify_graph.nodes.template_transform.template_renderer import ( Jinja2TemplateRenderer, TemplateRenderError, ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = NodeType.TEMPLATE_TRANSFORM + node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM _template_renderer: Jinja2TemplateRenderer _max_output_length: int def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +39,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer() + self._template_renderer = template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") @@ -87,12 +87,9 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = TemplateTransformNodeData.model_validate(node_data) - return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in typed_node_data.variables + for variable_selector in node_data.variables } diff --git a/api/core/workflow/nodes/tool/__init__.py b/api/dify_graph/nodes/tool/__init__.py similarity index 100% rename from api/core/workflow/nodes/tool/__init__.py rename to api/dify_graph/nodes/tool/__init__.py diff --git a/api/core/workflow/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py similarity index 95% rename from api/core/workflow/nodes/tool/entities.py rename to api/dify_graph/nodes/tool/entities.py index 8fe33c240a..b041ee66fd 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -4,7 +4,8 @@ from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.nodes.base.entities import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class ToolEntity(BaseModel): @@ -32,6 +33,8 @@ class ToolEntity(BaseModel): class ToolNodeData(BaseNodeData, ToolEntity): + type: NodeType = BuiltinNodeTypes.TOOL + class ToolInput(BaseModel): # TODO: check this type value: Union[Any, list[str]] diff --git a/api/core/workflow/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py similarity index 100% rename from api/core/workflow/nodes/tool/exc.py rename to api/dify_graph/nodes/tool/exc.py diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py similarity index 88% rename from api/core/workflow/nodes/tool/tool_node.py rename to api/dify_graph/nodes/tool/tool_node.py index 0d7270a282..598f0da92e 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,31 +1,28 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.enums import ( - NodeType, +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import ( + BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser -from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment -from core.workflow.variables.variables import ArrayAnyVariable -from extensions.ext_database import db +from dify_graph.file import File, FileTransferMethod +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.nodes.protocols import ToolFileManagerProtocol +from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment +from dify_graph.variables.variables import ArrayAnyVariable from factories import file_factory -from models import ToolFile from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import ToolNodeData @@ -36,7 +33,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.runtime import VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool class ToolNode(Node[ToolNodeData]): @@ -44,18 +42,41 @@ class ToolNode(Node[ToolNodeData]): Tool Node """ - node_type = NodeType.TOOL + node_type = BuiltinNodeTypes.TOOL + + def __init__( + self, + id: str, + config: NodeConfigDict, + graph_init_params: "GraphInitParams", + graph_runtime_state: "GraphRuntimeState", + *, + tool_file_manager_factory: ToolFileManagerProtocol, + ): + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + self._tool_file_manager_factory = tool_file_manager_factory @classmethod def version(cls) -> str: return "1" + def populate_start_event(self, event) -> None: + event.provider_id = self.node_data.provider_id + event.provider_type = self.node_data.provider_type + def _run(self) -> Generator[NodeEventBase, None, None]: """ Run the tool node """ from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError + dify_ctx = self.require_dify_context() + # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -75,7 +96,12 @@ class ToolNode(Node[ToolNodeData]): if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self._node_id, self.node_data, self.invoke_from, variable_pool + dify_ctx.tenant_id, + dify_ctx.app_id, + self._node_id, + self.node_data, + dify_ctx.invoke_from, + variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -109,10 +135,10 @@ class ToolNode(Node[ToolNodeData]): message_stream = ToolEngine.generic_invoke( tool=tool_runtime, tool_parameters=parameters, - user_id=self.user_id, + user_id=dify_ctx.user_id, workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=self.app_id, + app_id=dify_ctx.app_id, conversation_id=conversation_id.text if conversation_id else None, ) except ToolNodeError as e: @@ -133,8 +159,8 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=dify_ctx.user_id, + tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) @@ -264,11 +290,9 @@ class ToolNode(Node[ToolNodeData]): tool_file_id = str(url).split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"Tool file {tool_file_id} does not exist") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not found") mapping = { "tool_file_id": tool_file_id, @@ -287,11 +311,9 @@ class ToolNode(Node[ToolNodeData]): assert message.meta tool_file_id = message.message.text.split("/")[-1].split(".")[0] - with Session(db.engine) as session: - stmt = select(ToolFile).where(ToolFile.id == tool_file_id) - tool_file = session.scalar(stmt) - if tool_file is None: - raise ToolFileError(f"tool file {tool_file_id} not exists") + _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + if not tool_file: + raise ToolFileError(f"tool file {tool_file_id} not exists") mapping = { "tool_file_id": tool_file_id, @@ -467,7 +489,7 @@ class ToolNode(Node[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: ToolNodeData, ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -476,9 +498,8 @@ class ToolNode(Node[ToolNodeData]): :param node_data: node data :return: """ - # Create typed NodeData from dict - typed_node_data = ToolNodeData.model_validate(node_data) - + _ = graph_config # Explicitly mark as unused + typed_node_data = node_data result = {} for parameter_name in typed_node_data.tool_parameters: input = typed_node_data.tool_parameters[parameter_name] diff --git a/api/core/workflow/nodes/variable_aggregator/__init__.py b/api/dify_graph/nodes/variable_aggregator/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_aggregator/__init__.py rename to api/dify_graph/nodes/variable_aggregator/__init__.py diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py similarity index 70% rename from api/core/workflow/nodes/variable_aggregator/entities.py rename to api/dify_graph/nodes/variable_aggregator/entities.py index febbf1d1d6..4779ebd9a9 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/dify_graph/nodes/variable_aggregator/entities.py @@ -1,7 +1,8 @@ from pydantic import BaseModel -from core.workflow.nodes.base import BaseNodeData -from core.workflow.variables.types import SegmentType +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.variables.types import SegmentType class AdvancedSettings(BaseModel): @@ -28,6 +29,7 @@ class VariableAggregatorNodeData(BaseNodeData): Variable Aggregator Node Data. """ + type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR output_type: str variables: list[list[str]] advanced_settings: AdvancedSettings | None = None diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py similarity index 78% rename from api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py rename to api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py index 762b7dab07..7d26de6232 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,14 +1,14 @@ from collections.abc import Mapping -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from core.workflow.variables.segments import Segment +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData +from dify_graph.variables.segments import Segment class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = NodeType.VARIABLE_AGGREGATOR + node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR @classmethod def version(cls) -> str: diff --git a/api/core/workflow/utils/condition/__init__.py b/api/dify_graph/nodes/variable_assigner/__init__.py similarity index 100% rename from api/core/workflow/utils/condition/__init__.py rename to api/dify_graph/nodes/variable_assigner/__init__.py diff --git a/api/tests/unit_tests/core/model_runtime/__base/__init__.py b/api/dify_graph/nodes/variable_assigner/common/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__base/__init__.py rename to api/dify_graph/nodes/variable_assigner/common/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/dify_graph/nodes/variable_assigner/common/exc.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/common/exc.py rename to api/dify_graph/nodes/variable_assigner/common/exc.py diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/dify_graph/nodes/variable_assigner/common/helpers.py similarity index 90% rename from api/core/workflow/nodes/variable_assigner/common/helpers.py rename to api/dify_graph/nodes/variable_assigner/common/helpers.py index 37fde9d1b0..f0b22904a9 100644 --- a/api/core/workflow/nodes/variable_assigner/common/helpers.py +++ b/api/dify_graph/nodes/variable_assigner/common/helpers.py @@ -3,9 +3,9 @@ from typing import Any, TypeVar from pydantic import BaseModel -from core.workflow.variables import Segment -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.types import SegmentType +from dify_graph.variables import Segment +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.types import SegmentType # Use double underscore (`__`) prefix for internal variables # to minimize risk of collision with user-defined variable names. diff --git a/api/core/workflow/nodes/variable_assigner/v1/__init__.py b/api/dify_graph/nodes/variable_assigner/v1/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v1/__init__.py rename to api/dify_graph/nodes/variable_assigner/v1/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py similarity index 76% rename from api/core/workflow/nodes/variable_assigner/v1/node.py rename to api/dify_graph/nodes/variable_assigner/v1/node.py index b987949541..f9b261b191 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -1,28 +1,29 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.variables import SegmentType, VariableBase +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.variables import SegmentType, VariableBase from .node_data import VariableAssignerData, WriteMode if TYPE_CHECKING: - from core.workflow.runtime import GraphRuntimeState + from dify_graph.runtime import GraphRuntimeState class VariableAssignerNode(Node[VariableAssignerData]): - node_type = NodeType.VARIABLE_ASSIGNER + node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -52,21 +53,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerData.model_validate(node_data) - mapping = {} - assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] + assigned_variable_node_id = node_data.assigned_variable_selector[0] if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(typed_node_data.assigned_variable_selector) + selector_key = ".".join(node_data.assigned_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.assigned_variable_selector + mapping[key] = node_data.assigned_variable_selector - selector_key = ".".join(typed_node_data.input_variable_selector) + selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = typed_node_data.input_variable_selector + mapping[key] = node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: diff --git a/api/core/workflow/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py similarity index 65% rename from api/core/workflow/nodes/variable_assigner/v1/node_data.py rename to api/dify_graph/nodes/variable_assigner/v1/node_data.py index 9734d64712..57acb29535 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node_data.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node_data.py @@ -1,7 +1,8 @@ from collections.abc import Sequence from enum import StrEnum -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType class WriteMode(StrEnum): @@ -11,6 +12,7 @@ class WriteMode(StrEnum): class VariableAssignerData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/dify_graph/nodes/variable_assigner/v2/__init__.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v2/__init__.py rename to api/dify_graph/nodes/variable_assigner/v2/__init__.py diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py similarity index 83% rename from api/core/workflow/nodes/variable_assigner/v2/entities.py rename to api/dify_graph/nodes/variable_assigner/v2/entities.py index 2955730289..2b2bbe85de 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/entities.py +++ b/api/dify_graph/nodes/variable_assigner/v2/entities.py @@ -3,7 +3,8 @@ from typing import Any from pydantic import BaseModel, Field -from core.workflow.nodes.base import BaseNodeData +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType from .enums import InputType, Operation @@ -22,5 +23,6 @@ class VariableOperationItem(BaseModel): class VariableAssignerNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER version: str = "2" items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/core/workflow/nodes/variable_assigner/v2/enums.py b/api/dify_graph/nodes/variable_assigner/v2/enums.py similarity index 100% rename from api/core/workflow/nodes/variable_assigner/v2/enums.py rename to api/dify_graph/nodes/variable_assigner/v2/enums.py diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/dify_graph/nodes/variable_assigner/v2/exc.py similarity index 93% rename from api/core/workflow/nodes/variable_assigner/v2/exc.py rename to api/dify_graph/nodes/variable_assigner/v2/exc.py index 05173b3ca1..c50aab8668 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/exc.py +++ b/api/dify_graph/nodes/variable_assigner/v2/exc.py @@ -1,7 +1,7 @@ from collections.abc import Sequence from typing import Any -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError from .enums import InputType, Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/dify_graph/nodes/variable_assigner/v2/helpers.py similarity index 98% rename from api/core/workflow/nodes/variable_assigner/v2/helpers.py rename to api/dify_graph/nodes/variable_assigner/v2/helpers.py index ce3fe9620c..38c69cbe3c 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/dify_graph/nodes/variable_assigner/v2/helpers.py @@ -1,6 +1,6 @@ from typing import Any -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from .enums import Operation diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py similarity index 90% rename from api/core/workflow/nodes/variable_assigner/v2/node.py rename to api/dify_graph/nodes/variable_assigner/v2/node.py index 0d4c3d2774..f04a6b3b80 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -2,14 +2,15 @@ import json from collections.abc import Mapping, MutableMapping, Sequence from typing import TYPE_CHECKING, Any -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from core.workflow.variables import SegmentType, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH from . import helpers from .entities import VariableAssignerNodeData, VariableOperationItem @@ -23,8 +24,8 @@ from .exc import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): @@ -51,12 +52,12 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = NodeType.VARIABLE_ASSIGNER + node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER def __init__( self, id: str, - config: Mapping[str, Any], + config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", ): @@ -94,13 +95,10 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Mapping[str, Any], + node_data: VariableAssignerNodeData, ) -> Mapping[str, Sequence[str]]: - # Create typed NodeData from dict - typed_node_data = VariableAssignerNodeData.model_validate(node_data) - var_mapping: dict[str, Sequence[str]] = {} - for item in typed_node_data.items: + for item in node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping diff --git a/api/core/workflow/repositories/__init__.py b/api/dify_graph/repositories/__init__.py similarity index 69% rename from api/core/workflow/repositories/__init__.py rename to api/dify_graph/repositories/__init__.py index a778151baa..ef70eb09cc 100644 --- a/api/core/workflow/repositories/__init__.py +++ b/api/dify_graph/repositories/__init__.py @@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying storage mechanism. """ -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository __all__ = [ "OrderConfig", diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/dify_graph/repositories/draft_variable_repository.py similarity index 95% rename from api/core/workflow/repositories/draft_variable_repository.py rename to api/dify_graph/repositories/draft_variable_repository.py index 66ef714c16..b2ebfacffd 100644 --- a/api/core/workflow/repositories/draft_variable_repository.py +++ b/api/dify_graph/repositories/draft_variable_repository.py @@ -6,7 +6,7 @@ from typing import Any, Protocol from sqlalchemy.orm import Session -from core.workflow.enums import NodeType +from dify_graph.enums import NodeType class DraftVariableSaver(Protocol): diff --git a/api/core/workflow/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py similarity index 96% rename from api/core/workflow/repositories/human_input_form_repository.py rename to api/dify_graph/repositories/human_input_form_repository.py index efde59c6fd..88966831cb 100644 --- a/api/core/workflow/repositories/human_input_form_repository.py +++ b/api/dify_graph/repositories/human_input_form_repository.py @@ -4,8 +4,8 @@ from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Protocol -from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus class HumanInputError(Exception): diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py similarity index 95% rename from api/core/workflow/repositories/workflow_execution_repository.py rename to api/dify_graph/repositories/workflow_execution_repository.py index d9ce591db8..ef83f07649 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/dify_graph/repositories/workflow_execution_repository.py @@ -1,6 +1,6 @@ from typing import Protocol -from core.workflow.entities import WorkflowExecution +from dify_graph.entities import WorkflowExecution class WorkflowExecutionRepository(Protocol): diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py similarity index 97% rename from api/core/workflow/repositories/workflow_node_execution_repository.py rename to api/dify_graph/repositories/workflow_node_execution_repository.py index 43b41ff6b8..e6c1c3e497 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/dify_graph/repositories/workflow_node_execution_repository.py @@ -2,7 +2,7 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Protocol -from core.workflow.entities import WorkflowNodeExecution +from dify_graph.entities import WorkflowNodeExecution @dataclass diff --git a/api/core/workflow/runtime/__init__.py b/api/dify_graph/runtime/__init__.py similarity index 64% rename from api/core/workflow/runtime/__init__.py rename to api/dify_graph/runtime/__init__.py index 10014c7182..adca07e59a 100644 --- a/api/core/workflow/runtime/__init__.py +++ b/api/dify_graph/runtime/__init__.py @@ -1,9 +1,17 @@ -from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state import ( + ChildEngineBuilderNotConfiguredError, + ChildEngineError, + ChildGraphNotFoundError, + GraphRuntimeState, +) from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper from .variable_pool import VariablePool, VariableValue __all__ = [ + "ChildEngineBuilderNotConfiguredError", + "ChildEngineError", + "ChildGraphNotFoundError", "GraphRuntimeState", "ReadOnlyGraphRuntimeState", "ReadOnlyGraphRuntimeStateWrapper", diff --git a/api/core/workflow/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py similarity index 90% rename from api/core/workflow/runtime/graph_runtime_state.py rename to api/dify_graph/runtime/graph_runtime_state.py index 0af6bf49bc..41acc6db35 100644 --- a/api/core/workflow/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -10,12 +10,13 @@ from typing import TYPE_CHECKING, Any, ClassVar, Protocol from pydantic import BaseModel, Field from pydantic.json import pydantic_encoder -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.runtime.variable_pool import VariablePool +from dify_graph.enums import NodeExecutionType, NodeState, NodeType +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime.variable_pool import VariablePool if TYPE_CHECKING: - from core.workflow.entities.pause_reason import PauseReason + from dify_graph.entities import GraphInitParams + from dify_graph.entities.pause_reason import PauseReason class ReadyQueueProtocol(Protocol): @@ -135,6 +136,31 @@ class GraphProtocol(Protocol): def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... +class ChildGraphEngineBuilderProtocol(Protocol): + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: ... + + +class ChildEngineError(ValueError): + """Base error type for child-engine creation failures.""" + + +class ChildEngineBuilderNotConfiguredError(ChildEngineError): + """Raised when child-engine creation is requested without a bound builder.""" + + +class ChildGraphNotFoundError(ChildEngineError): + """Raised when the requested child graph entry point cannot be resolved.""" + + class _GraphStateSnapshot(BaseModel): """Serializable graph state snapshot for node/edge states.""" @@ -209,6 +235,7 @@ class GraphRuntimeState: self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() self._deferred_nodes: set[str] = set() + self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None # Node and edges states needed to be restored into # graph object. @@ -250,6 +277,31 @@ class GraphRuntimeState: if self._graph is not None: _ = self.response_coordinator + def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: + self._child_engine_builder = builder + + def create_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> Any: + if self._child_engine_builder is None: + raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") + + return self._child_engine_builder.build_child_engine( + workflow_id=workflow_id, + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + graph_config=graph_config, + root_node_id=root_node_id, + layers=layers, + ) + # ------------------------------------------------------------------ # Primary collaborators # ------------------------------------------------------------------ @@ -433,13 +485,13 @@ class GraphRuntimeState: # ------------------------------------------------------------------ def _build_ready_queue(self) -> ReadyQueueProtocol: # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("core.workflow.graph_engine.ready_queue") + module = importlib.import_module("dify_graph.graph_engine.ready_queue") in_memory_cls = module.InMemoryReadyQueue return in_memory_cls() def _build_graph_execution(self) -> GraphExecutionProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("core.workflow.graph_engine.domain.graph_execution") + module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") graph_execution_cls = module.GraphExecution workflow_id = self._pending_graph_execution_workflow_id or "" self._pending_graph_execution_workflow_id = None @@ -447,7 +499,7 @@ class GraphRuntimeState: def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("core.workflow.graph_engine.response_coordinator") + module = importlib.import_module("dify_graph.graph_engine.response_coordinator") coordinator_cls = module.ResponseStreamCoordinator return coordinator_cls(variable_pool=self.variable_pool, graph=graph) diff --git a/api/core/workflow/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py similarity index 92% rename from api/core/workflow/runtime/graph_runtime_state_protocol.py rename to api/dify_graph/runtime/graph_runtime_state_protocol.py index 81d87e5a74..7e55ece3f1 100644 --- a/api/core/workflow/runtime/graph_runtime_state_protocol.py +++ b/api/dify_graph/runtime/graph_runtime_state_protocol.py @@ -1,9 +1,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.system_variable import SystemVariableReadOnlyView -from core.workflow.variables.segments import Segment +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.system_variable import SystemVariableReadOnlyView +from dify_graph.variables.segments import Segment class ReadOnlyVariablePool(Protocol): diff --git a/api/core/workflow/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py similarity index 93% rename from api/core/workflow/runtime/read_only_wrappers.py rename to api/dify_graph/runtime/read_only_wrappers.py index 25a834a539..ca06d88c3d 100644 --- a/api/core/workflow/runtime/read_only_wrappers.py +++ b/api/dify_graph/runtime/read_only_wrappers.py @@ -4,9 +4,9 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Any -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.system_variable import SystemVariableReadOnlyView -from core.workflow.variables.segments import Segment +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.system_variable import SystemVariableReadOnlyView +from dify_graph.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState from .variable_pool import VariablePool diff --git a/api/core/workflow/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py similarity index 92% rename from api/core/workflow/runtime/variable_pool.py rename to api/dify_graph/runtime/variable_pool.py index 48ad102b43..e3ef6a2897 100644 --- a/api/core/workflow/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -8,18 +8,18 @@ from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field -from core.workflow.constants import ( +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, RAG_PIPELINE_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from core.workflow.file import File, FileAttribute, file_manager -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import Segment, SegmentGroup, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.segments import FileSegment, ObjectSegment -from core.workflow.variables.variables import RAGPipelineVariableInput, Variable +from dify_graph.file import File, FileAttribute, file_manager +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import Segment, SegmentGroup, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.segments import FileSegment, ObjectSegment +from dify_graph.variables.variables import RAGPipelineVariableInput, Variable from factories import variable_factory VariableValue = Union[str, int, float, dict[str, object], list[object], File] @@ -65,9 +65,15 @@ class VariablePool(BaseModel): # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool + # Add conversation variables to the variable pool. When restoring from a serialized + # snapshot, `variable_dictionary` already carries the latest runtime values. + # In that case, keep existing entries instead of overwriting them with the + # bootstrap list. for var in self.conversation_variables: - self.add((CONVERSATION_VARIABLE_NODE_ID, var.name), var) + selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) + if self._has(selector): + continue + self.add(selector, var) # Add rag pipeline variables to the variable pool if self.rag_pipeline_variables: rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) diff --git a/api/core/workflow/system_variable.py b/api/dify_graph/system_variable.py similarity index 98% rename from api/core/workflow/system_variable.py rename to api/dify_graph/system_variable.py index 4144f79b8a..cc5deda892 100644 --- a/api/core/workflow/system_variable.py +++ b/api/dify_graph/system_variable.py @@ -7,8 +7,8 @@ from uuid import uuid4 from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator -from core.workflow.enums import SystemVariableKey -from core.workflow.file.models import File +from dify_graph.enums import SystemVariableKey +from dify_graph.file.models import File class SystemVariable(BaseModel): diff --git a/api/tests/unit_tests/core/model_runtime/__init__.py b/api/dify_graph/utils/__init__.py similarity index 100% rename from api/tests/unit_tests/core/model_runtime/__init__.py rename to api/dify_graph/utils/__init__.py diff --git a/web/app/components/header/account-setting/members-page/edit-workspace-modal/index.module.css b/api/dify_graph/utils/condition/__init__.py similarity index 100% rename from web/app/components/header/account-setting/members-page/edit-workspace-modal/index.module.css rename to api/dify_graph/utils/condition/__init__.py diff --git a/api/core/workflow/utils/condition/entities.py b/api/dify_graph/utils/condition/entities.py similarity index 100% rename from api/core/workflow/utils/condition/entities.py rename to api/dify_graph/utils/condition/entities.py diff --git a/api/core/workflow/utils/condition/processor.py b/api/dify_graph/utils/condition/processor.py similarity index 98% rename from api/core/workflow/utils/condition/processor.py rename to api/dify_graph/utils/condition/processor.py index 4e635cc2f2..dea72d96c2 100644 --- a/api/core/workflow/utils/condition/processor.py +++ b/api/dify_graph/utils/condition/processor.py @@ -2,10 +2,10 @@ import json from collections.abc import Mapping, Sequence from typing import Literal, NamedTuple -from core.workflow.file import FileAttribute, file_manager -from core.workflow.runtime import VariablePool -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayBooleanSegment, BooleanSegment +from dify_graph.file import FileAttribute, file_manager +from dify_graph.runtime import VariablePool +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment from .entities import Condition, SubCondition, SupportedComparisonOperator diff --git a/api/core/workflow/variable_loader.py b/api/dify_graph/variable_loader.py similarity index 95% rename from api/core/workflow/variable_loader.py rename to api/dify_graph/variable_loader.py index dfa4ce2e75..d263450334 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/dify_graph/variable_loader.py @@ -2,9 +2,9 @@ import abc from collections.abc import Mapping, Sequence from typing import Any, Protocol -from core.workflow.runtime import VariablePool -from core.workflow.variables import VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH +from dify_graph.runtime import VariablePool +from dify_graph.variables import VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH class VariableLoader(Protocol): diff --git a/api/core/workflow/variables/__init__.py b/api/dify_graph/variables/__init__.py similarity index 100% rename from api/core/workflow/variables/__init__.py rename to api/dify_graph/variables/__init__.py diff --git a/api/core/workflow/variables/consts.py b/api/dify_graph/variables/consts.py similarity index 100% rename from api/core/workflow/variables/consts.py rename to api/dify_graph/variables/consts.py diff --git a/api/core/workflow/variables/exc.py b/api/dify_graph/variables/exc.py similarity index 100% rename from api/core/workflow/variables/exc.py rename to api/dify_graph/variables/exc.py diff --git a/api/core/workflow/variables/input_entities.py b/api/dify_graph/variables/input_entities.py similarity index 97% rename from api/core/workflow/variables/input_entities.py rename to api/dify_graph/variables/input_entities.py index 9a42012f0a..e6a68ea359 100644 --- a/api/core/workflow/variables/input_entities.py +++ b/api/dify_graph/variables/input_entities.py @@ -5,7 +5,7 @@ from typing import Any from jsonschema import Draft7Validator, SchemaError from pydantic import BaseModel, Field, field_validator -from core.workflow.file import FileTransferMethod, FileType +from dify_graph.file import FileTransferMethod, FileType class VariableEntityType(StrEnum): diff --git a/api/core/workflow/variables/segment_group.py b/api/dify_graph/variables/segment_group.py similarity index 100% rename from api/core/workflow/variables/segment_group.py rename to api/dify_graph/variables/segment_group.py diff --git a/api/core/workflow/variables/segments.py b/api/dify_graph/variables/segments.py similarity index 99% rename from api/core/workflow/variables/segments.py rename to api/dify_graph/variables/segments.py index 64bba7dbe2..bdb213ed48 100644 --- a/api/core/workflow/variables/segments.py +++ b/api/dify_graph/variables/segments.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator -from core.workflow.file import File +from dify_graph.file import File from .types import SegmentType diff --git a/api/core/workflow/variables/types.py b/api/dify_graph/variables/types.py similarity index 98% rename from api/core/workflow/variables/types.py rename to api/dify_graph/variables/types.py index 596905c26d..53bf495a27 100644 --- a/api/core/workflow/variables/types.py +++ b/api/dify_graph/variables/types.py @@ -4,10 +4,10 @@ from collections.abc import Mapping from enum import StrEnum from typing import TYPE_CHECKING, Any -from core.workflow.file.models import File +from dify_graph.file.models import File if TYPE_CHECKING: - pass + from dify_graph.variables.segments import Segment class ArrayValidation(StrEnum): @@ -219,7 +219,7 @@ class SegmentType(StrEnum): return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) @staticmethod - def get_zero_value(t: SegmentType): + def get_zero_value(t: SegmentType) -> Segment: # Lazy import to avoid circular dependency from factories import variable_factory diff --git a/api/core/workflow/variables/utils.py b/api/dify_graph/variables/utils.py similarity index 100% rename from api/core/workflow/variables/utils.py rename to api/dify_graph/variables/utils.py diff --git a/api/core/workflow/variables/variables.py b/api/dify_graph/variables/variables.py similarity index 100% rename from api/core/workflow/variables/variables.py rename to api/dify_graph/variables/variables.py diff --git a/api/core/workflow/workflow_type_encoder.py b/api/dify_graph/workflow_type_encoder.py similarity index 95% rename from api/core/workflow/workflow_type_encoder.py rename to api/dify_graph/workflow_type_encoder.py index a192b884f7..3dd846b3cb 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/dify_graph/workflow_type_encoder.py @@ -4,8 +4,8 @@ from typing import Any, overload from pydantic import BaseModel -from core.workflow.file.models import File -from core.workflow.variables import Segment +from dify_graph.file.models import File +from dify_graph.variables import Segment class WorkflowRuntimeTypeConverter: diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh index 1a675b3338..6b904b7d0d 100755 --- a/api/docker/entrypoint.sh +++ b/api/docker/entrypoint.sh @@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then if [[ -z "${CELERY_QUEUES}" ]]; then if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues - DEFAULT_QUEUES="api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + DEFAULT_QUEUES="api_token,dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi else DEFAULT_QUEUES="${CELERY_QUEUES}" diff --git a/api/docs/swagger.yaml b/api/docs/swagger.yaml new file mode 100644 index 0000000000..7bb26800a4 --- /dev/null +++ b/api/docs/swagger.yaml @@ -0,0 +1,1642 @@ +openapi: 3.0.3 +info: + title: Dify Console API + description: API documentation for Dify Console – Datasets, Snippets, Snippet Workflows, and Evaluation modules. + version: 1.0.0 + +servers: + - url: /console/api + +tags: + - name: Datasets + description: Knowledge base (dataset) CRUD and ancillary operations + - name: Dataset Evaluation + description: Knowledge base retrieval evaluation + - name: Snippets + description: Customized snippet CRUD, import/export, and dependency checks + - name: Snippet Workflows + description: Snippet draft/published workflow operations, node runs, and workflow runs + - name: Evaluation + description: Evaluation configuration, runs, metrics, and node info for App / Snippet targets + +# ============================================================ +# Paths +# ============================================================ +paths: + + # ---------------------------------------------------------- + # Datasets + # ---------------------------------------------------------- + + /datasets: + get: + tags: [Datasets] + summary: List datasets + operationId: get_datasets + parameters: + - $ref: '#/components/parameters/PageParam' + - $ref: '#/components/parameters/LimitParam' + - name: keyword + in: query + schema: { type: string } + - name: include_all + in: query + schema: { type: boolean, default: false } + - name: ids + in: query + schema: { type: array, items: { type: string } } + - name: tag_ids + in: query + schema: { type: array, items: { type: string } } + responses: + '200': + description: Datasets retrieved successfully + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/DatasetDetail' } } + page: { type: integer } + limit: { type: integer } + total: { type: integer } + has_more: { type: boolean } + post: + tags: [Datasets] + summary: Create a dataset + operationId: create_dataset + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/DatasetCreatePayload' } + responses: + '201': + description: Dataset created successfully + content: + application/json: + schema: { $ref: '#/components/schemas/DatasetDetail' } + '400': + $ref: '#/components/responses/BadRequest' + + /datasets/{dataset_id}: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset details + operationId: get_dataset + responses: + '200': + description: Dataset retrieved successfully + content: + application/json: + schema: { $ref: '#/components/schemas/DatasetDetail' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + patch: + tags: [Datasets] + summary: Update dataset + operationId: update_dataset + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/DatasetUpdatePayload' } + responses: + '200': + description: Dataset updated successfully + content: + application/json: + schema: { $ref: '#/components/schemas/DatasetDetail' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + delete: + tags: [Datasets] + summary: Delete dataset + operationId: delete_dataset + responses: + '204': { description: Dataset deleted successfully } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/use-check: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Check if dataset is in use + operationId: check_dataset_use + responses: + '200': + description: Dataset use status + content: + application/json: + schema: + type: object + properties: + is_using: { type: boolean } + + /datasets/{dataset_id}/queries: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset query history + operationId: get_dataset_queries + parameters: + - $ref: '#/components/parameters/PageParam' + - $ref: '#/components/parameters/LimitParam' + responses: + '200': + description: Query history retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { type: object } } + page: { type: integer } + limit: { type: integer } + total: { type: integer } + has_more: { type: boolean } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/indexing-estimate: + post: + tags: [Datasets] + summary: Estimate dataset indexing cost + operationId: estimate_dataset_indexing + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/IndexingEstimatePayload' } + responses: + '200': { description: Indexing estimate calculated } + + /datasets/{dataset_id}/related-apps: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get applications related to dataset + operationId: get_dataset_related_apps + responses: + '200': + description: Related apps retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { type: object } } + total: { type: integer } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/indexing-status: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset indexing status + operationId: get_dataset_indexing_status + responses: + '200': + description: Indexing status retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { type: object } } + + /datasets/api-keys: + get: + tags: [Datasets] + summary: Get dataset API keys + operationId: get_dataset_api_keys + responses: + '200': + description: API keys retrieved + content: + application/json: + schema: + type: object + properties: + items: { type: array, items: { $ref: '#/components/schemas/ApiKeyItem' } } + post: + tags: [Datasets] + summary: Create dataset API key + operationId: create_dataset_api_key + responses: + '200': + description: API key created + content: + application/json: + schema: { $ref: '#/components/schemas/ApiKeyItem' } + '400': { $ref: '#/components/responses/BadRequest' } + + /datasets/api-keys/{api_key_id}: + parameters: + - name: api_key_id + in: path + required: true + schema: { type: string, format: uuid } + delete: + tags: [Datasets] + summary: Delete dataset API key + operationId: delete_dataset_api_key + responses: + '204': { description: API key deleted } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/api-keys/{status}: + parameters: + - $ref: '#/components/parameters/DatasetId' + - name: status + in: path + required: true + schema: { type: string, enum: [enable, disable] } + post: + tags: [Datasets] + summary: Enable or disable dataset API + operationId: toggle_dataset_api + responses: + '200': { description: Status updated } + + /datasets/api-base-info: + get: + tags: [Datasets] + summary: Get dataset API base URL + operationId: get_dataset_api_base_info + responses: + '200': + description: API base info + content: + application/json: + schema: + type: object + properties: + api_base_url: { type: string } + + /datasets/retrieval-setting: + get: + tags: [Datasets] + summary: Get dataset retrieval settings + operationId: get_dataset_retrieval_setting + responses: + '200': + description: Retrieval settings + content: + application/json: + schema: + type: object + properties: + retrieval_method: { type: array, items: { type: string } } + + /datasets/retrieval-setting/{vector_type}: + parameters: + - name: vector_type + in: path + required: true + schema: { type: string } + get: + tags: [Datasets] + summary: Get mock retrieval settings by vector type + operationId: get_dataset_retrieval_setting_mock + responses: + '200': + description: Mock retrieval settings + content: + application/json: + schema: + type: object + properties: + retrieval_method: { type: array, items: { type: string } } + + /datasets/{dataset_id}/error-docs: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset error documents + operationId: get_dataset_error_docs + responses: + '200': + description: Error documents retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { type: object } } + total: { type: integer } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/permission-part-users: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset permission user list + operationId: get_dataset_permission_users + responses: + '200': + description: Permission users retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { type: object } } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/auto-disable-logs: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Datasets] + summary: Get dataset auto-disable logs + operationId: get_dataset_auto_disable_logs + responses: + '200': { description: Auto-disable logs retrieved } + '404': { $ref: '#/components/responses/NotFound' } + + # ---------------------------------------------------------- + # Dataset Evaluation + # ---------------------------------------------------------- + + /datasets/{dataset_id}/evaluation/template/download: + parameters: + - $ref: '#/components/parameters/DatasetId' + post: + tags: [Dataset Evaluation] + summary: Download evaluation dataset template for knowledge base + operationId: download_dataset_evaluation_template + responses: + '200': + description: XLSX template streamed as attachment + content: + application/vnd.openxmlformats-officedocument.spreadsheetml.sheet: {} + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Dataset Evaluation] + summary: Get evaluation configuration for knowledge base + operationId: get_dataset_evaluation_config + responses: + '200': + description: Evaluation configuration retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfig' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + put: + tags: [Dataset Evaluation] + summary: Save evaluation configuration for knowledge base + operationId: save_dataset_evaluation_config + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfigData' } + responses: + '200': + description: Configuration saved + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfig' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation/run: + parameters: + - $ref: '#/components/parameters/DatasetId' + post: + tags: [Dataset Evaluation] + summary: Start evaluation run for knowledge base + operationId: start_dataset_evaluation_run + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRunRequest' } + responses: + '200': + description: Evaluation run started + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRun' } + '400': { $ref: '#/components/responses/BadRequest' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + '429': { description: Max concurrent runs reached } + + /datasets/{dataset_id}/evaluation/logs: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Dataset Evaluation] + summary: Get evaluation run history for knowledge base + operationId: get_dataset_evaluation_logs + parameters: + - $ref: '#/components/parameters/PageParam' + - name: page_size + in: query + schema: { type: integer, default: 20 } + responses: + '200': + description: Evaluation logs retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/EvaluationRun' } } + total: { type: integer } + page: { type: integer } + page_size: { type: integer } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation/runs/{run_id}: + parameters: + - $ref: '#/components/parameters/DatasetId' + - $ref: '#/components/parameters/RunId' + get: + tags: [Dataset Evaluation] + summary: Get evaluation run detail for knowledge base + operationId: get_dataset_evaluation_run_detail + parameters: + - $ref: '#/components/parameters/PageParam' + - name: page_size + in: query + schema: { type: integer, default: 50 } + responses: + '200': + description: Run detail with items + content: + application/json: + schema: + type: object + properties: + run: { $ref: '#/components/schemas/EvaluationRun' } + items: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/EvaluationRunItem' } } + total: { type: integer } + page: { type: integer } + page_size: { type: integer } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation/runs/{run_id}/cancel: + parameters: + - $ref: '#/components/parameters/DatasetId' + - $ref: '#/components/parameters/RunId' + post: + tags: [Dataset Evaluation] + summary: Cancel a running knowledge base evaluation + operationId: cancel_dataset_evaluation_run + responses: + '200': + description: Evaluation run cancelled + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRun' } + '400': { $ref: '#/components/responses/BadRequest' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation/metrics: + parameters: + - $ref: '#/components/parameters/DatasetId' + get: + tags: [Dataset Evaluation] + summary: Get available retrieval evaluation metrics + operationId: get_dataset_evaluation_metrics + responses: + '200': + description: Metrics retrieved + content: + application/json: + schema: + type: object + properties: + metrics: { type: array, items: { type: string } } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + /datasets/{dataset_id}/evaluation/files/{file_id}: + parameters: + - $ref: '#/components/parameters/DatasetId' + - name: file_id + in: path + required: true + schema: { type: string, format: uuid } + get: + tags: [Dataset Evaluation] + summary: Download evaluation file for knowledge base + operationId: download_dataset_evaluation_file + responses: + '200': + description: File info and download URL + content: + application/json: + schema: { $ref: '#/components/schemas/FileInfo' } + '403': { $ref: '#/components/responses/Forbidden' } + '404': { $ref: '#/components/responses/NotFound' } + + # ---------------------------------------------------------- + # Snippets + # ---------------------------------------------------------- + + /workspaces/current/customized-snippets: + get: + tags: [Snippets] + summary: List customized snippets + operationId: list_customized_snippets + parameters: + - $ref: '#/components/parameters/PageParam' + - $ref: '#/components/parameters/LimitParam' + - name: keyword + in: query + schema: { type: string } + - name: is_published + in: query + schema: { type: boolean } + responses: + '200': + description: Snippets retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/SnippetListItem' } } + page: { type: integer } + limit: { type: integer } + total: { type: integer } + has_more: { type: boolean } + post: + tags: [Snippets] + summary: Create a customized snippet + operationId: create_customized_snippet + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/CreateSnippetPayload' } + responses: + '201': + description: Snippet created + content: + application/json: + schema: { $ref: '#/components/schemas/Snippet' } + '400': { $ref: '#/components/responses/BadRequest' } + + /workspaces/current/customized-snippets/{snippet_id}: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippets] + summary: Get snippet details + operationId: get_customized_snippet + responses: + '200': + description: Snippet retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/Snippet' } + '404': { $ref: '#/components/responses/NotFound' } + patch: + tags: [Snippets] + summary: Update snippet + operationId: update_customized_snippet + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/UpdateSnippetPayload' } + responses: + '200': + description: Snippet updated + content: + application/json: + schema: { $ref: '#/components/schemas/Snippet' } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + delete: + tags: [Snippets] + summary: Delete snippet + operationId: delete_customized_snippet + responses: + '204': { description: Snippet deleted } + '404': { $ref: '#/components/responses/NotFound' } + + /workspaces/current/customized-snippets/{snippet_id}/export: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippets] + summary: Export snippet as DSL (YAML) + operationId: export_customized_snippet + parameters: + - name: include_secret + in: query + schema: { type: string, enum: ['true', 'false'], default: 'false' } + responses: + '200': + description: Snippet DSL exported + content: + application/x-yaml: {} + '404': { $ref: '#/components/responses/NotFound' } + + /workspaces/current/customized-snippets/imports: + post: + tags: [Snippets] + summary: Import snippet from DSL + operationId: import_customized_snippet + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetImportPayload' } + responses: + '200': { description: Import succeeded } + '202': { description: Import pending confirmation } + '400': { $ref: '#/components/responses/BadRequest' } + + /workspaces/current/customized-snippets/imports/{import_id}/confirm: + parameters: + - name: import_id + in: path + required: true + schema: { type: string } + post: + tags: [Snippets] + summary: Confirm a pending snippet import + operationId: confirm_snippet_import + responses: + '200': { description: Import confirmed } + '400': { $ref: '#/components/responses/BadRequest' } + + /workspaces/current/customized-snippets/{snippet_id}/check-dependencies: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippets] + summary: Check snippet dependencies + operationId: check_snippet_dependencies + responses: + '200': { description: Dependencies checked } + '404': { $ref: '#/components/responses/NotFound' } + + /workspaces/current/customized-snippets/{snippet_id}/use-count/increment: + parameters: + - $ref: '#/components/parameters/SnippetId' + post: + tags: [Snippets] + summary: Increment snippet use count + operationId: increment_snippet_use_count + responses: + '200': + description: Use count incremented + content: + application/json: + schema: + type: object + properties: + result: { type: string } + use_count: { type: integer } + '404': { $ref: '#/components/responses/NotFound' } + + # ---------------------------------------------------------- + # Snippet Workflows + # ---------------------------------------------------------- + + /snippets/{snippet_id}/workflows/draft: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippet Workflows] + summary: Get draft workflow for snippet + operationId: get_snippet_draft_workflow + responses: + '200': + description: Draft workflow retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/Workflow' } + '404': { $ref: '#/components/responses/NotFound' } + post: + tags: [Snippet Workflows] + summary: Sync draft workflow + operationId: sync_snippet_draft_workflow + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetDraftSyncPayload' } + responses: + '200': + description: Draft synced + content: + application/json: + schema: + type: object + properties: + result: { type: string } + hash: { type: string } + updated_at: { type: number } + '400': { description: Hash mismatch } + + /snippets/{snippet_id}/workflows/draft/config: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippet Workflows] + summary: Get snippet draft workflow config limits + operationId: get_snippet_draft_config + responses: + '200': + description: Config retrieved + content: + application/json: + schema: + type: object + properties: + parallel_depth_limit: { type: integer } + + /snippets/{snippet_id}/workflows/publish: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippet Workflows] + summary: Get published workflow for snippet + operationId: get_snippet_published_workflow + responses: + '200': + description: Published workflow retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/Workflow' } + '404': { $ref: '#/components/responses/NotFound' } + post: + tags: [Snippet Workflows] + summary: Publish snippet workflow + operationId: publish_snippet_workflow + responses: + '200': + description: Workflow published + content: + application/json: + schema: + type: object + properties: + result: { type: string } + created_at: { type: number } + '400': { $ref: '#/components/responses/BadRequest' } + + /snippets/{snippet_id}/workflows/default-workflow-block-configs: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippet Workflows] + summary: Get default block configurations + operationId: get_snippet_default_block_configs + responses: + '200': { description: Default block configs retrieved } + + /snippets/{snippet_id}/workflow-runs: + parameters: + - $ref: '#/components/parameters/SnippetId' + get: + tags: [Snippet Workflows] + summary: List workflow runs for snippet + operationId: list_snippet_workflow_runs + parameters: + - name: last_id + in: query + schema: { type: string } + - $ref: '#/components/parameters/LimitParam' + responses: + '200': + description: Workflow runs retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/WorkflowRunPagination' } + + /snippets/{snippet_id}/workflow-runs/{run_id}: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/RunId' + get: + tags: [Snippet Workflows] + summary: Get workflow run detail + operationId: get_snippet_workflow_run_detail + responses: + '200': + description: Run detail retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/WorkflowRunDetail' } + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflow-runs/{run_id}/node-executions: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/RunId' + get: + tags: [Snippet Workflows] + summary: List node executions for a workflow run + operationId: list_snippet_workflow_run_node_executions + responses: + '200': + description: Node executions retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/WorkflowNodeExecution' } } + + /snippets/{snippet_id}/workflows/draft/nodes/{node_id}/run: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/NodeId' + post: + tags: [Snippet Workflows] + summary: Run a single node in draft workflow (single-step debug) + operationId: run_snippet_draft_node + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetDraftNodeRunPayload' } + responses: + '200': + description: Node run result + content: + application/json: + schema: { $ref: '#/components/schemas/WorkflowNodeExecution' } + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflows/draft/nodes/{node_id}/last-run: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/NodeId' + get: + tags: [Snippet Workflows] + summary: Get last run result for a node + operationId: get_snippet_draft_node_last_run + responses: + '200': + description: Node last run retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/WorkflowNodeExecution' } + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflows/draft/iteration/nodes/{node_id}/run: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/NodeId' + post: + tags: [Snippet Workflows] + summary: Run iteration node (SSE stream) + operationId: run_snippet_draft_iteration_node + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetIterationNodeRunPayload' } + responses: + '200': + description: SSE stream with iteration progress + content: + text/event-stream: {} + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflows/draft/loop/nodes/{node_id}/run: + parameters: + - $ref: '#/components/parameters/SnippetId' + - $ref: '#/components/parameters/NodeId' + post: + tags: [Snippet Workflows] + summary: Run loop node (SSE stream) + operationId: run_snippet_draft_loop_node + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetLoopNodeRunPayload' } + responses: + '200': + description: SSE stream with loop progress + content: + text/event-stream: {} + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflows/draft/run: + parameters: + - $ref: '#/components/parameters/SnippetId' + post: + tags: [Snippet Workflows] + summary: Run draft workflow (SSE stream) + operationId: run_snippet_draft_workflow + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/SnippetDraftRunPayload' } + responses: + '200': + description: SSE stream with execution progress + content: + text/event-stream: {} + '404': { $ref: '#/components/responses/NotFound' } + + /snippets/{snippet_id}/workflow-runs/tasks/{task_id}/stop: + parameters: + - $ref: '#/components/parameters/SnippetId' + - name: task_id + in: path + required: true + schema: { type: string } + post: + tags: [Snippet Workflows] + summary: Stop a running snippet workflow task + operationId: stop_snippet_workflow_task + responses: + '200': + description: Task stopped + content: + application/json: + schema: + type: object + properties: + result: { type: string } + + # ---------------------------------------------------------- + # Evaluation (App / Snippet targets) + # ---------------------------------------------------------- + + /{evaluate_target_type}/{evaluate_target_id}/dataset-template/download: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + post: + tags: [Evaluation] + summary: Download evaluation dataset template + operationId: download_evaluation_dataset_template + responses: + '200': + description: XLSX template streamed as attachment + content: + application/vnd.openxmlformats-officedocument.spreadsheetml.sheet: {} + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + get: + tags: [Evaluation] + summary: Get evaluation configuration + operationId: get_evaluation_detail + responses: + '200': + description: Evaluation configuration retrieved + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfig' } + '404': { $ref: '#/components/responses/NotFound' } + put: + tags: [Evaluation] + summary: Save evaluation configuration + operationId: save_evaluation_detail + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfigData' } + responses: + '200': + description: Configuration saved + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationConfig' } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/logs: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + get: + tags: [Evaluation] + summary: Get evaluation run history + operationId: get_evaluation_logs + parameters: + - $ref: '#/components/parameters/PageParam' + - name: page_size + in: query + schema: { type: integer, default: 20 } + responses: + '200': + description: Evaluation logs retrieved + content: + application/json: + schema: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/EvaluationRun' } } + total: { type: integer } + page: { type: integer } + page_size: { type: integer } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/run: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + post: + tags: [Evaluation] + summary: Start an evaluation run + operationId: start_evaluation_run + requestBody: + required: true + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRunRequest' } + responses: + '200': + description: Evaluation run started + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRun' } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + '429': { description: Max concurrent runs reached } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/runs/{run_id}: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + - $ref: '#/components/parameters/RunId' + get: + tags: [Evaluation] + summary: Get evaluation run detail with items + operationId: get_evaluation_run_detail + parameters: + - $ref: '#/components/parameters/PageParam' + - name: page_size + in: query + schema: { type: integer, default: 50 } + responses: + '200': + description: Run detail with paginated items + content: + application/json: + schema: + type: object + properties: + run: { $ref: '#/components/schemas/EvaluationRun' } + items: + type: object + properties: + data: { type: array, items: { $ref: '#/components/schemas/EvaluationRunItem' } } + total: { type: integer } + page: { type: integer } + page_size: { type: integer } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/runs/{run_id}/cancel: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + - $ref: '#/components/parameters/RunId' + post: + tags: [Evaluation] + summary: Cancel a running evaluation + operationId: cancel_evaluation_run + responses: + '200': + description: Evaluation run cancelled + content: + application/json: + schema: { $ref: '#/components/schemas/EvaluationRun' } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/metrics: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + get: + tags: [Evaluation] + summary: Get available evaluation metrics per category + operationId: get_evaluation_metrics + responses: + '200': + description: Metrics grouped by category + content: + application/json: + schema: + type: object + properties: + metrics: + type: object + additionalProperties: + type: array + items: { type: string } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/node-info: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + post: + tags: [Evaluation] + summary: Get workflow/snippet node info grouped by metric + operationId: get_evaluation_node_info + requestBody: + content: + application/json: + schema: + type: object + properties: + metrics: + type: array + items: { type: string } + description: Metric names to query. Omit or pass empty to get all nodes. + responses: + '200': + description: Node info grouped by metric or "all" + content: + application/json: + schema: + type: object + additionalProperties: + type: array + items: { $ref: '#/components/schemas/NodeInfo' } + '404': { $ref: '#/components/responses/NotFound' } + + /evaluation/available-metrics: + get: + tags: [Evaluation] + summary: Get centrally-defined list of evaluation metrics + operationId: get_available_evaluation_metrics + responses: + '200': + description: Available metrics list + content: + application/json: + schema: + type: object + properties: + metrics: + type: array + items: { type: string } + example: + - faithfulness + - answer_relevancy + - answer_correctness + - semantic_similarity + - context_precision + - context_recall + - context_relevance + - tool_correctness + - task_completion + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/files/{file_id}: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + - name: file_id + in: path + required: true + schema: { type: string, format: uuid } + get: + tags: [Evaluation] + summary: Download evaluation file + operationId: download_evaluation_file + responses: + '200': + description: File info and download URL + content: + application/json: + schema: { $ref: '#/components/schemas/FileInfo' } + '404': { $ref: '#/components/responses/NotFound' } + + /{evaluate_target_type}/{evaluate_target_id}/evaluation/version: + parameters: + - $ref: '#/components/parameters/EvaluateTargetType' + - $ref: '#/components/parameters/EvaluateTargetId' + get: + tags: [Evaluation] + summary: Get evaluation target version details + operationId: get_evaluation_version_detail + parameters: + - name: version + in: query + required: true + schema: { type: string } + responses: + '200': + description: Version graph retrieved + content: + application/json: + schema: + type: object + properties: + graph: { type: object } + '400': { $ref: '#/components/responses/BadRequest' } + '404': { $ref: '#/components/responses/NotFound' } + +# ============================================================ +# Components +# ============================================================ +components: + + # ---------- Parameters ---------- + parameters: + PageParam: + name: page + in: query + schema: { type: integer, default: 1 } + LimitParam: + name: limit + in: query + schema: { type: integer, default: 20 } + DatasetId: + name: dataset_id + in: path + required: true + schema: { type: string, format: uuid } + SnippetId: + name: snippet_id + in: path + required: true + schema: { type: string, format: uuid } + RunId: + name: run_id + in: path + required: true + schema: { type: string, format: uuid } + NodeId: + name: node_id + in: path + required: true + schema: { type: string } + EvaluateTargetType: + name: evaluate_target_type + in: path + required: true + schema: + type: string + enum: [app, snippets] + description: "Evaluation target type: app or snippets" + EvaluateTargetId: + name: evaluate_target_id + in: path + required: true + schema: { type: string, format: uuid } + description: Evaluation target ID (App ID or Snippet ID) + + # ---------- Responses ---------- + responses: + BadRequest: + description: Invalid request + content: + application/json: + schema: + type: object + properties: + message: { type: string } + NotFound: + description: Resource not found + content: + application/json: + schema: + type: object + properties: + message: { type: string } + Forbidden: + description: Permission denied + content: + application/json: + schema: + type: object + properties: + message: { type: string } + + # ---------- Schemas ---------- + schemas: + + # -- Dataset -- + DatasetDetail: + type: object + properties: + id: { type: string, format: uuid } + name: { type: string } + description: { type: string } + provider: { type: string } + permission: { type: string } + indexing_technique: { type: string } + embedding_model: { type: string } + embedding_model_provider: { type: string } + embedding_available: { type: boolean } + created_at: { type: number } + updated_at: { type: number } + + DatasetCreatePayload: + type: object + required: [name] + properties: + name: { type: string, minLength: 1, maxLength: 40 } + description: { type: string, maxLength: 400 } + indexing_technique: { type: string } + permission: { type: string, enum: [only_me, all_team_members, partial_members] } + provider: { type: string, default: vendor } + external_knowledge_api_id: { type: string } + external_knowledge_id: { type: string } + + DatasetUpdatePayload: + type: object + properties: + name: { type: string, minLength: 1, maxLength: 40 } + description: { type: string, maxLength: 400 } + permission: { type: string } + indexing_technique: { type: string } + embedding_model: { type: string } + embedding_model_provider: { type: string } + retrieval_model: { type: object } + icon_info: { type: object } + + IndexingEstimatePayload: + type: object + required: [info_list, process_rule, indexing_technique] + properties: + info_list: { type: object } + process_rule: { type: object } + indexing_technique: { type: string } + doc_form: { type: string, default: text_model } + dataset_id: { type: string } + doc_language: { type: string, default: English } + + ApiKeyItem: + type: object + properties: + id: { type: string, format: uuid } + type: { type: string } + token: { type: string } + created_at: { type: number } + + # -- Snippet -- + Snippet: + type: object + properties: + id: { type: string, format: uuid } + name: { type: string } + description: { type: string } + type: { type: string, enum: [node, group] } + is_published: { type: boolean } + version: { type: string } + use_count: { type: integer } + icon_info: { type: object } + input_fields: { type: array, items: { type: object } } + created_at: { type: number } + updated_at: { type: number } + + SnippetListItem: + type: object + properties: + id: { type: string, format: uuid } + name: { type: string } + description: { type: string } + type: { type: string } + is_published: { type: boolean } + use_count: { type: integer } + icon_info: { type: object } + created_at: { type: number } + updated_at: { type: number } + + CreateSnippetPayload: + type: object + required: [name] + properties: + name: { type: string } + description: { type: string } + type: { type: string, enum: [node, group] } + icon_info: { type: object } + input_fields: { type: array, items: { type: object } } + + UpdateSnippetPayload: + type: object + properties: + name: { type: string } + description: { type: string } + icon_info: { type: object } + + SnippetImportPayload: + type: object + properties: + mode: { type: string } + yaml_content: { type: string } + yaml_url: { type: string } + snippet_id: { type: string } + name: { type: string } + description: { type: string } + + # -- Snippet Workflow -- + Workflow: + type: object + properties: + id: { type: string, format: uuid } + graph: { type: object } + features: { type: object } + hash: { type: string } + created_at: { type: number } + updated_at: { type: number } + + SnippetDraftSyncPayload: + type: object + properties: + graph: { type: object } + hash: { type: string } + environment_variables: { type: array, items: { type: object } } + conversation_variables: { type: array, items: { type: object } } + input_variables: { type: array, items: { type: object } } + + SnippetDraftNodeRunPayload: + type: object + properties: + inputs: { type: object } + query: { type: string } + files: { type: array, items: { type: object } } + + SnippetDraftRunPayload: + type: object + properties: + inputs: { type: object } + files: { type: array, items: { type: object } } + + SnippetIterationNodeRunPayload: + type: object + properties: + inputs: { type: object } + + SnippetLoopNodeRunPayload: + type: object + properties: + inputs: { type: object } + + WorkflowRunPagination: + type: object + properties: + limit: { type: integer } + has_more: { type: boolean } + data: { type: array, items: { $ref: '#/components/schemas/WorkflowRunDetail' } } + + WorkflowRunDetail: + type: object + properties: + id: { type: string, format: uuid } + version: { type: string } + status: { type: string, enum: [running, succeeded, failed, stopped, partial-succeeded] } + elapsed_time: { type: number } + total_tokens: { type: integer } + total_steps: { type: integer } + created_at: { type: number } + finished_at: { type: number } + exceptions_count: { type: integer } + + WorkflowNodeExecution: + type: object + properties: + id: { type: string, format: uuid } + index: { type: integer } + node_id: { type: string } + node_type: { type: string } + title: { type: string } + inputs: { type: object } + process_data: { type: object } + outputs: { type: object } + status: { type: string } + error: { type: string } + elapsed_time: { type: number } + created_at: { type: number } + finished_at: { type: number } + + # -- Evaluation -- + EvaluationConfig: + type: object + properties: + evaluation_model: { type: string, nullable: true } + evaluation_model_provider: { type: string, nullable: true } + metrics_config: { type: object, nullable: true } + judgement_conditions: { type: object, nullable: true } + + EvaluationConfigData: + type: object + properties: + evaluation_model: { type: string } + evaluation_model_provider: { type: string } + default_metrics: + type: array + items: + type: object + properties: + metric: { type: string } + node_info_list: + type: array + items: { $ref: '#/components/schemas/NodeInfo' } + customized_metrics: + type: object + nullable: true + properties: + evaluation_workflow_id: { type: string } + input_fields: { type: object } + output_fields: { type: array, items: { type: object } } + judgment_config: + type: object + nullable: true + + EvaluationRunRequest: + allOf: + - $ref: '#/components/schemas/EvaluationConfigData' + - type: object + required: [file_id] + properties: + file_id: { type: string, format: uuid } + + EvaluationRun: + type: object + properties: + id: { type: string, format: uuid } + tenant_id: { type: string, format: uuid } + target_type: { type: string } + target_id: { type: string, format: uuid } + evaluation_config_id: { type: string, format: uuid } + status: { type: string, enum: [pending, running, completed, failed, cancelled] } + dataset_file_id: { type: string, format: uuid, nullable: true } + result_file_id: { type: string, format: uuid, nullable: true } + total_items: { type: integer } + completed_items: { type: integer } + failed_items: { type: integer } + progress: { type: number } + metrics_summary: { type: object } + error: { type: string, nullable: true } + created_by: { type: string, format: uuid } + started_at: { type: number, nullable: true } + completed_at: { type: number, nullable: true } + created_at: { type: number } + + EvaluationRunItem: + type: object + properties: + id: { type: string, format: uuid } + item_index: { type: integer } + inputs: { type: object } + expected_output: { type: string, nullable: true } + actual_output: { type: string, nullable: true } + metrics: + type: array + items: + type: object + properties: + name: { type: string } + value: {} + details: { type: object } + judgment: { type: object } + metadata: { type: object } + error: { type: string, nullable: true } + overall_score: { type: number, nullable: true } + + NodeInfo: + type: object + properties: + node_id: { type: string } + type: { type: string } + title: { type: string } + + FileInfo: + type: object + properties: + id: { type: string, format: uuid } + name: { type: string } + size: { type: integer } + extension: { type: string } + mime_type: { type: string } + created_at: { type: number } + download_url: { type: string } diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8778f5cafe..b7e7a6e60f 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -3,6 +3,7 @@ import logging import time import click +from sqlalchemy import select from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner @@ -10,6 +11,7 @@ from events.document_index_event import document_index_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Document +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -23,19 +25,17 @@ def handle(sender, **kwargs): for document_id in document_ids: logger.info(click.style(f"Start process document: {document_id}", fg="green")) - document = ( - db.session.query(Document) - .where( + document = db.session.scalar( + select(Document).where( Document.id == document_id, Document.dataset_id == dataset_id, ) - .first() ) if not document: raise NotFound("Document not found") - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index bac2fbef47..c43e99f0f4 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -2,8 +2,8 @@ import logging from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.nodes import NodeType -from core.workflow.nodes.tool.entities import ToolEntity +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced logger = logging.getLogger(__name__) @@ -16,7 +16,7 @@ def handle(sender, **kwargs): if synced_draft_workflow is None: return for node_data in synced_draft_workflow.graph_dict.get("nodes", []): - if node_data.get("data", {}).get("type") == NodeType.TOOL: + if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) tool_runtime = ToolManager.get_tool_runtime( diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py index 69959acd19..4709534ae6 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py @@ -1,4 +1,6 @@ -from sqlalchemy import select +from typing import Any, cast + +from sqlalchemy import delete, select from events.app_event import app_model_config_was_updated from extensions.ext_database import db @@ -29,9 +31,9 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).where( - AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id - ).delete() + db.session.execute( + delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id) + ) if added_dataset_ids: for dataset_id in added_dataset_ids: @@ -54,9 +56,11 @@ def get_dataset_ids_from_model_config(app_model_config: AppModelConfig) -> set[s continue tool_type = list(tool.keys())[0] - tool_config = list(tool.values())[0] + tool_config = cast(dict[str, Any], list(tool.values())[0]) if tool_type == "dataset": - dataset_ids.add(tool_config.get("id")) + dataset_id = tool_config.get("id") + if isinstance(dataset_id, str): + dataset_ids.add(dataset_id) # get dataset from dataset_configs dataset_configs = app_model_config.dataset_configs_dict diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 53e0065f6e..20852b818e 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,9 +1,9 @@ from typing import cast -from sqlalchemy import select +from sqlalchemy import delete, select -from core.workflow.nodes import NodeType from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData +from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin @@ -31,9 +31,9 @@ def handle(sender, **kwargs): if removed_dataset_ids: for dataset_id in removed_dataset_ids: - db.session.query(AppDatasetJoin).where( - AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id - ).delete() + db.session.execute( + delete(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id) + ) if added_dataset_ids: for dataset_id in added_dataset_ids: @@ -53,7 +53,7 @@ def get_dataset_ids_from_workflow(published_workflow: Workflow) -> set[str]: # fetch all knowledge retrieval nodes knowledge_retrieval_nodes = [ - node for node in nodes if node.get("data", {}).get("type") == NodeType.KNOWLEDGE_RETRIEVAL + node for node in nodes if node.get("data", {}).get("type") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL ] if not knowledge_retrieval_nodes: diff --git a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py index 430514ada2..b3917d5622 100644 --- a/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_triggers_when_app_published_workflow_updated.py @@ -3,7 +3,7 @@ from typing import cast from sqlalchemy import select from sqlalchemy.orm import Session -from core.workflow.nodes import NodeType +from core.trigger.constants import TRIGGER_NODE_TYPES from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models import AppMode @@ -98,7 +98,7 @@ def get_trigger_infos_from_workflow(published_workflow: Workflow) -> list[dict]: return [] nodes = graph.get("nodes", []) - trigger_types = {NodeType.TRIGGER_WEBHOOK.value, NodeType.TRIGGER_SCHEDULE.value, NodeType.TRIGGER_PLUGIN.value} + trigger_types = TRIGGER_NODE_TYPES trigger_infos = [ { diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index 46885761a1..fe95cc5816 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -13,6 +13,7 @@ def init_app(app: DifyApp): convert_to_agent_apps, create_tenant, delete_archived_workflow_runs, + export_app_messages, extract_plugins, extract_unique_plugins, file_usage, @@ -66,6 +67,7 @@ def init_app(app: DifyApp): restore_workflow_runs, clean_workflow_runs, clean_expired_messages, + export_app_messages, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_fastopenapi.py b/api/extensions/ext_fastopenapi.py index ab4d23a072..569203e974 100644 --- a/api/extensions/ext_fastopenapi.py +++ b/api/extensions/ext_fastopenapi.py @@ -1,3 +1,5 @@ +from typing import Protocol, cast + from fastopenapi.routers import FlaskRouter from flask_cors import CORS @@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS DOCS_PREFIX = "/fastopenapi" +class SupportsIncludeRouter(Protocol): + def include_router(self, router: object, *, prefix: str = "") -> None: ... + + def init_app(app: DifyApp) -> None: docs_enabled = dify_config.SWAGGER_UI_ENABLED docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None @@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None: _ = remote_files _ = setup - router.include_router(console_router, prefix="/console/api") + cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api") CORS( app, resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 74299956c0..02e50a90fc 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -3,6 +3,7 @@ import json import flask_login from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in +from sqlalchemy import select from werkzeug.exceptions import NotFound, Unauthorized from configs import dify_config @@ -34,16 +35,15 @@ def load_user_from_request(request_from_flask_login): if admin_api_key and admin_api_key == auth_token: workspace_id = request.headers.get("X-WORKSPACE-ID") if workspace_id: - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == workspace_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role == "owner") - .one_or_none() - ) + ).one_or_none() if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).filter_by(id=ta.account_id).first() + account = db.session.scalar(select(Account).where(Account.id == ta.account_id)) if account: account.current_tenant = tenant return account @@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login): end_user_id = decoded.get("end_user_id") if not end_user_id: raise Unauthorized("Invalid Authorization token.") - end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound("End user not found.") return end_user @@ -80,7 +80,7 @@ def load_user_from_request(request_from_flask_login): decoded = PassportService().verify(auth_token) end_user_id = decoded.get("end_user_id") if end_user_id: - end_user = db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id)) if not end_user: raise NotFound("End user not found.") return end_user @@ -90,11 +90,11 @@ def load_user_from_request(request_from_flask_login): server_code = request.view_args.get("server_code") if request.view_args else None if not server_code: raise Unauthorized("Invalid Authorization token.") - app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first() + app_mcp_server = db.session.scalar(select(AppMCPServer).where(AppMCPServer.server_code == server_code).limit(1)) if not app_mcp_server: raise NotFound("App MCP server not found.") - end_user = ( - db.session.query(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").first() + end_user = db.session.scalar( + select(EndUser).where(EndUser.session_id == app_mcp_server.id, EndUser.type == "mcp").limit(1) ) if not end_user: raise NotFound("End user not found.") diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 658e6a0738..26262484f9 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -18,6 +18,7 @@ from dify_app import DifyApp from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel +from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel if TYPE_CHECKING: from redis.lock import Lock @@ -181,13 +182,18 @@ def _create_sentinel_client(redis_params: dict[str, Any]) -> Union[redis.Redis, sentinel_hosts = [(node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",")] + sentinel_kwargs = { + "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, + "username": dify_config.REDIS_SENTINEL_USERNAME, + "password": dify_config.REDIS_SENTINEL_PASSWORD, + } + + if dify_config.REDIS_MAX_CONNECTIONS: + sentinel_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + sentinel = Sentinel( sentinel_hosts, - sentinel_kwargs={ - "socket_timeout": dify_config.REDIS_SENTINEL_SOCKET_TIMEOUT, - "username": dify_config.REDIS_SENTINEL_USERNAME, - "password": dify_config.REDIS_SENTINEL_PASSWORD, - }, + sentinel_kwargs=sentinel_kwargs, ) master: redis.Redis = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) @@ -204,12 +210,15 @@ def _create_cluster_client() -> Union[redis.Redis, RedisCluster]: for node in dify_config.REDIS_CLUSTERS.split(",") ] - cluster: RedisCluster = RedisCluster( - startup_nodes=nodes, - password=dify_config.REDIS_CLUSTERS_PASSWORD, - protocol=dify_config.REDIS_SERIALIZATION_PROTOCOL, - cache_config=_get_cache_configuration(), - ) + cluster_kwargs: dict[str, Any] = { + "startup_nodes": nodes, + "password": dify_config.REDIS_CLUSTERS_PASSWORD, + "protocol": dify_config.REDIS_SERIALIZATION_PROTOCOL, + "cache_config": _get_cache_configuration(), + } + if dify_config.REDIS_MAX_CONNECTIONS: + cluster_kwargs["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + cluster: RedisCluster = RedisCluster(**cluster_kwargs) return cluster @@ -225,6 +234,9 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis } ) + if dify_config.REDIS_MAX_CONNECTIONS: + redis_params["max_connections"] = dify_config.REDIS_MAX_CONNECTIONS + if ssl_kwargs: redis_params.update(ssl_kwargs) @@ -234,9 +246,17 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> redis.Redis | RedisCluster: + max_conns = dify_config.REDIS_MAX_CONNECTIONS if use_clusters: - return RedisCluster.from_url(pubsub_url) - return redis.Redis.from_url(pubsub_url) + if max_conns: + return RedisCluster.from_url(pubsub_url, max_connections=max_conns) + else: + return RedisCluster.from_url(pubsub_url) + + if max_conns: + return redis.Redis.from_url(pubsub_url, max_connections=max_conns) + else: + return redis.Redis.from_url(pubsub_url) def init_app(app: DifyApp): @@ -269,6 +289,11 @@ def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol: assert _pubsub_redis_client is not None, "PubSub redis Client should be initialized here." if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded": return ShardedRedisBroadcastChannel(_pubsub_redis_client) + if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "streams": + return StreamsBroadcastChannel( + _pubsub_redis_client, + retention_seconds=dify_config.PUBSUB_STREAMS_RETENTION_SECONDS, + ) return RedisBroadcastChannel(_pubsub_redis_client) diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index c3aa8edf80..9a34acb0c1 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -10,7 +10,7 @@ def init_app(app: DifyApp): from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from core.model_runtime.errors.invoke import InvokeRateLimitError + from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError def before_send(event, hint): if "exc_info" in hint: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index 817c8b0448..a94d75ec76 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -13,11 +13,12 @@ from typing import Any from sqlalchemy.orm import sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value -from models.workflow import WorkflowNodeExecutionModel +from models.enums import CreatorUserRole +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository logger = logging.getLogger(__name__) @@ -47,12 +48,28 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.triggered_from = data.get("triggered_from") or "" + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowNodeExecutionTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to WORKFLOW_RUN", triggered_from_val) + model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" model.status = data.get("status") or "running" # Default status if missing model.title = data.get("title") or "" - model.created_by_role = data.get("created_by_role") or "" + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.index = safe_int(data.get("index", 0)) diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index 14382ed876..bdfc81bd1c 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -22,12 +22,13 @@ from typing import Any, cast from sqlalchemy.orm import sessionmaker +from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.workflow import WorkflowRun +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.types import ( AverageInteractionStats, @@ -59,11 +60,37 @@ def _dict_to_workflow_run(data: dict[str, Any]) -> WorkflowRun: model.tenant_id = data.get("tenant_id") or "" model.app_id = data.get("app_id") or "" model.workflow_id = data.get("workflow_id") or "" - model.type = data.get("type") or "" - model.triggered_from = data.get("triggered_from") or "" + type_val = data.get("type") + try: + model.type = WorkflowType(str(type_val)) if type_val else WorkflowType.WORKFLOW + except ValueError: + logger.warning("Invalid type value: %s, falling back to WORKFLOW", type_val) + model.type = WorkflowType.WORKFLOW + triggered_from_val = data.get("triggered_from") + try: + model.triggered_from = ( + WorkflowRunTriggeredFrom(str(triggered_from_val)) + if triggered_from_val + else WorkflowRunTriggeredFrom.APP_RUN + ) + except ValueError: + logger.warning("Invalid triggered_from value: %s, falling back to APP_RUN", triggered_from_val) + model.triggered_from = WorkflowRunTriggeredFrom.APP_RUN model.version = data.get("version") or "" - model.status = data.get("status") or "running" # Default status if missing - model.created_by_role = data.get("created_by_role") or "" + status_val = data.get("status") + try: + model.status = WorkflowExecutionStatus(str(status_val)) if status_val else WorkflowExecutionStatus.RUNNING + except ValueError: + logger.warning("Invalid status value: %s, falling back to RUNNING", status_val) + model.status = WorkflowExecutionStatus.RUNNING + created_by_role_val = data.get("created_by_role") + try: + model.created_by_role = ( + CreatorUserRole(str(created_by_role_val)) if created_by_role_val else CreatorUserRole.ACCOUNT + ) + except ValueError: + logger.warning("Invalid created_by_role value: %s, falling back to ACCOUNT", created_by_role_val) + model.created_by_role = CreatorUserRole.ACCOUNT model.created_by = data.get("created_by") or "" model.total_tokens = safe_int(data.get("total_tokens", 0)) diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index 9928879a7b..c58aa6adbb 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -8,9 +8,9 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from core.workflow.entities import WorkflowExecution -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowExecution +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index 4897171b12..d84c0bc432 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -16,13 +16,12 @@ from typing import Any, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import WorkflowNodeExecution -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.enums import NodeType -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier @@ -78,7 +77,7 @@ def _dict_to_workflow_node_execution(data: dict[str, Any]) -> WorkflowNodeExecut index=safe_int(data.get("index", 0)), predecessor_node_id=data.get("predecessor_node_id"), node_id=data.get("node_id", ""), - node_type=NodeType(data.get("node_type", "start")), + node_type=data.get("node_type", "start"), title=data.get("title", ""), inputs=inputs, process_data=process_data, @@ -185,7 +184,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): ("predecessor_node_id", domain_model.predecessor_node_id or ""), ("node_execution_id", domain_model.node_execution_id or ""), ("node_id", domain_model.node_id), - ("node_type", domain_model.node_type.value), + ("node_type", domain_model.node_type), ("title", domain_model.title), ( "inputs", diff --git a/api/extensions/otel/celery_sqlcommenter.py b/api/extensions/otel/celery_sqlcommenter.py new file mode 100644 index 0000000000..8abb1ce15a --- /dev/null +++ b/api/extensions/otel/celery_sqlcommenter.py @@ -0,0 +1,114 @@ +""" +Celery SQL comment context for OpenTelemetry SQLCommenter. + +Injects Celery-specific metadata (framework, task_name, traceparent, celery_retries, +routing_key) into SQL comments for queries executed by Celery workers. This improves +trace-to-SQL correlation and debugging in production. + +Uses the OpenTelemetry context key SQLCOMMENTER_ORM_TAGS_AND_VALUES, which is read +by opentelemetry.instrumentation.sqlcommenter_utils._add_framework_tags() when the +SQLAlchemy instrumentor appends comments to SQL statements. +""" + +import logging +from typing import Any + +from celery.signals import task_postrun, task_prerun +from opentelemetry import context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + +logger = logging.getLogger(__name__) +_TRACE_PROPAGATOR = TraceContextTextMapPropagator() + +_SQLCOMMENTER_CONTEXT_KEY = "SQLCOMMENTER_ORM_TAGS_AND_VALUES" +_TOKEN_ATTR = "_dify_sqlcommenter_context_token" + + +def _build_celery_sqlcommenter_tags(task: Any) -> dict[str, str | int]: + """Build SQL commenter tags from the current Celery task and OpenTelemetry context.""" + tags: dict[str, str | int] = {} + + try: + tags["framework"] = f"celery:{_get_celery_version()}" + except Exception: + tags["framework"] = "celery:unknown" + + if task and getattr(task, "name", None): + tags["task_name"] = str(task.name) + + traceparent = _get_traceparent() + if traceparent: + tags["traceparent"] = traceparent + + if task and hasattr(task, "request"): + request = task.request + retries = getattr(request, "retries", None) + if retries is not None and retries > 0: + tags["celery_retries"] = int(retries) + + delivery_info = getattr(request, "delivery_info", None) or {} + if isinstance(delivery_info, dict): + routing_key = delivery_info.get("routing_key") + if routing_key: + tags["routing_key"] = str(routing_key) + + return tags + + +def _get_celery_version() -> str: + import celery + + return getattr(celery, "__version__", "unknown") + + +def _get_traceparent() -> str | None: + """Extract traceparent from the current OpenTelemetry context.""" + carrier: dict[str, str] = {} + _TRACE_PROPAGATOR.inject(carrier) + return carrier.get("traceparent") + + +def _on_task_prerun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + tags = _build_celery_sqlcommenter_tags(task) + if not tags: + return + + current = context.get_current() + new_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, tags, current) + token = context.attach(new_ctx) + setattr(task, _TOKEN_ATTR, token) + + +def _on_task_postrun(*args: object, **kwargs: object) -> None: + task = kwargs.get("task") + if not task: + return + + token = getattr(task, _TOKEN_ATTR, None) + if token is None: + return + + try: + context.detach(token) + except Exception: + logger.debug("Failed to detach SQL commenter context", exc_info=True) + finally: + try: + delattr(task, _TOKEN_ATTR) + except AttributeError: + pass + + +def setup_celery_sqlcommenter() -> None: + """ + Connect Celery task_prerun and task_postrun handlers to inject SQL comment + context for worker queries. Call this from init_celery_worker after + CeleryInstrumentor().instrument() so our handlers run after the OTEL + instrumentor's and the trace context is already attached. + """ + task_prerun.connect(_on_task_prerun, weak=False) + task_postrun.connect(_on_task_postrun, weak=False) diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 66d1c977d6..544ef3fe18 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -9,11 +9,11 @@ from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from core.workflow.enums import NodeType -from core.workflow.file.models import File -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.variables import Segment +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.file.models import File +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.variables import Segment from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes @@ -84,21 +84,17 @@ class DefaultNodeOTelParser: span.set_attribute("node.id", node.id) if node.execution_id: span.set_attribute("node.execution_id", node.execution_id) - if hasattr(node, "node_type") and node.node_type: - span.set_attribute("node.type", node.node_type.value) + span.set_attribute("node.type", node.node_type) span.set_attribute(GenAIAttributes.FRAMEWORK, "dify") - node_type = getattr(node, "node_type", None) - if isinstance(node_type, NodeType): - if node_type == NodeType.LLM: - span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM") - elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: - span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER") - elif node_type == NodeType.TOOL: - span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL") - else: - span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") + node_type = node.node_type + if node_type == BuiltinNodeTypes.LLM: + span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM") + elif node_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER") + elif node_type == BuiltinNodeTypes.TOOL: + span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL") else: span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK") diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 8556974080..3da9a9e97d 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -8,8 +8,8 @@ from typing import Any from opentelemetry.trace import Span -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index 82cb865b8b..dd658b250b 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -8,9 +8,9 @@ from typing import Any from opentelemetry.trace import Span -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.variables import Segment +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index b99180722b..f4e6a18b4d 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -4,10 +4,10 @@ Parser for tool nodes that captures tool-specific metadata. from opentelemetry.trace import Span -from core.workflow.enums import WorkflowNodeExecutionMetadataKey -from core.workflow.graph_events import GraphNodeEventBase -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.tool.entities import ToolNodeData +from dify_graph.enums import WorkflowNodeExecutionMetadataKey +from dify_graph.graph_events import GraphNodeEventBase +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes diff --git a/api/extensions/otel/runtime.py b/api/extensions/otel/runtime.py index a7181d2683..149d76b07b 100644 --- a/api/extensions/otel/runtime.py +++ b/api/extensions/otel/runtime.py @@ -5,9 +5,9 @@ from typing import Union from celery.signals import worker_init from flask_login import user_loaded_from_request, user_logged_in -from opentelemetry import trace +from opentelemetry import metrics, trace from opentelemetry.propagate import set_global_textmap -from opentelemetry.propagators.b3 import B3Format +from opentelemetry.propagators.b3 import B3MultiFormat from opentelemetry.propagators.composite import CompositePropagator from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator @@ -24,16 +24,36 @@ def setup_context_propagation() -> None: CompositePropagator( [ TraceContextTextMapPropagator(), - B3Format(), + B3MultiFormat(), ] ) ) def shutdown_tracer() -> None: + flush_telemetry() + + +def flush_telemetry() -> None: + """ + Best-effort flush for telemetry providers. + + This is mainly used by short-lived command processes (e.g. Kubernetes CronJob) + so counters/histograms are exported before the process exits. + """ provider = trace.get_tracer_provider() if hasattr(provider, "force_flush"): - provider.force_flush() + try: + provider.force_flush() + except Exception: + logger.exception("otel: failed to flush trace provider") + + metric_provider = metrics.get_meter_provider() + if hasattr(metric_provider, "force_flush"): + try: + metric_provider.force_flush() + except Exception: + logger.exception("otel: failed to flush metric provider") def is_celery_worker(): @@ -67,11 +87,14 @@ def init_celery_worker(*args, **kwargs): from opentelemetry.metrics import get_meter_provider from opentelemetry.trace import get_tracer_provider + from extensions.otel.celery_sqlcommenter import setup_celery_sqlcommenter + tracer_provider = get_tracer_provider() metric_provider = get_meter_provider() if dify_config.DEBUG: logger.info("Initializing OpenTelemetry for Celery worker") CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() + setup_celery_sqlcommenter() def is_instrument_flag_enabled() -> bool: diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index 83c5c2d12f..96f5915ff0 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -32,7 +32,7 @@ class OpenDALStorage(BaseStorage): kwargs = kwargs or _get_opendal_kwargs(scheme=scheme) if scheme == "fs": - root = kwargs.get("root", "storage") + root = kwargs.setdefault("root", "storage") Path(root).mkdir(parents=True, exist_ok=True) retry_layer = opendal.layers.RetryLayer(max_times=3, factor=2.0, jitter=True) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 47396831fa..cb07ba58ae 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -14,7 +14,7 @@ from werkzeug.http import parse_options_header from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.helper import ssrf_proxy -from core.workflow.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers +from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers from extensions.ext_database import db from models import MessageFile, ToolFile, UploadFile @@ -424,13 +424,11 @@ def _build_from_datasource_file( datasource_file_id = mapping.get("datasource_file_id") if not datasource_file_id: raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = ( - db.session.query(UploadFile) - .where( + datasource_file = db.session.scalar( + select(UploadFile).where( UploadFile.id == datasource_file_id, UploadFile.tenant_id == tenant_id, ) - .first() ) if datasource_file is None: diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index b74d9517f4..14a56bf4a2 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -3,13 +3,13 @@ from typing import Any, cast from uuid import uuid4 from configs import dify_config -from core.workflow.constants import ( +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from core.workflow.file import File -from core.workflow.variables.exc import VariableError -from core.workflow.variables.segments import ( +from dify_graph.file import File +from dify_graph.variables.exc import VariableError +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayBooleanSegment, ArrayFileSegment, @@ -26,8 +26,8 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import ( +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayBooleanVariable, ArrayFileVariable, @@ -55,7 +55,7 @@ class TypeMismatchError(Exception): # Define the constant -SEGMENT_TO_VARIABLE_MAP = { +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { ArrayAnySegment: ArrayAnyVariable, ArrayBooleanSegment: ArrayBooleanVariable, ArrayFileSegment: ArrayFileVariable, @@ -296,13 +296,11 @@ def segment_to_variable( raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return cast( - VariableBase, - variable_class( - id=id, - name=name, - description=description, - value=segment.value, - selector=list(selector), - ), + return variable_class( + id=id, + name=name, + description=description, + value_type=segment.value_type, + value=segment.value, + selector=list(selector), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index 461c163e2f..ac7c5376fb 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from core.workflow.variables.segments import Segment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import Segment +from dify_graph.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index faa3606f0e..a5c7ddbb11 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -5,7 +5,7 @@ from typing import Any, TypeAlias from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from core.workflow.file import File +from dify_graph.file import File JSONValue: TypeAlias = Any diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 29b9e40242..7ee628726b 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -5,7 +5,7 @@ from datetime import datetime from flask_restx import fields from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers simple_account_fields = { "id": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 55bd0a5fbd..428f92ed33 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -7,7 +7,7 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from core.workflow.file import File +from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile JSONValueType: TypeAlias = JSONValue diff --git a/api/fields/raws.py b/api/fields/raws.py index 33b47ba2c3..318dedc25c 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,6 @@ from flask_restx import fields -from core.workflow.file import File +from dify_graph.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 019949e105..7ce2139687 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields from core.helper import encrypter -from core.workflow.variables import SecretVariable, SegmentType, VariableBase +from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/broadcast_channel/redis/streams_channel.py b/api/libs/broadcast_channel/redis/streams_channel.py new file mode 100644 index 0000000000..d6ec5504ca --- /dev/null +++ b/api/libs/broadcast_channel/redis/streams_channel.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +import logging +import queue +import threading +from collections.abc import Iterator +from typing import Self + +from libs.broadcast_channel.channel import Producer, Subscriber, Subscription +from libs.broadcast_channel.exc import SubscriptionClosedError +from redis import Redis, RedisCluster + +logger = logging.getLogger(__name__) + + +class StreamsBroadcastChannel: + """ + Redis Streams based broadcast channel implementation. + + Characteristics: + - At-least-once delivery for late subscribers within the stream retention window. + - Each topic is stored as a dedicated Redis Stream key. + - The stream key expires `retention_seconds` after the last event is published (to bound storage). + """ + + def __init__(self, redis_client: Redis | RedisCluster, *, retention_seconds: int = 600): + self._client = redis_client + self._retention_seconds = max(int(retention_seconds or 0), 0) + + def topic(self, topic: str) -> StreamsTopic: + return StreamsTopic(self._client, topic, retention_seconds=self._retention_seconds) + + +class StreamsTopic: + def __init__(self, redis_client: Redis | RedisCluster, topic: str, *, retention_seconds: int = 600): + self._client = redis_client + self._topic = topic + self._key = f"stream:{topic}" + self._retention_seconds = retention_seconds + self.max_length = 5000 + + def as_producer(self) -> Producer: + return self + + def publish(self, payload: bytes) -> None: + self._client.xadd(self._key, {b"data": payload}, maxlen=self.max_length) + if self._retention_seconds > 0: + try: + self._client.expire(self._key, self._retention_seconds) + except Exception as e: + logger.warning("Failed to set expire for stream key %s: %s", self._key, e, exc_info=True) + + def as_subscriber(self) -> Subscriber: + return self + + def subscribe(self) -> Subscription: + return _StreamsSubscription(self._client, self._key) + + +class _StreamsSubscription(Subscription): + _SENTINEL = object() + + def __init__(self, client: Redis | RedisCluster, key: str): + self._client = client + self._key = key + self._closed = threading.Event() + self._last_id = "0-0" + self._queue: queue.Queue[object] = queue.Queue() + self._start_lock = threading.Lock() + self._listener: threading.Thread | None = None + + def _listen(self) -> None: + try: + while not self._closed.is_set(): + streams = self._client.xread({self._key: self._last_id}, block=1000, count=100) + + if not streams: + continue + + for _key, entries in streams: + for entry_id, fields in entries: + data = None + if isinstance(fields, dict): + data = fields.get(b"data") + data_bytes: bytes | None = None + if isinstance(data, str): + data_bytes = data.encode() + elif isinstance(data, (bytes, bytearray)): + data_bytes = bytes(data) + if data_bytes is not None: + self._queue.put_nowait(data_bytes) + self._last_id = entry_id + finally: + self._queue.put_nowait(self._SENTINEL) + self._listener = None + + def _start_if_needed(self) -> None: + if self._listener is not None: + return + # Ensure only one listener thread is created under concurrent calls + with self._start_lock: + if self._listener is not None or self._closed.is_set(): + return + self._listener = threading.Thread( + target=self._listen, + name=f"redis-streams-sub-{self._key}", + daemon=True, + ) + self._listener.start() + + def __iter__(self) -> Iterator[bytes]: + # Iterator delegates to receive with timeout; stops on closure. + self._start_if_needed() + while not self._closed.is_set(): + item = self.receive(timeout=1) + if item is not None: + yield item + + def receive(self, timeout: float | None = 0.1) -> bytes | None: + if self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + self._start_if_needed() + + try: + if timeout is None: + item = self._queue.get() + else: + item = self._queue.get(timeout=timeout) + except queue.Empty: + return None + + if item is self._SENTINEL or self._closed.is_set(): + raise SubscriptionClosedError("The Redis streams subscription is closed") + assert isinstance(item, (bytes, bytearray)), "Unexpected item type in stream queue" + return bytes(item) + + def close(self) -> None: + if self._closed.is_set(): + return + self._closed.set() + listener = self._listener + if listener is not None: + listener.join(timeout=2.0) + if listener.is_alive(): + logger.warning( + "Streams subscription listener for key %s did not stop within timeout; keeping reference.", + self._key, + ) + else: + self._listener = None + + # Context manager helpers + def __enter__(self) -> Self: + self._start_if_needed() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> bool | None: + self.close() + return None diff --git a/api/libs/helper.py b/api/libs/helper.py index 206bb8fd81..e7572cc025 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -21,8 +21,8 @@ from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: @@ -32,6 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +def _stream_with_request_context(response: object) -> Any: + """Bridge Flask's loosely-typed streaming helper without leaking casts into callers.""" + return cast(Any, stream_with_context)(response) + + def escape_like_pattern(pattern: str) -> str: """ Escape special characters in a string for safe use in SQL LIKE patterns. @@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str: return sha256(hash_text.encode()).hexdigest() -def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: - if isinstance(response, dict): +def compact_generate_response( + response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator, +) -> Response: + if isinstance(response, Mapping): return Response( response=json.dumps(jsonable_encoder(response)), status=200, content_type="application/json; charset=utf-8", ) else: + stream_response = response - def generate() -> Generator: - yield from response + def generate() -> Generator[str, None, None]: + yield from stream_response - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) -def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: +def length_prefixed_response( + magic_number: int, + response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator, +) -> Response: """ This function is used to return a response with a length prefix. Magic number is a one byte number that indicates the type of the response. @@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data return struct.pack(" Generator: - for chunk in response: + stream_response = response + + def generate() -> Generator[bytes, None, None]: + for chunk in stream_response: if isinstance(chunk, str): yield pack_response_with_length_prefix(chunk.encode("utf-8")) else: yield pack_response_with_length_prefix(chunk) - return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + return Response( + _stream_with_request_context(generate()), + status=200, + mimetype="text/event-stream", + ) class TokenManager: diff --git a/api/libs/login.py b/api/libs/login.py index 69e2b58426..bd5cb5f30d 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue] @wraps(func) def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue: if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: - pass - elif current_user is not None and not current_user.is_authenticated: + return current_app.ensure_sync(func)(*args, **kwargs) + + user = _get_user() + if user is None or not user.is_authenticated: return current_app.login_manager.unauthorized() # type: ignore # we put csrf validation here for less conflicts # TODO: maybe find a better place for it. - check_csrf_token(request, current_user.id) + check_csrf_token(request, user.id) return current_app.ensure_sync(func)(*args, **kwargs) return decorated_view diff --git a/api/libs/module_loading.py b/api/libs/module_loading.py index 9f74943433..7063a115b0 100644 --- a/api/libs/module_loading.py +++ b/api/libs/module_loading.py @@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py import sys from importlib import import_module +from typing import Any -def cached_import(module_path: str, class_name: str): +def cached_import(module_path: str, class_name: str) -> Any: """ Import a module and return the named attribute/class from it, with caching. @@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str): Returns: The imported attribute/class """ - if not ( - (module := sys.modules.get(module_path)) - and (spec := getattr(module, "__spec__", None)) - and getattr(spec, "_initializing", False) is False - ): + module = sys.modules.get(module_path) + spec = getattr(module, "__spec__", None) if module is not None else None + if module is None or getattr(spec, "_initializing", False): module = import_module(module_path) return getattr(module, class_name) -def import_string(dotted_path: str): +def import_string(dotted_path: str) -> Any: """ Import a dotted module path and return the attribute/class designated by the last name in the path. Raise ImportError if the import failed. diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 889a5a3248..1afb42304d 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -1,7 +1,51 @@ +import logging +import sys import urllib.parse from dataclasses import dataclass +from typing import NotRequired import httpx +from pydantic import TypeAdapter, ValidationError + +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + +logger = logging.getLogger(__name__) + +JsonObject = dict[str, object] +JsonObjectList = list[JsonObject] + +JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject) +JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList) + + +class AccessTokenResponse(TypedDict, total=False): + access_token: str + + +class GitHubEmailRecord(TypedDict, total=False): + email: str + primary: bool + + +class GitHubRawUserInfo(TypedDict): + id: int | str + login: str + name: NotRequired[str | None] + email: NotRequired[str | None] + + +class GoogleRawUserInfo(TypedDict): + sub: str + email: str + + +ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse) +GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo) +GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord]) +GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo) @dataclass @@ -11,26 +55,38 @@ class OAuthUserInfo: email: str +def _json_object(response: httpx.Response) -> JsonObject: + return JSON_OBJECT_ADAPTER.validate_python(response.json()) + + +def _json_list(response: httpx.Response) -> JsonObjectList: + return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json()) + + class OAuth: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self, invite_token: str | None = None) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: raise NotImplementedError() - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: raise NotImplementedError() def get_user_info(self, token: str) -> OAuthUserInfo: raw_info = self.get_raw_user_info(token) return self._transform_user_info(raw_info) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: raise NotImplementedError() @@ -40,7 +96,7 @@ class GitHubOAuth(OAuth): _USER_INFO_URL = "https://api.github.com/user" _EMAIL_INFO_URL = "https://api.github.com/user/emails" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "redirect_uri": self.redirect_uri, @@ -50,7 +106,7 @@ class GitHubOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -60,7 +116,7 @@ class GitHubOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -68,23 +124,32 @@ class GitHubOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"token {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - user_info = response.json() + user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) - email_info = email_response.json() - primary_email: dict = next((email for email in email_info if email["primary"] == True), {}) + try: + email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) + email_response.raise_for_status() + email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) + primary_email = next((email for email in email_info if email.get("primary") is True), None) + except (httpx.HTTPStatusError, ValidationError): + logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) + primary_email = None - return {**user_info, "email": primary_email.get("email", "")} + return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - email = raw_info.get("email") + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + email = payload.get("email") if not email: - email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com" - return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email) + raise ValueError( + 'Dify currently not supports the "Keep my email addresses private" feature,' + " please disable it and login again" + ) + return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) class GoogleOAuth(OAuth): @@ -92,7 +157,7 @@ class GoogleOAuth(OAuth): _TOKEN_URL = "https://oauth2.googleapis.com/token" _USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" - def get_authorization_url(self, invite_token: str | None = None): + def get_authorization_url(self, invite_token: str | None = None) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -103,7 +168,7 @@ class GoogleOAuth(OAuth): params["state"] = invite_token return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> str: data = { "client_id": self.client_id, "client_secret": self.client_secret, @@ -114,7 +179,7 @@ class GoogleOAuth(OAuth): headers = {"Accept": "application/json"} response = httpx.post(self._TOKEN_URL, data=data, headers=headers) - response_json = response.json() + response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response)) access_token = response_json.get("access_token") if not access_token: @@ -122,11 +187,12 @@ class GoogleOAuth(OAuth): return access_token - def get_raw_user_info(self, token: str): + def get_raw_user_info(self, token: str) -> JsonObject: headers = {"Authorization": f"Bearer {token}"} response = httpx.get(self._USER_INFO_URL, headers=headers) response.raise_for_status() - return response.json() + return _json_object(response) - def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo: - return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"]) + def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: + payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info) + return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"]) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index ae0ae3bcb6..d5dc35ac97 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,25 +1,57 @@ +import sys import urllib.parse -from typing import Any +from typing import Any, Literal import httpx from flask_login import current_user +from pydantic import TypeAdapter from sqlalchemy import select from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding +if sys.version_info >= (3, 12): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class NotionPageSummary(TypedDict): + page_id: str + page_name: str + page_icon: dict[str, str] | None + parent_id: str + type: Literal["page", "database"] + + +class NotionSourceInfo(TypedDict): + workspace_name: str | None + workspace_icon: str | None + workspace_id: str | None + pages: list[NotionPageSummary] + total: int + + +SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object]) +NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo) +NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary) + class OAuthDataSource: + client_id: str + client_secret: str + redirect_uri: str + def __init__(self, client_id: str, client_secret: str, redirect_uri: str): self.client_id = client_id self.client_secret = client_secret self.redirect_uri = redirect_uri - def get_authorization_url(self): + def get_authorization_url(self) -> str: raise NotImplementedError() - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: raise NotImplementedError() @@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource): _NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks" _NOTION_BOT_USER = "https://api.notion.com/v1/users/me" - def get_authorization_url(self): + def get_authorization_url(self) -> str: params = { "client_id": self.client_id, "response_type": "code", @@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource): } return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" - def get_access_token(self, code: str): + def get_access_token(self, code: str) -> None: data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri} headers = {"Accept": "application/json"} auth = (self.client_id, self.client_secret) @@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource): workspace_id = response_json.get("workspace_id") # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def save_internal_access_token(self, access_token: str): + def save_internal_access_token(self, access_token: str) -> None: workspace_name = self.notion_workspace_name(access_token) workspace_icon = None workspace_id = current_user.current_tenant_id # get all authorized pages pages = self.get_authorized_pages(access_token) - source_info = { - "workspace_name": workspace_name, - "workspace_icon": workspace_icon, - "workspace_id": workspace_id, - "pages": pages, - "total": len(pages), - } + source_info = self._build_source_info( + workspace_name=workspace_name, + workspace_icon=workspace_icon, + workspace_id=workspace_id, + pages=pages, + ) # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource): ) ) if data_source_binding: - data_source_binding.source_info = source_info + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() @@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource): new_data_source_binding = DataSourceOauthBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=source_info, + source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info), provider="notion", ) db.session.add(new_data_source_binding) db.session.commit() - def sync_data_source(self, binding_id: str): + def sync_data_source(self, binding_id: str) -> None: # save data source binding data_source_binding = db.session.scalar( select(DataSourceOauthBinding).where( @@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) - source_info = data_source_binding.source_info - new_source_info = { - "workspace_name": source_info["workspace_name"], - "workspace_icon": source_info["workspace_icon"], - "workspace_id": source_info["workspace_id"], - "pages": pages, - "total": len(pages), - } - data_source_binding.source_info = new_source_info + source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info) + new_source_info = self._build_source_info( + workspace_name=source_info["workspace_name"], + workspace_icon=source_info["workspace_icon"], + workspace_id=source_info["workspace_id"], + pages=pages, + ) + data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info) data_source_binding.disabled = False data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") - def get_authorized_pages(self, access_token: str): - pages = [] + def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]: + pages: list[NotionPageSummary] = [] page_results = self.notion_page_search(access_token) database_results = self.notion_database_search(access_token) # get page detail @@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "page", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) # get database detail for database_result in database_results: page_id = database_result["id"] @@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource): "parent_id": parent_id, "type": "database", } - pages.append(page) + pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page)) return pages - def notion_page_search(self, access_token: str): - results = [] + def notion_page_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource): return results - def notion_block_parent_page_id(self, access_token: str, block_id: str): + def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource): return self.notion_block_parent_page_id(access_token, parent[parent_type]) return parent[parent_type] - def notion_workspace_name(self, access_token: str): + def notion_workspace_name(self, access_token: str) -> str: headers = { "Authorization": f"Bearer {access_token}", "Notion-Version": "2022-06-28", @@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource): return user_info["workspace_name"] return "workspace" - def notion_database_search(self, access_token: str): - results = [] + def notion_database_search(self, access_token: str) -> list[dict[str, Any]]: + results: list[dict[str, Any]] = [] next_cursor = None has_more = True @@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource): next_cursor = response_json.get("next_cursor", None) return results + + @staticmethod + def _build_source_info( + *, + workspace_name: str | None, + workspace_icon: str | None, + workspace_id: str | None, + pages: list[NotionPageSummary], + ) -> NotionSourceInfo: + return { + "workspace_name": workspace_name, + "workspace_icon": workspace_icon, + "workspace_id": workspace_id, + "pages": pages, + "total": len(pages), + } diff --git a/api/migrations/env.py b/api/migrations/env.py index 66a4614e80..3b1fa7bb89 100644 --- a/api/migrations/env.py +++ b/api/migrations/env.py @@ -66,6 +66,7 @@ def run_migrations_offline(): context.configure( url=url, target_metadata=get_metadata(), literal_binds=True ) + logger.info("Generating offline migration SQL with url: %s", url) with context.begin_transaction(): context.run_migrations() diff --git a/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py new file mode 100644 index 0000000000..ed794178b3 --- /dev/null +++ b/api/migrations/versions/2026_02_26_1336-e288952f2994_add_partial_indexes_on_conversations_.py @@ -0,0 +1,37 @@ +"""add partial indexes on conversations for app_id with created_at and updated_at + +Revision ID: e288952f2994 +Revises: fce013ca180e +Create Date: 2026-02-26 13:36:45.928922 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = 'e288952f2994' +down_revision = 'fce013ca180e' +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.create_index( + 'conversation_app_created_at_idx', + ['app_id', sa.literal_column('created_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + batch_op.create_index( + 'conversation_app_updated_at_idx', + ['app_id', sa.literal_column('updated_at DESC')], + unique=False, + postgresql_where=sa.text('is_deleted IS false'), + ) + + +def downgrade(): + with op.batch_alter_table('conversations', schema=None) as batch_op: + batch_op.drop_index('conversation_app_updated_at_idx') + batch_op.drop_index('conversation_app_created_at_idx') diff --git a/api/migrations/versions/2026_03_02_1805-0ec65df55790_add_indexes_for_human_input_forms.py b/api/migrations/versions/2026_03_02_1805-0ec65df55790_add_indexes_for_human_input_forms.py new file mode 100644 index 0000000000..63fd58b1bf --- /dev/null +++ b/api/migrations/versions/2026_03_02_1805-0ec65df55790_add_indexes_for_human_input_forms.py @@ -0,0 +1,68 @@ +"""add indexes for human_input_forms query patterns + +Revision ID: 0ec65df55790 +Revises: e288952f2994 +Create Date: 2026-03-02 18:05:00.000000 + +""" + +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "0ec65df55790" +down_revision = "e288952f2994" +branch_labels = None +depends_on = None + + +def upgrade(): + with op.batch_alter_table("human_input_forms", schema=None) as batch_op: + batch_op.create_index( + "human_input_forms_workflow_run_id_node_id_idx", + ["workflow_run_id", "node_id"], + unique=False, + ) + batch_op.create_index( + "human_input_forms_status_created_at_idx", + ["status", "created_at"], + unique=False, + ) + batch_op.create_index( + "human_input_forms_status_expiration_time_idx", + ["status", "expiration_time"], + unique=False, + ) + + with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("human_input_form_deliveries_form_id_idx"), + ["form_id"], + unique=False, + ) + + with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op: + batch_op.create_index( + batch_op.f("human_input_form_recipients_delivery_id_idx"), + ["delivery_id"], + unique=False, + ) + batch_op.create_index( + batch_op.f("human_input_form_recipients_form_id_idx"), + ["form_id"], + unique=False, + ) + + +def downgrade(): + with op.batch_alter_table("human_input_forms", schema=None) as batch_op: + batch_op.drop_index("human_input_forms_workflow_run_id_node_id_idx") + batch_op.drop_index("human_input_forms_status_expiration_time_idx") + batch_op.drop_index("human_input_forms_status_created_at_idx") + + with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("human_input_form_recipients_form_id_idx")) + batch_op.drop_index(batch_op.f("human_input_form_recipients_delivery_id_idx")) + + with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("human_input_form_deliveries_form_id_idx")) diff --git a/api/migrations/versions/2026_03_04_1600-6b5f9f8b1a2c_add_user_id_to_workflow_draft_variables.py b/api/migrations/versions/2026_03_04_1600-6b5f9f8b1a2c_add_user_id_to_workflow_draft_variables.py new file mode 100644 index 0000000000..432e4dadf5 --- /dev/null +++ b/api/migrations/versions/2026_03_04_1600-6b5f9f8b1a2c_add_user_id_to_workflow_draft_variables.py @@ -0,0 +1,69 @@ +"""add user_id and switch workflow_draft_variables unique key to user scope + +Revision ID: 6b5f9f8b1a2c +Revises: 0ec65df55790 +Create Date: 2026-03-04 16:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "6b5f9f8b1a2c" +down_revision = "0ec65df55790" +branch_labels = None +depends_on = None + + +def _is_pg(conn) -> bool: + return conn.dialect.name == "postgresql" + + +def upgrade(): + conn = op.get_bind() + table_name = "workflow_draft_variables" + + with op.batch_alter_table(table_name, schema=None) as batch_op: + batch_op.add_column(sa.Column("user_id", models.types.StringUUID(), nullable=True)) + + if _is_pg(conn): + with op.get_context().autocommit_block(): + op.create_index( + "workflow_draft_variables_app_id_user_id_key", + "workflow_draft_variables", + ["app_id", "user_id", "node_id", "name"], + unique=True, + postgresql_concurrently=True, + ) + else: + op.create_index( + "workflow_draft_variables_app_id_user_id_key", + "workflow_draft_variables", + ["app_id", "user_id", "node_id", "name"], + unique=True, + ) + + with op.batch_alter_table(table_name, schema=None) as batch_op: + batch_op.drop_constraint(op.f("workflow_draft_variables_app_id_key"), type_="unique") + + +def downgrade(): + conn = op.get_bind() + + with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op: + batch_op.create_unique_constraint( + op.f("workflow_draft_variables_app_id_key"), + ["app_id", "node_id", "name"], + ) + + if _is_pg(conn): + with op.get_context().autocommit_block(): + op.drop_index("workflow_draft_variables_app_id_user_id_key", postgresql_concurrently=True) + else: + op.drop_index("workflow_draft_variables_app_id_user_id_key", table_name="workflow_draft_variables") + + with op.batch_alter_table("workflow_draft_variables", schema=None) as batch_op: + batch_op.drop_column("user_id") diff --git a/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py b/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py new file mode 100644 index 0000000000..4b1b32c8d3 --- /dev/null +++ b/api/migrations/versions/2026_03_17_1721-4c60d8d3ee74_merge_migration_heads.py @@ -0,0 +1,25 @@ +"""merge migration heads + +Revision ID: 4c60d8d3ee74 +Revises: fce013ca180e, a1b2c3d4e5f6 +Create Date: 2026-03-17 17:21:12.105536 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4c60d8d3ee74' +down_revision = ('fce013ca180e', 'a1b2c3d4e5f6') +branch_labels = None +depends_on = None + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/api/models/__init__.py b/api/models/__init__.py index 97a9f8ec6f..8a605d5195 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -30,7 +30,6 @@ from .enums import ( AppTriggerStatus, AppTriggerType, CreatorUserRole, - UserFrom, WorkflowRunTriggeredFrom, WorkflowTriggerStatus, ) @@ -219,7 +218,6 @@ __all__ = [ "TriggerOAuthTenantClient", "TriggerSubscription", "UploadFile", - "UserFrom", "Whitelist", "Workflow", "WorkflowAppLog", diff --git a/api/models/account.py b/api/models/account.py index f7a9c20026..5960ac6564 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,12 +8,12 @@ from uuid import uuid4 import sqlalchemy as sa from flask_login import UserMixin from sqlalchemy import DateTime, String, func, select -from sqlalchemy.orm import Mapped, Session, mapped_column, validates +from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import deprecated from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID class TenantAccountRole(enum.StrEnum): @@ -104,7 +104,9 @@ class Account(UserMixin, TypeBase): last_active_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False ) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'active'"), default="active") + status: Mapped[AccountStatus] = mapped_column( + EnumText(AccountStatus, length=16), server_default=sa.text("'active'"), default=AccountStatus.ACTIVE + ) initialized_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -116,12 +118,6 @@ class Account(UserMixin, TypeBase): role: TenantAccountRole | None = field(default=None, init=False) _current_tenant: "Tenant | None" = field(default=None, init=False) - @validates("status") - def _normalize_status(self, _key: str, value: str | AccountStatus) -> str: - if isinstance(value, AccountStatus): - return value.value - return value - @property def is_password_set(self): return self.password is not None @@ -177,18 +173,15 @@ class Account(UserMixin, TypeBase): return self.role def get_status(self) -> AccountStatus: - status_str = self.status - return AccountStatus(status_str) + return self.status @classmethod def get_by_openid(cls, provider: str, open_id: str): - account_integrate = ( - db.session.query(AccountIntegrate) - .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) - .one_or_none() - ) + account_integrate = db.session.execute( + select(AccountIntegrate).where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id) + ).scalar_one_or_none() if account_integrate: - return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none() + return db.session.scalar(select(Account).where(Account.id == account_integrate.account_id)) return None # check current_user.current_tenant.current_role in ['admin', 'owner'] @@ -249,7 +242,9 @@ class Tenant(TypeBase): name: Mapped[str] = mapped_column(String(255)) encrypt_public_key: Mapped[str | None] = mapped_column(LongText, default=None) plan: Mapped[str] = mapped_column(String(255), server_default=sa.text("'basic'"), default="basic") - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'"), default="normal") + status: Mapped[TenantStatus] = mapped_column( + EnumText(TenantStatus, length=255), server_default=sa.text("'normal'"), default=TenantStatus.NORMAL + ) custom_config: Mapped[str | None] = mapped_column(LongText, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -291,7 +286,9 @@ class TenantAccountJoin(TypeBase): tenant_id: Mapped[str] = mapped_column(StringUUID) account_id: Mapped[str] = mapped_column(StringUUID) current: Mapped[bool] = mapped_column(sa.Boolean, server_default=sa.text("false"), default=False) - role: Mapped[str] = mapped_column(String(16), server_default="normal", default="normal") + role: Mapped[TenantAccountRole] = mapped_column( + EnumText(TenantAccountRole, length=16), server_default="normal", default=TenantAccountRole.NORMAL + ) invited_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime, server_default=func.current_timestamp(), nullable=False, init=False @@ -324,6 +321,11 @@ class AccountIntegrate(TypeBase): ) +class InvitationCodeStatus(enum.StrEnum): + UNUSED = "unused" + USED = "used" + + class InvitationCode(TypeBase): __tablename__ = "invitation_codes" __table_args__ = ( @@ -335,7 +337,11 @@ class InvitationCode(TypeBase): id: Mapped[int] = mapped_column(sa.Integer, init=False) batch: Mapped[str] = mapped_column(String(255)) code: Mapped[str] = mapped_column(String(32)) - status: Mapped[str] = mapped_column(String(16), server_default=sa.text("'unused'"), default="unused") + status: Mapped[InvitationCodeStatus] = mapped_column( + EnumText(InvitationCodeStatus, length=16), + server_default=sa.text("'unused'"), + default=InvitationCodeStatus.UNUSED, + ) used_at: Mapped[datetime | None] = mapped_column(DateTime, default=None) used_by_tenant_id: Mapped[str | None] = mapped_column(StringUUID, default=None) used_by_account_id: Mapped[str | None] = mapped_column(StringUUID, default=None) @@ -367,10 +373,13 @@ class TenantPluginPermission(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) install_permission: Mapped[InstallPermission] = mapped_column( - String(16), nullable=False, server_default="everyone", default=InstallPermission.EVERYONE + EnumText(InstallPermission, length=16), + nullable=False, + server_default="everyone", + default=InstallPermission.EVERYONE, ) debug_permission: Mapped[DebugPermission] = mapped_column( - String(16), nullable=False, server_default="noone", default=DebugPermission.NOBODY + EnumText(DebugPermission, length=16), nullable=False, server_default="noone", default=DebugPermission.NOBODY ) @@ -396,10 +405,13 @@ class TenantPluginAutoUpgradeStrategy(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) strategy_setting: Mapped[StrategySetting] = mapped_column( - String(16), nullable=False, server_default="fix_only", default=StrategySetting.FIX_ONLY + EnumText(StrategySetting, length=16), + nullable=False, + server_default="fix_only", + default=StrategySetting.FIX_ONLY, ) upgrade_mode: Mapped[UpgradeMode] = mapped_column( - String(16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE + EnumText(UpgradeMode, length=16), nullable=False, server_default="exclude", default=UpgradeMode.EXCLUDE ) exclude_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) include_plugins: Mapped[list[str]] = mapped_column(sa.JSON, nullable=False, default_factory=list) diff --git a/api/models/dataset.py b/api/models/dataset.py index e7da2961bc..d0163e6984 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -8,9 +8,10 @@ import os import pickle import re import time +from collections.abc import Sequence from datetime import datetime from json import JSONDecodeError -from typing import Any, cast +from typing import Any, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -19,6 +20,7 @@ from sqlalchemy.orm import Mapped, Session, mapped_column from configs import dify_config from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.constant.query_type import QueryType from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file @@ -29,12 +31,81 @@ from services.entities.knowledge_entities.knowledge_entities import ParentMode, from .account import Account from .base import Base, TypeBase from .engine import db +from .enums import ( + CollectionBindingType, + CreatorUserRole, + DatasetMetadataType, + DatasetQuerySource, + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + DocumentDocType, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, + SummaryStatus, +) from .model import App, Tag, TagBinding, UploadFile -from .types import AdjustedJSON, BinaryData, LongText, StringUUID, adjusted_json_index +from .types import AdjustedJSON, BinaryData, EnumText, LongText, StringUUID, adjusted_json_index logger = logging.getLogger(__name__) +class PreProcessingRuleItem(TypedDict): + id: str + enabled: bool + + +class SegmentationConfig(TypedDict): + delimiter: str + max_tokens: int + chunk_overlap: int + + +class AutomaticRulesConfig(TypedDict): + pre_processing_rules: list[PreProcessingRuleItem] + segmentation: SegmentationConfig + + +class ProcessRuleDict(TypedDict): + id: str + dataset_id: str + mode: str + rules: dict[str, Any] | None + + +class DocMetadataDetailItem(TypedDict): + id: str + name: str + type: str + value: Any + + +class AttachmentItem(TypedDict): + id: str + name: str + size: int + extension: str + mime_type: str + source_url: str + + +class DatasetBindingItem(TypedDict): + id: str + name: str + + +class ExternalKnowledgeApiDict(TypedDict): + id: str + tenant_id: str + name: str + description: str + settings: dict[str, Any] | None + dataset_bindings: list[DatasetBindingItem] + created_by: str + created_at: str + + class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" @@ -51,14 +122,19 @@ class Dataset(Base): INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] PROVIDER_LIST = ["vendor", "external", None] + DOC_FORM_LIST = [member.value for member in IndexStructureType] id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) description = mapped_column(LongText, nullable=True) provider: Mapped[str] = mapped_column(String(255), server_default=sa.text("'vendor'")) - permission: Mapped[str] = mapped_column(String(255), server_default=sa.text("'only_me'")) - data_source_type = mapped_column(String(255)) + permission: Mapped[DatasetPermissionEnum] = mapped_column( + EnumText(DatasetPermissionEnum, length=255), + server_default=sa.text("'only_me'"), + default=DatasetPermissionEnum.ONLY_ME, + ) + data_source_type = mapped_column(EnumText(DataSourceType, length=255)) indexing_technique: Mapped[str | None] = mapped_column(String(255)) index_struct = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) @@ -75,7 +151,9 @@ class Dataset(Base): summary_index_setting = mapped_column(AdjustedJSON, nullable=True) built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) icon_info = mapped_column(AdjustedJSON, nullable=True) - runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'")) + runtime_mode = mapped_column( + EnumText(DatasetRuntimeMode, length=255), nullable=True, server_default=sa.text("'general'") + ) pipeline_id = mapped_column(StringUUID, nullable=True) chunk_structure = mapped_column(sa.String(255), nullable=True) enable_api = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) @@ -83,30 +161,25 @@ class Dataset(Base): @property def total_documents(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def total_available_documents(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def dataset_keyword_table(self): - dataset_keyword_table = ( - db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first() - ) - if dataset_keyword_table: - return dataset_keyword_table - - return None + return db.session.scalar(select(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id)) @property def index_struct_dict(self): @@ -133,64 +206,66 @@ class Dataset(Base): @property def latest_process_rule(self): - return ( - db.session.query(DatasetProcessRule) + return db.session.scalar( + select(DatasetProcessRule) .where(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) - .first() + .limit(1) ) @property def app_count(self): return ( - db.session.query(func.count(AppDatasetJoin.id)) - .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) - .scalar() + db.session.scalar( + select(func.count(AppDatasetJoin.id)).where( + AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id + ) + ) + or 0 ) @property def document_count(self): - return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar() + return db.session.scalar(select(func.count(Document.id)).where(Document.dataset_id == self.id)) or 0 @property def available_document_count(self): return ( - db.session.query(func.count(Document.id)) - .where( - Document.dataset_id == self.id, - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, + db.session.scalar( + select(func.count(Document.id)).where( + Document.dataset_id == self.id, + Document.indexing_status == "completed", + Document.enabled == True, + Document.archived == False, + ) ) - .scalar() + or 0 ) @property def available_segment_count(self): return ( - db.session.query(func.count(DocumentSegment.id)) - .where( - DocumentSegment.dataset_id == self.id, - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.dataset_id == self.id, + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + ) ) - .scalar() + or 0 ) @property def word_count(self): - return ( - db.session.query(Document) - .with_entities(func.coalesce(func.sum(Document.word_count), 0)) - .where(Document.dataset_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(Document.word_count), 0)).where(Document.dataset_id == self.id) ) @property def doc_form(self) -> str | None: if self.chunk_structure: return self.chunk_structure - document = db.session.query(Document).where(Document.dataset_id == self.id).first() + document = db.session.scalar(select(Document).where(Document.dataset_id == self.id).limit(1)) if document: return document.doc_form return None @@ -208,8 +283,8 @@ class Dataset(Base): @property def tags(self): - tags = ( - db.session.query(Tag) + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -217,8 +292,7 @@ class Dataset(Base): Tag.tenant_id == self.tenant_id, Tag.type == "knowledge", ) - .all() - ) + ).all() return tags or [] @@ -226,8 +300,8 @@ class Dataset(Base): def external_knowledge_info(self): if self.provider != "external": return None - external_knowledge_binding = ( - db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first() + external_knowledge_binding = db.session.scalar( + select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id) ) if not external_knowledge_binding: return None @@ -248,7 +322,7 @@ class Dataset(Base): @property def is_published(self): if self.pipeline_id: - pipeline = db.session.query(Pipeline).where(Pipeline.id == self.pipeline_id).first() + pipeline = db.session.scalar(select(Pipeline).where(Pipeline.id == self.pipeline_id)) if pipeline: return pipeline.is_published return False @@ -320,14 +394,14 @@ class DatasetProcessRule(Base): # bug id = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4())) dataset_id = mapped_column(StringUUID, nullable=False) - mode = mapped_column(String(255), nullable=False, server_default=sa.text("'automatic'")) + mode = mapped_column(EnumText(ProcessRuleMode, length=255), nullable=False, server_default=sa.text("'automatic'")) rules = mapped_column(LongText, nullable=True) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) MODES = ["automatic", "custom", "hierarchical"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] - AUTOMATIC_RULES: dict[str, Any] = { + AUTOMATIC_RULES: AutomaticRulesConfig = { "pre_processing_rules": [ {"id": "remove_extra_spaces", "enabled": True}, {"id": "remove_urls_emails", "enabled": False}, @@ -335,7 +409,7 @@ class DatasetProcessRule(Base): # bug "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, } - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ProcessRuleDict: return { "id": self.id, "dataset_id": self.dataset_id, @@ -366,12 +440,12 @@ class Document(Base): tenant_id = mapped_column(StringUUID, nullable=False) dataset_id = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - data_source_type: Mapped[str] = mapped_column(String(255), nullable=False) + data_source_type: Mapped[str] = mapped_column(EnumText(DataSourceType, length=255), nullable=False) data_source_info = mapped_column(LongText, nullable=True) dataset_process_rule_id = mapped_column(StringUUID, nullable=True) batch: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) - created_from: Mapped[str] = mapped_column(String(255), nullable=False) + created_from: Mapped[str] = mapped_column(EnumText(DocumentCreatedFrom, length=255), nullable=False) created_by = mapped_column(StringUUID, nullable=False) created_api_request_id = mapped_column(StringUUID, nullable=True) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) @@ -405,7 +479,9 @@ class Document(Base): stopped_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) # basic fields - indexing_status = mapped_column(String(255), nullable=False, server_default=sa.text("'waiting'")) + indexing_status = mapped_column( + EnumText(IndexingStatus, length=255), nullable=False, server_default=sa.text("'waiting'") + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) @@ -416,7 +492,7 @@ class Document(Base): updated_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) - doc_type = mapped_column(String(40), nullable=True) + doc_type = mapped_column(EnumText(DocumentDocType, length=40), nullable=True) doc_metadata = mapped_column(AdjustedJSON, nullable=True) doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'")) doc_language = mapped_column(String(255), nullable=True) @@ -459,10 +535,8 @@ class Document(Base): if self.data_source_info: if self.data_source_type == "upload_file": data_source_info_dict: dict[str, Any] = json.loads(self.data_source_info) - file_detail = ( - db.session.query(UploadFile) - .where(UploadFile.id == data_source_info_dict["upload_file_id"]) - .one_or_none() + file_detail = db.session.scalar( + select(UploadFile).where(UploadFile.id == data_source_info_dict["upload_file_id"]) ) if file_detail: return { @@ -495,24 +569,23 @@ class Document(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def segment_count(self): - return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count() + return ( + db.session.scalar(select(func.count(DocumentSegment.id)).where(DocumentSegment.document_id == self.id)) or 0 + ) @property def hit_count(self): - return ( - db.session.query(DocumentSegment) - .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0)) - .where(DocumentSegment.document_id == self.id) - .scalar() + return db.session.scalar( + select(func.coalesce(func.sum(DocumentSegment.hit_count), 0)).where(DocumentSegment.document_id == self.id) ) @property def uploader(self): - user = db.session.query(Account).where(Account.id == self.created_by).first() + user = db.session.scalar(select(Account).where(Account.id == self.created_by)) return user.name if user else None @property @@ -524,19 +597,18 @@ class Document(Base): return self.updated_at @property - def doc_metadata_details(self) -> list[dict[str, Any]] | None: + def doc_metadata_details(self) -> list[DocMetadataDetailItem] | None: if self.doc_metadata: - document_metadatas = ( - db.session.query(DatasetMetadata) + document_metadatas = db.session.scalars( + select(DatasetMetadata) .join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id) .where( DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id ) - .all() - ) - metadata_list: list[dict[str, Any]] = [] + ).all() + metadata_list: list[DocMetadataDetailItem] = [] for metadata in document_metadatas: - metadata_dict: dict[str, Any] = { + metadata_dict: DocMetadataDetailItem = { "id": metadata.id, "name": metadata.name, "type": metadata.type, @@ -550,13 +622,13 @@ class Document(Base): return None @property - def process_rule_dict(self) -> dict[str, Any] | None: + def process_rule_dict(self) -> ProcessRuleDict | None: if self.dataset_process_rule_id and self.dataset_process_rule: return self.dataset_process_rule.to_dict() return None - def get_built_in_fields(self) -> list[dict[str, Any]]: - built_in_fields: list[dict[str, Any]] = [] + def get_built_in_fields(self) -> list[DocMetadataDetailItem]: + built_in_fields: list[DocMetadataDetailItem] = [] built_in_fields.append( { "id": "built-in", @@ -729,7 +801,7 @@ class DocumentSegment(Base): enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) disabled_by = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'waiting'")) + status: Mapped[str] = mapped_column(EnumText(SegmentStatus, length=255), server_default=sa.text("'waiting'")) created_by = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -764,7 +836,7 @@ class DocumentSegment(Base): ) @property - def child_chunks(self) -> list[Any]: + def child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -773,16 +845,13 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] - def get_child_chunks(self) -> list[Any]: + def get_child_chunks(self) -> Sequence[Any]: if not self.document: return [] process_rule = self.document.dataset_process_rule @@ -791,12 +860,9 @@ class DocumentSegment(Base): if rules_dict: rules = Rule.model_validate(rules_dict) if rules.parent_mode: - child_chunks = ( - db.session.query(ChildChunk) - .where(ChildChunk.segment_id == self.id) - .order_by(ChildChunk.position.asc()) - .all() - ) + child_chunks = db.session.scalars( + select(ChildChunk).where(ChildChunk.segment_id == self.id).order_by(ChildChunk.position.asc()) + ).all() return child_chunks or [] return [] @@ -870,7 +936,7 @@ class DocumentSegment(Base): return text @property - def attachments(self) -> list[dict[str, Any]]: + def attachments(self) -> list[AttachmentItem]: # Use JOIN to fetch attachments in a single query instead of two separate queries attachments_with_bindings = db.session.execute( select(SegmentAttachmentBinding, UploadFile) @@ -884,7 +950,7 @@ class DocumentSegment(Base): ).all() if not attachments_with_bindings: return [] - attachment_list = [] + attachment_list: list[AttachmentItem] = [] for _, attachment in attachments_with_bindings: upload_file_id = attachment.id nonce = os.urandom(16).hex() @@ -945,15 +1011,15 @@ class ChildChunk(Base): @property def dataset(self): - return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first() + return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) @property def document(self): - return db.session.query(Document).where(Document.id == self.document_id).first() + return db.session.scalar(select(Document).where(Document.id == self.document_id)) @property def segment(self): - return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first() + return db.session.scalar(select(DocumentSegment).where(DocumentSegment.id == self.segment_id)) class AppDatasetJoin(TypeBase): @@ -999,9 +1065,9 @@ class DatasetQuery(TypeBase): ) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) content: Mapped[str] = mapped_column(LongText, nullable=False) - source: Mapped[str] = mapped_column(String(255), nullable=False) + source: Mapped[str] = mapped_column(EnumText(DatasetQuerySource, length=255), nullable=False) source_app_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False @@ -1014,7 +1080,7 @@ class DatasetQuery(TypeBase): if isinstance(queries, list): for query in queries: if query["content_type"] == QueryType.IMAGE_QUERY: - file_info = db.session.query(UploadFile).filter_by(id=query["content"]).first() + file_info = db.session.scalar(select(UploadFile).where(UploadFile.id == query["content"])) if file_info: query["file_info"] = { "id": file_info.id, @@ -1079,7 +1145,7 @@ class DatasetKeywordTable(TypeBase): super().__init__(object_hook=object_hook, *args, **kwargs) # get dataset - dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() + dataset = db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id)) if not dataset: return None if self.data_source_type == "database": @@ -1144,7 +1210,9 @@ class DatasetCollectionBinding(TypeBase): ) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) model_name: Mapped[str] = mapped_column(String(255), nullable=False) - type: Mapped[str] = mapped_column(String(40), server_default=sa.text("'dataset'"), nullable=False) + type: Mapped[str] = mapped_column( + EnumText(CollectionBindingType, length=40), server_default=sa.text("'dataset'"), nullable=False + ) collection_name: Mapped[str] = mapped_column(String(64), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1254,7 +1322,7 @@ class ExternalKnowledgeApis(TypeBase): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ExternalKnowledgeApiDict: return { "id": self.id, "tenant_id": self.tenant_id, @@ -1274,13 +1342,13 @@ class ExternalKnowledgeApis(TypeBase): return None @property - def dataset_bindings(self) -> list[dict[str, Any]]: + def dataset_bindings(self) -> list[DatasetBindingItem]: external_knowledge_bindings = db.session.scalars( select(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) ).all() dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] datasets = db.session.scalars(select(Dataset).where(Dataset.id.in_(dataset_ids))).all() - dataset_bindings: list[dict[str, Any]] = [] + dataset_bindings: list[DatasetBindingItem] = [] for dataset in datasets: dataset_bindings.append({"id": dataset.id, "name": dataset.name}) @@ -1371,7 +1439,7 @@ class DatasetMetadata(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[str] = mapped_column(EnumText(DatasetMetadataType, length=255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=sa.func.current_timestamp(), init=False @@ -1473,7 +1541,7 @@ class PipelineCustomizedTemplate(TypeBase): @property def created_user_name(self): - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name return "" @@ -1508,7 +1576,7 @@ class Pipeline(TypeBase): ) def retrieve_dataset(self, session: Session): - return session.query(Dataset).where(Dataset.pipeline_id == self.id).first() + return session.scalar(select(Dataset).where(Dataset.pipeline_id == self.id)) class DocumentPipelineExecutionLog(TypeBase): @@ -1598,7 +1666,9 @@ class DocumentSegmentSummary(Base): summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True) summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True) tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True) - status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'")) + status: Mapped[str] = mapped_column( + EnumText(SummaryStatus, length=32), nullable=False, server_default=sa.text("'generating'") + ) error: Mapped[str] = mapped_column(LongText, nullable=True) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true")) disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True) diff --git a/api/models/enums.py b/api/models/enums.py index 2bc61120ce..4849099d30 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,16 +1,22 @@ from enum import StrEnum -from core.workflow.enums import NodeType +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" - -class UserFrom(StrEnum): - ACCOUNT = "account" - END_USER = "end-user" + @classmethod + def _missing_(cls, value): + if value == "end-user": + return cls.END_USER + else: + return super()._missing_(value) class WorkflowRunTriggeredFrom(StrEnum): @@ -71,9 +77,249 @@ class AppTriggerStatus(StrEnum): class AppTriggerType(StrEnum): """App Trigger Type Enum""" - TRIGGER_WEBHOOK = NodeType.TRIGGER_WEBHOOK.value - TRIGGER_SCHEDULE = NodeType.TRIGGER_SCHEDULE.value - TRIGGER_PLUGIN = NodeType.TRIGGER_PLUGIN.value + TRIGGER_WEBHOOK = TRIGGER_WEBHOOK_NODE_TYPE + TRIGGER_SCHEDULE = TRIGGER_SCHEDULE_NODE_TYPE + TRIGGER_PLUGIN = TRIGGER_PLUGIN_NODE_TYPE # for backward compatibility UNKNOWN = "unknown" + + +class AppStatus(StrEnum): + """App Status Enum""" + + NORMAL = "normal" + + +class AppMCPServerStatus(StrEnum): + """AppMCPServer Status Enum""" + + NORMAL = "normal" + ACTIVE = "active" + INACTIVE = "inactive" + + +class ConversationStatus(StrEnum): + """Conversation Status Enum""" + + NORMAL = "normal" + + +class DataSourceType(StrEnum): + """Data Source Type for Dataset and Document""" + + UPLOAD_FILE = "upload_file" + NOTION_IMPORT = "notion_import" + WEBSITE_CRAWL = "website_crawl" + LOCAL_FILE = "local_file" + ONLINE_DOCUMENT = "online_document" + + +class ProcessRuleMode(StrEnum): + """Dataset Process Rule Mode""" + + AUTOMATIC = "automatic" + CUSTOM = "custom" + HIERARCHICAL = "hierarchical" + + +class IndexingStatus(StrEnum): + """Document Indexing Status""" + + WAITING = "waiting" + PARSING = "parsing" + CLEANING = "cleaning" + SPLITTING = "splitting" + INDEXING = "indexing" + PAUSED = "paused" + COMPLETED = "completed" + ERROR = "error" + + +class DocumentCreatedFrom(StrEnum): + """Document Created From""" + + WEB = "web" + API = "api" + RAG_PIPELINE = "rag-pipeline" + + +class ConversationFromSource(StrEnum): + """Conversation / Message from_source""" + + API = "api" + CONSOLE = "console" + + +class FeedbackFromSource(StrEnum): + """MessageFeedback from_source""" + + USER = "user" + ADMIN = "admin" + + +class FeedbackRating(StrEnum): + """MessageFeedback rating""" + + LIKE = "like" + DISLIKE = "dislike" + + +class InvokeFrom(StrEnum): + """How a conversation/message was invoked""" + + SERVICE_API = "service-api" + WEB_APP = "web-app" + TRIGGER = "trigger" + EXPLORE = "explore" + DEBUGGER = "debugger" + PUBLISHED_PIPELINE = "published" + VALIDATION = "validation" + + @classmethod + def value_of(cls, value: str) -> "InvokeFrom": + return cls(value) + + def to_source(self) -> str: + source_mapping = { + InvokeFrom.WEB_APP: "web_app", + InvokeFrom.DEBUGGER: "dev", + InvokeFrom.EXPLORE: "explore_app", + InvokeFrom.TRIGGER: "trigger", + InvokeFrom.SERVICE_API: "api", + } + return source_mapping.get(self, "dev") + + +class DocumentDocType(StrEnum): + """Document doc_type classification""" + + BOOK = "book" + WEB_PAGE = "web_page" + PAPER = "paper" + SOCIAL_MEDIA_POST = "social_media_post" + WIKIPEDIA_ENTRY = "wikipedia_entry" + PERSONAL_DOCUMENT = "personal_document" + BUSINESS_DOCUMENT = "business_document" + IM_CHAT_LOG = "im_chat_log" + SYNCED_FROM_NOTION = "synced_from_notion" + SYNCED_FROM_GITHUB = "synced_from_github" + OTHERS = "others" + + +class TagType(StrEnum): + """Tag type""" + + KNOWLEDGE = "knowledge" + APP = "app" + + +class DatasetMetadataType(StrEnum): + """Dataset metadata value type""" + + STRING = "string" + NUMBER = "number" + TIME = "time" + + +class SegmentStatus(StrEnum): + """Document segment status""" + + WAITING = "waiting" + INDEXING = "indexing" + COMPLETED = "completed" + ERROR = "error" + PAUSED = "paused" + RE_SEGMENT = "re_segment" + + +class DatasetRuntimeMode(StrEnum): + """Dataset runtime mode""" + + GENERAL = "general" + RAG_PIPELINE = "rag_pipeline" + + +class CollectionBindingType(StrEnum): + """Dataset collection binding type""" + + DATASET = "dataset" + ANNOTATION = "annotation" + + +class DatasetQuerySource(StrEnum): + """Dataset query source""" + + HIT_TESTING = "hit_testing" + APP = "app" + + +class TidbAuthBindingStatus(StrEnum): + """TiDB auth binding status""" + + CREATING = "CREATING" + ACTIVE = "ACTIVE" + + +class MessageFileBelongsTo(StrEnum): + """MessageFile belongs_to""" + + USER = "user" + ASSISTANT = "assistant" + + +class CredentialSourceType(StrEnum): + """Load balancing credential source type""" + + PROVIDER = "provider" + CUSTOM_MODEL = "custom_model" + + +class PaymentStatus(StrEnum): + """Provider order payment status""" + + WAIT_PAY = "wait_pay" + PAID = "paid" + FAILED = "failed" + REFUNDED = "refunded" + + +class BannerStatus(StrEnum): + """ExporleBanner status""" + + ENABLED = "enabled" + DISABLED = "disabled" + + +class SummaryStatus(StrEnum): + """Document segment summary status""" + + NOT_STARTED = "not_started" + GENERATING = "generating" + COMPLETED = "completed" + ERROR = "error" + TIMEOUT = "timeout" + + +class MessageChainType(StrEnum): + """Message chain type""" + + SYSTEM = "system" + + +class ProviderQuotaType(StrEnum): + PAID = "paid" + """hosted paid quota""" + + FREE = "free" + """third-party free quota""" + + TRIAL = "trial" + """hosted trial quota""" + + @staticmethod + def value_of(value: str) -> "ProviderQuotaType": + for member in ProviderQuotaType: + if member.value == value: + return member + raise ValueError(f"No matching enum found for value '{value}'") diff --git a/api/models/human_input.py b/api/models/human_input.py index 5208461de1..48e7fbb9ea 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,7 +6,7 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( DeliveryMethodType, HumanInputFormKind, HumanInputFormStatus, @@ -30,6 +30,15 @@ def _generate_token() -> str: class HumanInputForm(DefaultFieldsMixin, Base): __tablename__ = "human_input_forms" + __table_args__ = ( + sa.Index( + "human_input_forms_workflow_run_id_node_id_idx", + "workflow_run_id", + "node_id", + ), + sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"), + sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"), + ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) @@ -84,6 +93,12 @@ class HumanInputForm(DefaultFieldsMixin, Base): class HumanInputDelivery(DefaultFieldsMixin, Base): __tablename__ = "human_input_form_deliveries" + __table_args__ = ( + sa.Index( + None, + "form_id", + ), + ) form_id: Mapped[str] = mapped_column( StringUUID, @@ -181,6 +196,10 @@ RecipientPayload = Annotated[ class HumanInputFormRecipient(DefaultFieldsMixin, Base): __tablename__ = "human_input_form_recipients" + __table_args__ = ( + sa.Index(None, "form_id"), + sa.Index(None, "delivery_id"), + ) form_id: Mapped[str] = mapped_column( StringUUID, diff --git a/api/models/model.py b/api/models/model.py index 4a95faf7f7..b098966052 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -7,7 +7,7 @@ from collections.abc import Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa @@ -15,27 +15,296 @@ from flask import request from flask_login import UserMixin # type: ignore[import-untyped] from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column +from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from core.workflow.file import helpers as file_helpers +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from dify_graph.file import helpers as file_helpers +from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string from .engine import db -from .enums import CreatorUserRole +from .enums import ( + AppMCPServerStatus, + AppStatus, + BannerStatus, + ConversationFromSource, + ConversationStatus, + CreatorUserRole, + FeedbackFromSource, + FeedbackRating, + InvokeFrom, + MessageChainType, + MessageFileBelongsTo, + MessageStatus, + TagType, +) from .provider_ids import GenericProviderID -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from .workflow import Workflow +# --- TypedDict definitions for structured dict return types --- + + +class EnabledConfig(TypedDict): + enabled: bool + + +class EmbeddingModelInfo(TypedDict): + embedding_provider_name: str + embedding_model_name: str + + +class AnnotationReplyDisabledConfig(TypedDict): + enabled: Literal[False] + + +class AnnotationReplyEnabledConfig(TypedDict): + id: str + enabled: Literal[True] + score_threshold: float + embedding_model: EmbeddingModelInfo + + +AnnotationReplyConfig = AnnotationReplyEnabledConfig | AnnotationReplyDisabledConfig + + +class SensitiveWordAvoidanceConfig(TypedDict): + enabled: bool + type: str + config: dict[str, Any] + + +class AgentToolConfig(TypedDict): + provider_type: str + provider_id: str + tool_name: str + tool_parameters: dict[str, Any] + plugin_unique_identifier: NotRequired[str | None] + credential_id: NotRequired[str | None] + + +class AgentModeConfig(TypedDict): + enabled: bool + strategy: str | None + tools: list[AgentToolConfig | dict[str, Any]] + prompt: str | None + + +class ImageUploadConfig(TypedDict): + enabled: bool + number_limits: int + detail: str + transfer_methods: list[str] + + +class FileUploadConfig(TypedDict): + image: ImageUploadConfig + + +class DeletedToolInfo(TypedDict): + type: str + tool_name: str + provider_id: str + + +class ExternalDataToolConfig(TypedDict): + enabled: bool + variable: str + type: str + config: dict[str, Any] + + +class UserInputFormItemConfig(TypedDict): + variable: str + label: str + description: NotRequired[str] + required: NotRequired[bool] + max_length: NotRequired[int] + options: NotRequired[list[str]] + default: NotRequired[str] + type: NotRequired[str] + config: NotRequired[dict[str, Any]] + + +# Each item is a single-key dict, e.g. {"text-input": UserInputFormItemConfig} +UserInputFormItem = dict[str, UserInputFormItemConfig] + + +class DatasetConfigs(TypedDict): + retrieval_model: str + datasets: NotRequired[dict[str, Any]] + top_k: NotRequired[int] + score_threshold: NotRequired[float] + score_threshold_enabled: NotRequired[bool] + reranking_model: NotRequired[dict[str, Any] | None] + weights: NotRequired[dict[str, Any] | None] + reranking_enabled: NotRequired[bool] + reranking_mode: NotRequired[str] + metadata_filtering_mode: NotRequired[str] + metadata_model_config: NotRequired[dict[str, Any] | None] + metadata_filtering_conditions: NotRequired[dict[str, Any] | None] + + +class ChatPromptMessage(TypedDict): + text: str + role: str + + +class ChatPromptConfig(TypedDict, total=False): + prompt: list[ChatPromptMessage] + + +class CompletionPromptText(TypedDict): + text: str + + +class ConversationHistoriesRole(TypedDict): + user_prefix: str + assistant_prefix: str + + +class CompletionPromptConfig(TypedDict): + prompt: CompletionPromptText + conversation_histories_role: NotRequired[ConversationHistoriesRole] + + +class ModelConfig(TypedDict): + provider: str + name: str + mode: str + completion_params: NotRequired[dict[str, Any]] + + +class AppModelConfigDict(TypedDict): + opening_statement: str | None + suggested_questions: list[str] + suggested_questions_after_answer: EnabledConfig + speech_to_text: EnabledConfig + text_to_speech: EnabledConfig + retriever_resource: EnabledConfig + annotation_reply: AnnotationReplyConfig + more_like_this: EnabledConfig + sensitive_word_avoidance: SensitiveWordAvoidanceConfig + external_data_tools: list[ExternalDataToolConfig] + model: ModelConfig + user_input_form: list[UserInputFormItem] + dataset_query_variable: str | None + pre_prompt: str | None + agent_mode: AgentModeConfig + prompt_type: str + chat_prompt_config: ChatPromptConfig + completion_prompt_config: CompletionPromptConfig + dataset_configs: DatasetConfigs + file_upload: FileUploadConfig + # Added dynamically in Conversation.model_config + model_id: NotRequired[str | None] + provider: NotRequired[str | None] + + +class ConversationDict(TypedDict): + id: str + app_id: str + app_model_config_id: str | None + model_provider: str | None + override_model_configs: str | None + model_id: str | None + mode: str + name: str + summary: str | None + inputs: dict[str, Any] + introduction: str | None + system_instruction: str | None + system_instruction_tokens: int + status: str + invoke_from: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + read_at: datetime | None + read_account_id: str | None + dialogue_count: int + created_at: datetime + updated_at: datetime + + +class MessageDict(TypedDict): + id: str + app_id: str + conversation_id: str + model_id: str | None + inputs: dict[str, Any] + query: str + total_price: Decimal | None + message: dict[str, Any] + answer: str + status: str + error: str | None + message_metadata: dict[str, Any] + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + agent_based: bool + workflow_run_id: str | None + + +class MessageFeedbackDict(TypedDict): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None + from_source: str + from_end_user_id: str | None + from_account_id: str | None + created_at: str + updated_at: str + + +class MessageFileInfo(TypedDict, total=False): + belongs_to: str | None + upload_file_id: str | None + id: str + tenant_id: str + type: str + transfer_method: str + remote_url: str | None + related_id: str | None + filename: str | None + extension: str | None + mime_type: str | None + size: int + dify_model_identity: str + url: str | None + + +class ExtraContentDict(TypedDict, total=False): + type: str + workflow_run_id: str + + +class TraceAppConfigDict(TypedDict): + id: str + app_id: str + tracing_provider: str | None + tracing_config: dict[str, Any] + is_active: bool + created_at: str | None + updated_at: str | None + + class DifySetup(TypeBase): __tablename__ = "dify_setups" __table_args__ = (sa.PrimaryKeyConstraint("version", name="dify_setup_pkey"),) @@ -83,13 +352,15 @@ class App(Base): tenant_id: Mapped[str] = mapped_column(StringUUID) name: Mapped[str] = mapped_column(String(255)) description: Mapped[str] = mapped_column(LongText, default=sa.text("''")) - mode: Mapped[str] = mapped_column(String(255)) - icon_type: Mapped[str | None] = mapped_column(String(255)) # image, emoji, link + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255)) icon = mapped_column(String(255)) icon_background: Mapped[str | None] = mapped_column(String(255)) app_model_config_id = mapped_column(StringUUID, nullable=True) workflow_id = mapped_column(StringUUID, nullable=True) - status: Mapped[str] = mapped_column(String(255), server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) enable_site: Mapped[bool] = mapped_column(sa.Boolean) enable_api: Mapped[bool] = mapped_column(sa.Boolean) api_rpm: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0")) @@ -124,13 +395,12 @@ class App(Base): @property def site(self) -> Site | None: - site = db.session.query(Site).where(Site.app_id == self.id).first() - return site + return db.session.scalar(select(Site).where(Site.app_id == self.id)) @property def app_model_config(self) -> AppModelConfig | None: if self.app_model_config_id: - return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + return db.session.scalar(select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id)) return None @@ -139,7 +409,7 @@ class App(Base): if self.workflow_id: from .workflow import Workflow - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) return None @@ -149,8 +419,7 @@ class App(Base): @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def is_agent(self) -> bool: @@ -176,7 +445,7 @@ class App(Base): return str(self.mode) @property - def deleted_tools(self) -> list[dict[str, str]]: + def deleted_tools(self) -> list[DeletedToolInfo]: from core.tools.tool_manager import ToolManager, ToolProviderType from services.plugin.plugin_service import PluginService @@ -257,7 +526,7 @@ class App(Base): provider_id.provider_name: existence[i] for i, provider_id in enumerate(builtin_provider_ids) } - deleted_tools: list[dict[str, str]] = [] + deleted_tools: list[DeletedToolInfo] = [] for tool in tools: keys = list(tool.keys()) @@ -290,9 +559,9 @@ class App(Base): return deleted_tools @property - def tags(self) -> list[Tag]: - tags = ( - db.session.query(Tag) + def tags(self) -> Sequence[Tag]: + tags = db.session.scalars( + select(Tag) .join(TagBinding, Tag.id == TagBinding.tag_id) .where( TagBinding.target_id == self.id, @@ -300,15 +569,14 @@ class App(Base): Tag.tenant_id == self.tenant_id, Tag.type == "app", ) - .all() - ) + ).all() return tags or [] @property def author_name(self) -> str | None: if self.created_by: - account = db.session.query(Account).where(Account.id == self.created_by).first() + account = db.session.scalar(select(Account).where(Account.id == self.created_by)) if account: return account.name @@ -360,41 +628,43 @@ class AppModelConfig(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property - def model_dict(self) -> dict[str, Any]: - return json.loads(self.model) if self.model else {} + def model_dict(self) -> ModelConfig: + return cast(ModelConfig, json.loads(self.model) if self.model else {}) @property def suggested_questions_list(self) -> list[str]: return json.loads(self.suggested_questions) if self.suggested_questions else [] @property - def suggested_questions_after_answer_dict(self) -> dict[str, Any]: - return ( + def suggested_questions_after_answer_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.suggested_questions_after_answer) if self.suggested_questions_after_answer - else {"enabled": False} + else {"enabled": False}, ) @property - def speech_to_text_dict(self) -> dict[str, Any]: - return json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False} + def speech_to_text_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.speech_to_text) if self.speech_to_text else {"enabled": False}) @property - def text_to_speech_dict(self) -> dict[str, Any]: - return json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False} + def text_to_speech_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.text_to_speech) if self.text_to_speech else {"enabled": False}) @property - def retriever_resource_dict(self) -> dict[str, Any]: - return json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + def retriever_resource_dict(self) -> EnabledConfig: + return cast( + EnabledConfig, json.loads(self.retriever_resource) if self.retriever_resource else {"enabled": True} + ) @property - def annotation_reply_dict(self) -> dict[str, Any]: - annotation_setting = ( - db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first() + def annotation_reply_dict(self) -> AnnotationReplyConfig: + annotation_setting = db.session.scalar( + select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id) ) if annotation_setting: collection_binding_detail = annotation_setting.collection_binding_detail @@ -415,56 +685,62 @@ class AppModelConfig(TypeBase): return {"enabled": False} @property - def more_like_this_dict(self) -> dict[str, Any]: - return json.loads(self.more_like_this) if self.more_like_this else {"enabled": False} + def more_like_this_dict(self) -> EnabledConfig: + return cast(EnabledConfig, json.loads(self.more_like_this) if self.more_like_this else {"enabled": False}) @property - def sensitive_word_avoidance_dict(self) -> dict[str, Any]: - return ( + def sensitive_word_avoidance_dict(self) -> SensitiveWordAvoidanceConfig: + return cast( + SensitiveWordAvoidanceConfig, json.loads(self.sensitive_word_avoidance) if self.sensitive_word_avoidance - else {"enabled": False, "type": "", "configs": []} + else {"enabled": False, "type": "", "config": {}}, ) @property - def external_data_tools_list(self) -> list[dict[str, Any]]: + def external_data_tools_list(self) -> list[ExternalDataToolConfig]: return json.loads(self.external_data_tools) if self.external_data_tools else [] @property - def user_input_form_list(self) -> list[dict[str, Any]]: + def user_input_form_list(self) -> list[UserInputFormItem]: return json.loads(self.user_input_form) if self.user_input_form else [] @property - def agent_mode_dict(self) -> dict[str, Any]: - return ( + def agent_mode_dict(self) -> AgentModeConfig: + return cast( + AgentModeConfig, json.loads(self.agent_mode) if self.agent_mode - else {"enabled": False, "strategy": None, "tools": [], "prompt": None} + else {"enabled": False, "strategy": None, "tools": [], "prompt": None}, ) @property - def chat_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.chat_prompt_config) if self.chat_prompt_config else {} + def chat_prompt_config_dict(self) -> ChatPromptConfig: + return cast(ChatPromptConfig, json.loads(self.chat_prompt_config) if self.chat_prompt_config else {}) @property - def completion_prompt_config_dict(self) -> dict[str, Any]: - return json.loads(self.completion_prompt_config) if self.completion_prompt_config else {} + def completion_prompt_config_dict(self) -> CompletionPromptConfig: + return cast( + CompletionPromptConfig, + json.loads(self.completion_prompt_config) if self.completion_prompt_config else {}, + ) @property - def dataset_configs_dict(self) -> dict[str, Any]: + def dataset_configs_dict(self) -> DatasetConfigs: if self.dataset_configs: - dataset_configs: dict[str, Any] = json.loads(self.dataset_configs) + dataset_configs = json.loads(self.dataset_configs) if "retrieval_model" not in dataset_configs: return {"retrieval_model": "single"} else: - return dataset_configs + return cast(DatasetConfigs, dataset_configs) return { "retrieval_model": "multiple", } @property - def file_upload_dict(self) -> dict[str, Any]: - return ( + def file_upload_dict(self) -> FileUploadConfig: + return cast( + FileUploadConfig, json.loads(self.file_upload) if self.file_upload else { @@ -474,10 +750,10 @@ class AppModelConfig(TypeBase): "detail": "high", "transfer_methods": ["remote_url", "local_file"], } - } + }, ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> AppModelConfigDict: return { "opening_statement": self.opening_statement, "suggested_questions": self.suggested_questions_list, @@ -501,36 +777,42 @@ class AppModelConfig(TypeBase): "file_upload": self.file_upload_dict, } - def from_model_config_dict(self, model_config: Mapping[str, Any]): + def from_model_config_dict(self, model_config: AppModelConfigDict): self.opening_statement = model_config.get("opening_statement") self.suggested_questions = ( - json.dumps(model_config["suggested_questions"]) if model_config.get("suggested_questions") else None + json.dumps(model_config.get("suggested_questions")) if model_config.get("suggested_questions") else None ) self.suggested_questions_after_answer = ( - json.dumps(model_config["suggested_questions_after_answer"]) + json.dumps(model_config.get("suggested_questions_after_answer")) if model_config.get("suggested_questions_after_answer") else None ) - self.speech_to_text = json.dumps(model_config["speech_to_text"]) if model_config.get("speech_to_text") else None - self.text_to_speech = json.dumps(model_config["text_to_speech"]) if model_config.get("text_to_speech") else None - self.more_like_this = json.dumps(model_config["more_like_this"]) if model_config.get("more_like_this") else None + self.speech_to_text = ( + json.dumps(model_config.get("speech_to_text")) if model_config.get("speech_to_text") else None + ) + self.text_to_speech = ( + json.dumps(model_config.get("text_to_speech")) if model_config.get("text_to_speech") else None + ) + self.more_like_this = ( + json.dumps(model_config.get("more_like_this")) if model_config.get("more_like_this") else None + ) self.sensitive_word_avoidance = ( - json.dumps(model_config["sensitive_word_avoidance"]) + json.dumps(model_config.get("sensitive_word_avoidance")) if model_config.get("sensitive_word_avoidance") else None ) self.external_data_tools = ( - json.dumps(model_config["external_data_tools"]) if model_config.get("external_data_tools") else None + json.dumps(model_config.get("external_data_tools")) if model_config.get("external_data_tools") else None ) - self.model = json.dumps(model_config["model"]) if model_config.get("model") else None + self.model = json.dumps(model_config.get("model")) if model_config.get("model") else None self.user_input_form = ( - json.dumps(model_config["user_input_form"]) if model_config.get("user_input_form") else None + json.dumps(model_config.get("user_input_form")) if model_config.get("user_input_form") else None ) self.dataset_query_variable = model_config.get("dataset_query_variable") - self.pre_prompt = model_config["pre_prompt"] - self.agent_mode = json.dumps(model_config["agent_mode"]) if model_config.get("agent_mode") else None + self.pre_prompt = model_config.get("pre_prompt") + self.agent_mode = json.dumps(model_config.get("agent_mode")) if model_config.get("agent_mode") else None self.retriever_resource = ( - json.dumps(model_config["retriever_resource"]) if model_config.get("retriever_resource") else None + json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) self.prompt_type = model_config.get("prompt_type", "simple") self.chat_prompt_config = ( @@ -574,8 +856,7 @@ class RecommendedApp(Base): # bug @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class InstalledApp(TypeBase): @@ -602,13 +883,11 @@ class InstalledApp(TypeBase): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def tenant(self) -> Tenant | None: - tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() - return tenant + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class TrialApp(Base): @@ -628,8 +907,7 @@ class TrialApp(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) class AccountTrialAppRecord(Base): @@ -648,13 +926,11 @@ class AccountTrialAppRecord(Base): @property def app(self) -> App | None: - app = db.session.query(App).where(App.id == self.app_id).first() - return app + return db.session.scalar(select(App).where(App.id == self.app_id)) @property def user(self) -> Account | None: - user = db.session.query(Account).where(Account.id == self.account_id).first() - return user + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class ExporleBanner(TypeBase): @@ -664,8 +940,11 @@ class ExporleBanner(TypeBase): content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) - status: Mapped[str] = mapped_column( - sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + status: Mapped[BannerStatus] = mapped_column( + EnumText(BannerStatus, length=255), + nullable=False, + server_default=sa.text("'enabled'::character varying"), + default=BannerStatus.ENABLED, ) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -711,6 +990,18 @@ class Conversation(Base): __table_args__ = ( sa.PrimaryKeyConstraint("id", name="conversation_pkey"), sa.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), + sa.Index( + "conversation_app_created_at_idx", + "app_id", + sa.text("created_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), + sa.Index( + "conversation_app_updated_at_idx", + "app_id", + sa.text("updated_at DESC"), + postgresql_where=sa.text("is_deleted IS false"), + ), ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) @@ -719,23 +1010,27 @@ class Conversation(Base): model_provider = mapped_column(String(255), nullable=True) override_model_configs = mapped_column(LongText) model_id = mapped_column(String(255), nullable=True) - mode: Mapped[str] = mapped_column(String(255)) + mode: Mapped[AppMode] = mapped_column(EnumText(AppMode, length=255)) name: Mapped[str] = mapped_column(String(255), nullable=False) summary = mapped_column(LongText) _inputs: Mapped[dict[str, Any]] = mapped_column("inputs", sa.JSON) introduction = mapped_column(LongText) system_instruction = mapped_column(LongText) system_instruction_tokens: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0")) - status: Mapped[str] = mapped_column(String(255), nullable=False) + status: Mapped[ConversationStatus] = mapped_column( + EnumText(ConversationStatus, length=255), nullable=False, default=ConversationStatus.NORMAL + ) # The `invoke_from` records how the conversation is created. # # Its value corresponds to the members of `InvokeFrom`. # (api/core/app/entities/app_invoke_entities.py) - invoke_from = mapped_column(String(255), nullable=True) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) # ref: ConversationSource. - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id = mapped_column(StringUUID) from_account_id = mapped_column(StringUUID) read_at = mapped_column(sa.DateTime) @@ -811,27 +1106,29 @@ class Conversation(Base): self._inputs = inputs @property - def model_config(self): - model_config = {} + def model_config(self) -> AppModelConfigDict: + model_config = cast(AppModelConfigDict, {}) app_model_config: AppModelConfig | None = None if self.mode == AppMode.ADVANCED_CHAT: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) - model_config = override_model_configs + model_config = cast(AppModelConfigDict, override_model_configs) else: if self.override_model_configs: override_model_configs = json.loads(self.override_model_configs) if "model" in override_model_configs: # where is app_id? - app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict(override_model_configs) + app_model_config = AppModelConfig(app_id=self.app_id).from_model_config_dict( + cast(AppModelConfigDict, override_model_configs) + ) model_config = app_model_config.to_dict() else: - model_config["configs"] = override_model_configs + model_config["configs"] = override_model_configs # type: ignore[typeddict-unknown-key] else: - app_model_config = ( - db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first() + app_model_config = db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id) ) if app_model_config: model_config = app_model_config.to_dict() @@ -854,36 +1151,43 @@ class Conversation(Base): @property def annotated(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0 + return ( + db.session.scalar( + select(func.count(MessageAnnotation.id)).where(MessageAnnotation.conversation_id == self.id) + ) + or 0 + ) > 0 @property def annotation(self): - return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first() + return db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).limit(1)) @property def message_count(self): - return db.session.query(Message).where(Message.conversation_id == self.id).count() + return db.session.scalar(select(func.count(Message.id)).where(Message.conversation_id == self.id)) or 0 @property def user_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == FeedbackRating.LIKE, + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "user", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "user", + MessageFeedback.rating == FeedbackRating.DISLIKE, + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -891,23 +1195,25 @@ class Conversation(Base): @property def admin_feedback_stats(self): like = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "like", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == FeedbackRating.LIKE, + ) ) - .count() + or 0 ) dislike = ( - db.session.query(MessageFeedback) - .where( - MessageFeedback.conversation_id == self.id, - MessageFeedback.from_source == "admin", - MessageFeedback.rating == "dislike", + db.session.scalar( + select(func.count(MessageFeedback.id)).where( + MessageFeedback.conversation_id == self.id, + MessageFeedback.from_source == "admin", + MessageFeedback.rating == FeedbackRating.DISLIKE, + ) ) - .count() + or 0 ) return {"like": like, "dislike": dislike} @@ -969,22 +1275,19 @@ class Conversation(Base): @property def first_message(self): - return ( - db.session.query(Message) - .where(Message.conversation_id == self.id) - .order_by(Message.created_at.asc()) - .first() + return db.session.scalar( + select(Message).where(Message.conversation_id == self.id).order_by(Message.created_at.asc()) ) @property def app(self) -> App | None: with Session(db.engine, expire_on_commit=False) as session: - return session.query(App).where(App.id == self.app_id).first() + return session.scalar(select(App).where(App.id == self.app_id)) @property def from_end_user_session_id(self): if self.from_end_user_id: - end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first() + end_user = db.session.scalar(select(EndUser).where(EndUser.id == self.from_end_user_id)) if end_user: return end_user.session_id @@ -993,7 +1296,7 @@ class Conversation(Base): @property def from_account_name(self) -> str | None: if self.from_account_id: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() + account = db.session.scalar(select(Account).where(Account.id == self.from_account_id)) if account: return account.name @@ -1003,7 +1306,7 @@ class Conversation(Base): def in_debug_mode(self) -> bool: return self.override_model_configs is not None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> ConversationDict: return { "id": self.id, "app_id": self.app_id, @@ -1068,11 +1371,18 @@ class Message(Base): provider_response_latency: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) total_price: Mapped[Decimal | None] = mapped_column(sa.Numeric(10, 7)) currency: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[MessageStatus] = mapped_column( + EnumText(MessageStatus, length=255), + nullable=False, + server_default=sa.text("'normal'"), + default=MessageStatus.NORMAL, + ) error: Mapped[str | None] = mapped_column(LongText) message_metadata: Mapped[str | None] = mapped_column(LongText) - invoke_from: Mapped[str | None] = mapped_column(String(255), nullable=True) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + invoke_from: Mapped[InvokeFrom | None] = mapped_column(EnumText(InvokeFrom, length=255), nullable=True) + from_source: Mapped[ConversationFromSource] = mapped_column( + EnumText(ConversationFromSource, length=255), nullable=False + ) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID) from_account_id: Mapped[str | None] = mapped_column(StringUUID) created_at: Mapped[datetime] = mapped_column(sa.DateTime, server_default=func.current_timestamp()) @@ -1081,7 +1391,7 @@ class Message(Base): ) agent_based: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) - app_mode: Mapped[str | None] = mapped_column(String(255), nullable=True) + app_mode: Mapped[AppMode | None] = mapped_column(EnumText(AppMode, length=255), nullable=True) @property def inputs(self) -> dict[str, Any]: @@ -1213,21 +1523,15 @@ class Message(Base): @property def user_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user") ) - return feedback @property def admin_feedback(self): - feedback = ( - db.session.query(MessageFeedback) - .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") - .first() + return db.session.scalar( + select(MessageFeedback).where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin") ) - return feedback @property def feedbacks(self): @@ -1236,28 +1540,27 @@ class Message(Base): @property def annotation(self): - annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first() + annotation = db.session.scalar(select(MessageAnnotation).where(MessageAnnotation.message_id == self.id)) return annotation @property def annotation_hit_history(self): - annotation_history = ( - db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first() + annotation_history = db.session.scalar( + select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id) ) if annotation_history: - annotation = ( - db.session.query(MessageAnnotation) - .where(MessageAnnotation.id == annotation_history.annotation_id) - .first() + return db.session.scalar( + select(MessageAnnotation).where(MessageAnnotation.id == annotation_history.annotation_id) ) - return annotation return None @property def app_model_config(self): - conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first() + conversation = db.session.scalar(select(Conversation).where(Conversation.id == self.conversation_id)) if conversation: - return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first() + return db.session.scalar( + select(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id) + ) return None @@ -1270,24 +1573,23 @@ class Message(Base): return json.loads(self.message_metadata) if self.message_metadata else {} @property - def agent_thoughts(self) -> list[MessageAgentThought]: - return ( - db.session.query(MessageAgentThought) + def agent_thoughts(self) -> Sequence[MessageAgentThought]: + return db.session.scalars( + select(MessageAgentThought) .where(MessageAgentThought.message_id == self.id) .order_by(MessageAgentThought.position.asc()) - .all() - ) + ).all() @property def retriever_resources(self) -> Any: return self.message_metadata_dict.get("retriever_resources") if self.message_metadata else [] @property - def message_files(self) -> list[dict[str, Any]]: + def message_files(self) -> list[MessageFileInfo]: from factories import file_factory message_files = db.session.scalars(select(MessageFile).where(MessageFile.message_id == self.id)).all() - current_app = db.session.query(App).where(App.id == self.app_id).first() + current_app = db.session.scalar(select(App).where(App.id == self.app_id)) if not current_app: raise ValueError(f"App {self.app_id} not found") @@ -1338,10 +1640,13 @@ class Message(Base): ) files.append(file) - result: list[dict[str, Any]] = [ - {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} - for (file, message_file) in zip(files, message_files) - ] + result = cast( + list[MessageFileInfo], + [ + {"belongs_to": message_file.belongs_to, "upload_file_id": message_file.upload_file_id, **file.to_dict()} + for (file, message_file) in zip(files, message_files) + ], + ) db.session.commit() return result @@ -1351,7 +1656,7 @@ class Message(Base): self._extra_contents = list(contents) @property - def extra_contents(self) -> list[dict[str, Any]]: + def extra_contents(self) -> list[ExtraContentDict]: return getattr(self, "_extra_contents", []) @property @@ -1367,7 +1672,7 @@ class Message(Base): return None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageDict: return { "id": self.id, "app_id": self.app_id, @@ -1391,7 +1696,7 @@ class Message(Base): } @classmethod - def from_dict(cls, data: dict[str, Any]) -> Message: + def from_dict(cls, data: MessageDict) -> Message: return cls( id=data["id"], app_id=data["app_id"], @@ -1430,8 +1735,8 @@ class MessageFeedback(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - rating: Mapped[str] = mapped_column(String(255), nullable=False) - from_source: Mapped[str] = mapped_column(String(255), nullable=False) + rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False) + from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False) content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) @@ -1448,10 +1753,9 @@ class MessageFeedback(TypeBase): @property def from_account(self) -> Account | None: - account = db.session.query(Account).where(Account.id == self.from_account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.from_account_id)) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> MessageFeedbackDict: return { "id": str(self.id), "app_id": str(self.app_id), @@ -1480,10 +1784,14 @@ class MessageFile(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) type: Mapped[str] = mapped_column(String(255), nullable=False) - transfer_method: Mapped[FileTransferMethod] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[CreatorUserRole] = mapped_column(String(255), nullable=False) + transfer_method: Mapped[FileTransferMethod] = mapped_column( + EnumText(FileTransferMethod, length=255), nullable=False + ) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) - belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None) + belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column( + EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None + ) url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( @@ -1520,13 +1828,11 @@ class MessageAnnotation(Base): @property def account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationHitHistory(TypeBase): @@ -1555,18 +1861,15 @@ class AppAnnotationHitHistory(TypeBase): @property def account(self): - account = ( - db.session.query(Account) + return db.session.scalar( + select(Account) .join(MessageAnnotation, MessageAnnotation.account_id == Account.id) .where(MessageAnnotation.id == self.annotation_id) - .first() ) - return account @property def annotation_create_account(self): - account = db.session.query(Account).where(Account.id == self.account_id).first() - return account + return db.session.scalar(select(Account).where(Account.id == self.account_id)) class AppAnnotationSetting(TypeBase): @@ -1599,12 +1902,9 @@ class AppAnnotationSetting(TypeBase): def collection_binding_detail(self): from .dataset import DatasetCollectionBinding - collection_binding_detail = ( - db.session.query(DatasetCollectionBinding) - .where(DatasetCollectionBinding.id == self.collection_binding_id) - .first() + return db.session.scalar( + select(DatasetCollectionBinding).where(DatasetCollectionBinding.id == self.collection_binding_id) ) - return collection_binding_detail class OperationLog(TypeBase): @@ -1690,7 +1990,9 @@ class AppMCPServer(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) description: Mapped[str] = mapped_column(String(255), nullable=False) server_code: Mapped[str] = mapped_column(String(255), nullable=False) - status: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppMCPServerStatus] = mapped_column( + EnumText(AppMCPServerStatus, length=255), nullable=False, server_default=sa.text("'normal'") + ) parameters: Mapped[str] = mapped_column(LongText, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -1708,14 +2010,16 @@ class AppMCPServer(TypeBase): def generate_server_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0: + while ( + db.session.scalar(select(func.count(AppMCPServer.id)).where(AppMCPServer.server_code == result)) or 0 + ) > 0: result = generate_string(n) return result @property - def parameters_dict(self) -> dict[str, Any]: - return cast(dict[str, Any], json.loads(self.parameters)) + def parameters_dict(self) -> dict[str, str]: + return cast(dict[str, str], json.loads(self.parameters)) class Site(Base): @@ -1729,7 +2033,7 @@ class Site(Base): id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=False) title: Mapped[str] = mapped_column(String(255), nullable=False) - icon_type = mapped_column(String(255), nullable=True) + icon_type: Mapped[IconType | None] = mapped_column(EnumText(IconType, length=255), nullable=True) icon = mapped_column(String(255)) icon_background = mapped_column(String(255)) description = mapped_column(LongText) @@ -1744,7 +2048,9 @@ class Site(Base): customize_domain = mapped_column(String(255)) customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) - status = mapped_column(String(255), nullable=False, server_default=sa.text("'normal'")) + status: Mapped[AppStatus] = mapped_column( + EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL + ) created_by = mapped_column(StringUUID, nullable=True) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) updated_by = mapped_column(StringUUID, nullable=True) @@ -1767,7 +2073,7 @@ class Site(Base): def generate_code(n: int) -> str: while True: result = generate_string(n) - while db.session.query(Site).where(Site.code == result).count() > 0: + while (db.session.scalar(select(func.count(Site.id)).where(Site.code == result)) or 0) > 0: result = generate_string(n) return result @@ -1815,7 +2121,7 @@ class UploadFile(Base): # The `server_default` serves as a fallback mechanism. id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - storage_type: Mapped[str] = mapped_column(String(255), nullable=False) + storage_type: Mapped[StorageType] = mapped_column(EnumText(StorageType, length=255), nullable=False) key: Mapped[str] = mapped_column(String(255), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) size: Mapped[int] = mapped_column(sa.Integer, nullable=False) @@ -1824,7 +2130,12 @@ class UploadFile(Base): # The `created_by_role` field indicates whether the file was created by an `Account` or an `EndUser`. # Its value is derived from the `CreatorUserRole` enumeration. - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'account'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), + nullable=False, + server_default=sa.text("'account'"), + default=CreatorUserRole.ACCOUNT, + ) # The `created_by` field stores the ID of the entity that created this upload file. # @@ -1854,7 +2165,7 @@ class UploadFile(Base): self, *, tenant_id: str, - storage_type: str, + storage_type: StorageType, key: str, name: str, size: int, @@ -1877,7 +2188,7 @@ class UploadFile(Base): self.size = size self.extension = extension self.mime_type = mime_type - self.created_by_role = created_by_role.value + self.created_by_role = created_by_role self.created_by = created_by self.created_at = created_at self.used = used @@ -1919,7 +2230,7 @@ class MessageChain(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[MessageChainType] = mapped_column(EnumText(MessageChainType, length=255), nullable=False) input: Mapped[str | None] = mapped_column(LongText, nullable=True) output: Mapped[str | None] = mapped_column(LongText, nullable=True) created_at: Mapped[datetime] = mapped_column( @@ -1940,7 +2251,7 @@ class MessageAgentThought(TypeBase): ) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) position: Mapped[int] = mapped_column(sa.Integer, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) message_chain_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) thought: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) @@ -2094,7 +2405,7 @@ class Tag(TypeBase): StringUUID, insert_default=lambda: str(uuid4()), default_factory=lambda: str(uuid4()), init=False ) tenant_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) - type: Mapped[str] = mapped_column(String(16), nullable=False) + type: Mapped[TagType] = mapped_column(EnumText(TagType, length=16), nullable=False) name: Mapped[str] = mapped_column(String(255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( @@ -2155,7 +2466,7 @@ class TraceAppConfig(TypeBase): def tracing_config_str(self) -> str: return json.dumps(self.tracing_config_dict) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> TraceAppConfigDict: return { "id": self.id, "app_id": self.app_id, diff --git a/api/models/provider.py b/api/models/provider.py index 6175a3ae88..afeee20b1e 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -6,14 +6,15 @@ from functools import cached_property from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func, text +from sqlalchemy import DateTime, String, func, select, text from sqlalchemy.orm import Mapped, mapped_column from libs.uuid_utils import uuidv7 from .base import TypeBase from .engine import db -from .types import LongText, StringUUID +from .enums import CredentialSourceType, PaymentStatus +from .types import EnumText, LongText, StringUUID class ProviderType(StrEnum): @@ -69,8 +70,8 @@ class Provider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - provider_type: Mapped[str] = mapped_column( - String(40), nullable=False, server_default=text("'custom'"), default="custom" + provider_type: Mapped[ProviderType] = mapped_column( + EnumText(ProviderType, length=40), nullable=False, server_default=text("'custom'"), default=ProviderType.CUSTOM ) is_valid: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("false"), default=False) last_used: Mapped[datetime | None] = mapped_column(DateTime, nullable=True, init=False) @@ -96,7 +97,7 @@ class Provider(TypeBase): @cached_property def credential(self): if self.credential_id: - return db.session.query(ProviderCredential).where(ProviderCredential.id == self.credential_id).first() + return db.session.scalar(select(ProviderCredential).where(ProviderCredential.id == self.credential_id)) @property def credential_name(self): @@ -159,10 +160,8 @@ class ProviderModel(TypeBase): @cached_property def credential(self): if self.credential_id: - return ( - db.session.query(ProviderModelCredential) - .where(ProviderModelCredential.id == self.credential_id) - .first() + return db.session.scalar( + select(ProviderModelCredential).where(ProviderModelCredential.id == self.credential_id) ) @property @@ -211,7 +210,7 @@ class TenantPreferredModelProvider(TypeBase): ) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) provider_name: Mapped[str] = mapped_column(String(255), nullable=False) - preferred_provider_type: Mapped[str] = mapped_column(String(40), nullable=False) + preferred_provider_type: Mapped[ProviderType] = mapped_column(EnumText(ProviderType, length=40), nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) @@ -239,7 +238,9 @@ class ProviderOrder(TypeBase): quantity: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=text("1")) currency: Mapped[str | None] = mapped_column(String(40)) total_amount: Mapped[int | None] = mapped_column(sa.Integer) - payment_status: Mapped[str] = mapped_column(String(40), nullable=False, server_default=text("'wait_pay'")) + payment_status: Mapped[PaymentStatus] = mapped_column( + EnumText(PaymentStatus, length=40), nullable=False, server_default=text("'wait_pay'") + ) paid_at: Mapped[datetime | None] = mapped_column(DateTime) pay_failed_at: Mapped[datetime | None] = mapped_column(DateTime) refunded_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -302,7 +303,9 @@ class LoadBalancingModelConfig(TypeBase): name: Mapped[str] = mapped_column(String(255), nullable=False) encrypted_config: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) credential_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) - credential_source_type: Mapped[str | None] = mapped_column(String(40), nullable=True, default=None) + credential_source_type: Mapped[CredentialSourceType | None] = mapped_column( + EnumText(CredentialSourceType, length=40), nullable=True, default=None + ) enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=text("true"), default=True) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False diff --git a/api/models/tools.py b/api/models/tools.py index e7b98dcf27..01182af867 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -8,17 +8,21 @@ from uuid import uuid4 import sqlalchemy as sa from deprecated import deprecated -from sqlalchemy import ForeignKey, String, func +from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle -from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration +from core.tools.entities.tool_entities import ( + ApiProviderSchemaType, + ToolProviderType, + WorkflowToolParameterConfiguration, +) from .base import TypeBase from .engine import db from .model import Account, App, Tenant -from .types import LongText, StringUUID +from .types import EnumText, LongText, StringUUID if TYPE_CHECKING: from core.entities.mcp_provider import MCPProviderEntity @@ -184,11 +188,11 @@ class ApiToolProvider(TypeBase): def user(self) -> Account | None: if not self.user_id: return None - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) class ToolLabelBinding(TypeBase): @@ -208,7 +212,7 @@ class ToolLabelBinding(TypeBase): # tool id tool_id: Mapped[str] = mapped_column(String(64), nullable=False) # tool type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # label name label_name: Mapped[str] = mapped_column(String(40), nullable=False) @@ -262,11 +266,11 @@ class WorkflowToolProvider(TypeBase): @property def user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def tenant(self) -> Tenant | None: - return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first() + return db.session.scalar(select(Tenant).where(Tenant.id == self.tenant_id)) @property def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]: @@ -277,7 +281,7 @@ class WorkflowToolProvider(TypeBase): @property def app(self) -> App | None: - return db.session.query(App).where(App.id == self.app_id).first() + return db.session.scalar(select(App).where(App.id == self.app_id)) class MCPToolProvider(TypeBase): @@ -334,7 +338,7 @@ class MCPToolProvider(TypeBase): encrypted_headers: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None) def load_user(self) -> Account | None: - return db.session.query(Account).where(Account.id == self.user_id).first() + return db.session.scalar(select(Account).where(Account.id == self.user_id)) @property def credentials(self) -> dict[str, Any]: @@ -386,7 +390,7 @@ class ToolModelInvoke(TypeBase): # provider provider: Mapped[str] = mapped_column(String(255), nullable=False) # type - tool_type: Mapped[str] = mapped_column(String(40), nullable=False) + tool_type: Mapped[ToolProviderType] = mapped_column(EnumText(ToolProviderType, length=40), nullable=False) # tool name tool_name: Mapped[str] = mapped_column(String(128), nullable=False) # invoke parameters diff --git a/api/models/trigger.py b/api/models/trigger.py index 209345eb84..627b854060 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -3,7 +3,7 @@ import time from collections.abc import Mapping from datetime import datetime from functools import cached_property -from typing import Any, cast +from typing import Any, TypedDict, cast from uuid import uuid4 import sqlalchemy as sa @@ -23,6 +23,47 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr from .model import Account from .types import EnumText, LongText, StringUUID +TriggerJsonObject = dict[str, object] +TriggerCredentials = dict[str, str] + + +class WorkflowTriggerLogDict(TypedDict): + id: str + tenant_id: str + app_id: str + workflow_id: str + workflow_run_id: str | None + root_node_id: str | None + trigger_metadata: Any + trigger_type: str + trigger_data: Any + inputs: Any + outputs: Any + status: str + error: str | None + queue_name: str + celery_task_id: str | None + retry_count: int + elapsed_time: float | None + total_tokens: int | None + created_by_role: str + created_by: str + created_at: str | None + triggered_at: str | None + finished_at: str | None + + +class WorkflowSchedulePlanDict(TypedDict): + id: str + app_id: str + node_id: str + tenant_id: str + cron_expression: str + timezone: str + next_run_at: str | None + created_at: str + updated_at: str + class TriggerSubscription(TypeBase): """ @@ -51,10 +92,14 @@ class TriggerSubscription(TypeBase): String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)" ) endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint") - parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON") - properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON") + parameters: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription parameters JSON" + ) + properties: Mapped[TriggerJsonObject] = mapped_column( + sa.JSON, nullable=False, comment="Subscription properties JSON" + ) - credentials: Mapped[dict[str, Any]] = mapped_column( + credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") @@ -162,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase): ) @property - def oauth_params(self) -> Mapping[str, Any]: - return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}")) + def oauth_params(self) -> Mapping[str, object]: + return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}")) class WorkflowTriggerLog(TypeBase): @@ -227,7 +272,7 @@ class WorkflowTriggerLog(TypeBase): queue_name: Mapped[str] = mapped_column(String(100), nullable=False) celery_task_id: Mapped[str | None] = mapped_column(String(255), nullable=True) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(String(255), nullable=False) retry_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, default=0) elapsed_time: Mapped[float | None] = mapped_column(sa.Float, nullable=True, default=None) @@ -250,7 +295,7 @@ class WorkflowTriggerLog(TypeBase): created_by_role = CreatorUserRole(self.created_by_role) return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> WorkflowTriggerLogDict: """Convert to dictionary for API responses""" return { "id": self.id, @@ -481,7 +526,7 @@ class WorkflowSchedulePlan(TypeBase): DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp(), init=False ) - def to_dict(self) -> dict[str, Any]: + def to_dict(self) -> WorkflowSchedulePlanDict: """Convert to dictionary representation""" return { "id": self.id, diff --git a/api/models/web.py b/api/models/web.py index 5f6a7b40bf..1fb37340d7 100644 --- a/api/models/web.py +++ b/api/models/web.py @@ -2,13 +2,14 @@ from datetime import datetime from uuid import uuid4 import sqlalchemy as sa -from sqlalchemy import DateTime, String, func +from sqlalchemy import DateTime, func, select from sqlalchemy.orm import Mapped, mapped_column from .base import TypeBase from .engine import db +from .enums import CreatorUserRole from .model import Message -from .types import StringUUID +from .types import EnumText, StringUUID class SavedMessage(TypeBase): @@ -24,7 +25,9 @@ class SavedMessage(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) message_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False, server_default=sa.text("'end_user'")) + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'") + ) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, @@ -35,7 +38,7 @@ class SavedMessage(TypeBase): @property def message(self): - return db.session.query(Message).where(Message.id == self.message_id).first() + return db.session.scalar(select(Message).where(Message.id == self.message_id)) class PinnedConversation(TypeBase): @@ -50,8 +53,8 @@ class PinnedConversation(TypeBase): ) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) conversation_id: Mapped[str] = mapped_column(StringUUID) - created_by_role: Mapped[str] = mapped_column( - String(255), + created_by_role: Mapped[CreatorUserRole] = mapped_column( + EnumText(CreatorUserRole, length=255), nullable=False, server_default=sa.text("'end_user'"), ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 3af7ca236d..0b13c0a074 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,9 +1,10 @@ +import copy import json import logging from collections.abc import Generator, Mapping, Sequence from datetime import datetime from enum import StrEnum -from typing import TYPE_CHECKING, Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa @@ -19,20 +20,21 @@ from sqlalchemy import ( orm, select, ) -from sqlalchemy.orm import Mapped, declared_attr, mapped_column +from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated -from core.workflow.constants import ( +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from core.workflow.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import NodeType, WorkflowExecutionStatus -from core.workflow.file.constants import maybe_file_object -from core.workflow.file.models import File -from core.workflow.variables import utils as variable_utils -from core.workflow.variables.variables import FloatVariable, IntegerVariable, StringVariable +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey +from dify_graph.file.constants import maybe_file_object +from dify_graph.file.models import File +from dify_graph.variables import utils as variable_utils +from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -46,18 +48,37 @@ if TYPE_CHECKING: from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.workflow.variables import SecretVariable, Segment, SegmentType, VariableBase +from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper from .account import Account from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db -from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType +from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID logger = logging.getLogger(__name__) +SerializedWorkflowValue = dict[str, Any] +SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] + + +class WorkflowContentDict(TypedDict): + graph: Mapping[str, Any] + features: dict[str, Any] + environment_variables: list[dict[str, Any]] + conversation_variables: list[dict[str, Any]] + rag_pipeline_variables: list[dict[str, Any]] + + +class WorkflowRunSummaryDict(TypedDict): + id: str + status: str + triggered_from: str + elapsed_time: float + total_tokens: int + class WorkflowType(StrEnum): """ @@ -142,7 +163,7 @@ class Workflow(Base): # bug id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - type: Mapped[str] = mapped_column(String(255), nullable=False) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255), nullable=False) version: Mapped[str] = mapped_column(String(255), nullable=False) marked_name: Mapped[str] = mapped_column(String(255), default="", server_default="") marked_comment: Mapped[str] = mapped_column(String(255), default="", server_default="") @@ -189,7 +210,7 @@ class Workflow(Base): # bug workflow.id = str(uuid4()) workflow.tenant_id = tenant_id workflow.app_id = app_id - workflow.type = type + workflow.type = WorkflowType(type) workflow.version = version workflow.graph = graph workflow.features = features @@ -234,8 +255,11 @@ class Workflow(Base): # bug def get_node_config_by_id(self, node_id: str) -> NodeConfigDict: """Extract a node configuration from the workflow graph by node ID. - A node configuration is a dictionary containing the node's properties, including - the node's id, title, and its data as a dict. + + A node configuration includes the node id and a typed `BaseNodeData` for `data`. + `BaseNodeData` keeps a dict-like `get`/`__getitem__` compatibility layer backed by + model fields plus Pydantic extra storage for legacy consumers, but callers should + prefer attribute access. """ workflow_graph = self.graph_dict @@ -253,12 +277,9 @@ class Workflow(Base): # bug return NodeConfigDictAdapter.validate_python(node_config) @staticmethod - def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType: + def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: """Extract type of a node from the node configuration returned by `get_node_config_by_id`.""" - node_config_data = node_config.get("data", {}) - # Get node class - node_type = NodeType(node_config_data.get("type")) - return node_type + return node_config["data"].type @staticmethod def get_enclosing_node_type_and_id( @@ -270,12 +291,12 @@ class Workflow(Base): # bug loop_id = node_config.get("loop_id") if loop_id is None: raise _InvalidGraphDefinitionError("invalid graph") - return NodeType.LOOP, loop_id + return BuiltinNodeTypes.LOOP, loop_id elif in_iteration: iteration_id = node_config.get("iteration_id") if iteration_id is None: raise _InvalidGraphDefinitionError("invalid graph") - return NodeType.ITERATION, iteration_id + return BuiltinNodeTypes.ITERATION, iteration_id else: return None @@ -283,26 +304,40 @@ class Workflow(Base): # bug def features(self) -> str: """ Convert old features structure to new features structure. + + This property avoids rewriting the underlying JSON when normalization + produces no effective change, to prevent marking the row dirty on read. """ if not self._features: return self._features - features = json.loads(self._features) - if features.get("file_upload", {}).get("image", {}).get("enabled", False): - image_enabled = True - image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) - image_transfer_methods = features["file_upload"]["image"].get( - "transfer_methods", ["remote_url", "local_file"] - ) - features["file_upload"]["enabled"] = image_enabled - features["file_upload"]["number_limits"] = image_number_limits - features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods - features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) - features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( - "allowed_file_extensions", [] - ) - del features["file_upload"]["image"] - self._features = json.dumps(features) + # Parse once and deep-copy before normalization to detect in-place changes. + original_dict = self._decode_features_payload(self._features) + if original_dict is None: + return self._features + + # Fast-path: if the legacy file_upload.image.enabled shape is absent, skip + # deep-copy and normalization entirely and return the stored JSON. + file_upload_payload = original_dict.get("file_upload") + if not isinstance(file_upload_payload, dict): + return self._features + file_upload = cast(dict[str, Any], file_upload_payload) + + image_payload = file_upload.get("image") + if not isinstance(image_payload, dict): + return self._features + image = cast(dict[str, Any], image_payload) + if "enabled" not in image: + return self._features + + normalized_dict = self._normalize_features_payload(copy.deepcopy(original_dict)) + + if normalized_dict == original_dict: + # No effective change; return stored JSON unchanged. + return self._features + + # Normalization changed the payload: persist the normalized JSON. + self._features = json.dumps(normalized_dict) return self._features @features.setter @@ -313,6 +348,44 @@ class Workflow(Base): # bug def features_dict(self) -> dict[str, Any]: return json.loads(self.features) if self.features else {} + @property + def serialized_features(self) -> str: + """Return the stored features JSON without triggering compatibility rewrites.""" + return self._features + + @property + def normalized_features_dict(self) -> dict[str, Any]: + """Decode features with legacy normalization without mutating the model state.""" + if not self._features: + return {} + + features = self._decode_features_payload(self._features) + return self._normalize_features_payload(features) if features is not None else {} + + @staticmethod + def _decode_features_payload(features: str) -> dict[str, Any] | None: + """Decode workflow features JSON when it contains an object payload.""" + payload = json.loads(features) + return cast(dict[str, Any], payload) if isinstance(payload, dict) else None + + @staticmethod + def _normalize_features_payload(features: dict[str, Any]) -> dict[str, Any]: + if features.get("file_upload", {}).get("image", {}).get("enabled", False): + image_number_limits = int(features["file_upload"]["image"].get("number_limits", DEFAULT_FILE_NUMBER_LIMITS)) + image_transfer_methods = features["file_upload"]["image"].get( + "transfer_methods", ["remote_url", "local_file"] + ) + features["file_upload"]["enabled"] = True + features["file_upload"]["number_limits"] = image_number_limits + features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods + features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) + features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( + "allowed_file_extensions", [] + ) + del features["file_upload"]["image"] + + return features + def walk_nodes( self, specific_node_type: NodeType | None = None ) -> Generator[tuple[str, Mapping[str, Any]], None, None]: @@ -346,7 +419,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `core.workflow.nodes` + For specific node type, refer to `dify_graph.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -354,9 +427,7 @@ class Workflow(Base): # bug if specific_node_type: yield from ( - (node["id"], node["data"]) - for node in graph_dict["nodes"] - if node["data"]["type"] == specific_node_type.value + (node["id"], node["data"]) for node in graph_dict["nodes"] if node["data"]["type"] == specific_node_type ) else: yield from ((node["id"], node["data"]) for node in graph_dict["nodes"]) @@ -391,7 +462,7 @@ class Workflow(Base): # bug def rag_pipeline_user_input_form(self) -> list: # get user_input_form from start node - variables: list[Any] = self.rag_pipeline_variables + variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables return variables @@ -434,17 +505,13 @@ class Workflow(Base): # bug def environment_variables( self, ) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: - # TODO: find some way to init `self._environment_variables` when instance created. - if self._environment_variables is None: - self._environment_variables = "{}" - # Use workflow.tenant_id to avoid relying on request user in background threads tenant_id = self.tenant_id if not tenant_id: return [] - environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}") + environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}")) results = [ variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() ] @@ -504,14 +571,39 @@ class Workflow(Base): # bug ) self._environment_variables = environment_variables_json - def to_dict(self, *, include_secret: bool = False) -> Mapping[str, Any]: + @staticmethod + def normalize_environment_variable_mappings( + mappings: Sequence[Mapping[str, Any]], + ) -> list[dict[str, Any]]: + """Convert masked secret placeholders into the draft hidden sentinel. + + Regular draft sync requests should preserve existing secrets without shipping + plaintext values back from the client. The dedicated restore endpoint now + copies published secrets server-side, so draft sync only needs to normalize + the UI mask into `HIDDEN_VALUE`. + """ + masked_secret_value = encrypter.full_mask_token() + normalized_mappings: list[dict[str, Any]] = [] + + for mapping in mappings: + normalized_mapping = dict(mapping) + if ( + normalized_mapping.get("value_type") == SegmentType.SECRET.value + and normalized_mapping.get("value") == masked_secret_value + ): + normalized_mapping["value"] = HIDDEN_VALUE + normalized_mappings.append(normalized_mapping) + + return normalized_mappings + + def to_dict(self, *, include_secret: bool = False) -> WorkflowContentDict: environment_variables = list(self.environment_variables) environment_variables = [ v if not isinstance(v, SecretVariable) or include_secret else v.model_copy(update={"value": ""}) for v in environment_variables ] - result = { + result: WorkflowContentDict = { "graph": self.graph_dict, "features": self.features_dict, "environment_variables": [var.model_dump(mode="json") for var in environment_variables], @@ -522,11 +614,7 @@ class Workflow(Base): # bug @property def conversation_variables(self) -> Sequence[VariableBase]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._conversation_variables is None: - self._conversation_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._conversation_variables) + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}")) results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] return results @@ -538,22 +626,29 @@ class Workflow(Base): # bug ) @property - def rag_pipeline_variables(self) -> list[dict]: - # TODO: find some way to init `self._conversation_variables` when instance created. - if self._rag_pipeline_variables is None: - self._rag_pipeline_variables = "{}" - - variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables) - results = list(variables_dict.values()) - return results + def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]: + variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}")) + return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()] @rag_pipeline_variables.setter - def rag_pipeline_variables(self, values: list[dict]) -> None: + def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None: self._rag_pipeline_variables = json.dumps( - {item["variable"]: item for item in values}, + { + rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json") + for rag_pipeline_variable in ( + item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item) + for item in values + ) + }, ensure_ascii=False, ) + def copy_serialized_variable_storage_from(self, source_workflow: "Workflow") -> None: + """Copy stored variable JSON directly for same-tenant restore flows.""" + self._environment_variables = source_workflow._environment_variables + self._conversation_variables = source_workflow._conversation_variables + self._rag_pipeline_variables = source_workflow._rag_pipeline_variables + @staticmethod def version_from_datetime(d: datetime) -> str: return str(d) @@ -609,8 +704,8 @@ class WorkflowRun(Base): app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - type: Mapped[str] = mapped_column(String(255)) - triggered_from: Mapped[str] = mapped_column(String(255)) + type: Mapped[WorkflowType] = mapped_column(EnumText(WorkflowType, length=255)) + triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column(EnumText(WorkflowRunTriggeredFrom, length=255)) version: Mapped[str] = mapped_column(String(255)) graph: Mapped[str | None] = mapped_column(LongText) inputs: Mapped[str | None] = mapped_column(LongText) @@ -669,14 +764,14 @@ class WorkflowRun(Base): def message(self): from .model import Message - return ( - db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first() + return db.session.scalar( + select(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id) ) @property @deprecated("This method is retained for historical reasons; avoid using it if possible.") def workflow(self): - return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first() + return db.session.scalar(select(Workflow).where(Workflow.id == self.workflow_id)) def to_dict(self): return { @@ -788,50 +883,44 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo __tablename__ = "workflow_node_executions" - @declared_attr.directive - @classmethod - def __table_args__(cls) -> Any: - return ( - PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), - Index( - "workflow_node_execution_workflow_run_id_idx", - "workflow_run_id", - ), - Index( - "workflow_node_execution_node_run_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_id", - ), - Index( - "workflow_node_execution_id_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_execution_id", - ), - Index( - # The first argument is the index name, - # which we leave as `None`` to allow auto-generation by the ORM. - None, - cls.tenant_id, - cls.workflow_id, - cls.node_id, - # MyPy may flag the following line because it doesn't recognize that - # the `declared_attr` decorator passes the receiving class as the first - # argument to this method, allowing us to reference class attributes. - cls.created_at.desc(), - ), - ) + __table_args__ = ( + PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + Index( + "workflow_node_execution_workflow_run_id_idx", + "workflow_run_id", + ), + Index( + "workflow_node_execution_node_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_id", + ), + Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), + Index( + None, + "tenant_id", + "workflow_id", + "node_id", + sa.desc("created_at"), + ), + ) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4())) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) workflow_id: Mapped[str] = mapped_column(StringUUID) - triggered_from: Mapped[str] = mapped_column(String(255)) + triggered_from: Mapped[WorkflowNodeExecutionTriggeredFrom] = mapped_column( + EnumText(WorkflowNodeExecutionTriggeredFrom, length=255) + ) workflow_run_id: Mapped[str | None] = mapped_column(StringUUID) index: Mapped[int] = mapped_column(sa.Integer) predecessor_node_id: Mapped[str | None] = mapped_column(String(255)) @@ -847,7 +936,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) created_at: Mapped[datetime] = mapped_column(DateTime, server_default=func.current_timestamp()) - created_by_role: Mapped[str] = mapped_column(String(255)) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) created_by: Mapped[str] = mapped_column(StringUUID) finished_at: Mapped[datetime | None] = mapped_column(DateTime) @@ -922,18 +1011,21 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo extras: dict[str, Any] = {} execution_metadata = self.execution_metadata_dict if execution_metadata: - if self.node_type == NodeType.TOOL and "tool_info" in execution_metadata: + if self.node_type == BuiltinNodeTypes.TOOL and "tool_info" in execution_metadata: tool_info: dict[str, Any] = execution_metadata["tool_info"] extras["icon"] = ToolManager.get_tool_icon( tenant_id=self.tenant_id, provider_type=tool_info["provider_type"], provider_id=tool_info["provider_id"], ) - elif self.node_type == NodeType.DATASOURCE and "datasource_info" in execution_metadata: + elif self.node_type == BuiltinNodeTypes.DATASOURCE and "datasource_info" in execution_metadata: datasource_info = execution_metadata["datasource_info"] extras["icon"] = datasource_info.get("icon") - elif self.node_type == NodeType.TRIGGER_PLUGIN and "trigger_info" in execution_metadata: - trigger_info = execution_metadata["trigger_info"] or {} + elif ( + self.node_type == TRIGGER_PLUGIN_NODE_TYPE + and WorkflowNodeExecutionMetadataKey.TRIGGER_INFO in execution_metadata + ): + trigger_info = execution_metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] or {} provider_id = trigger_info.get("provider_id") if provider_id: extras["icon"] = TriggerManager.get_trigger_plugin_icon( @@ -1131,7 +1223,7 @@ class WorkflowAppLog(TypeBase): workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) created_from: Mapped[str] = mapped_column(String(255), nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) created_at: Mapped[datetime] = mapped_column( DateTime, nullable=False, server_default=func.current_timestamp(), init=False @@ -1205,7 +1297,7 @@ class WorkflowArchiveLog(TypeBase): app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - created_by_role: Mapped[str] = mapped_column(String(255), nullable=False) + created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False) created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) log_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True) @@ -1214,7 +1306,9 @@ class WorkflowArchiveLog(TypeBase): run_version: Mapped[str] = mapped_column(String(255), nullable=False) run_status: Mapped[str] = mapped_column(String(255), nullable=False) - run_triggered_from: Mapped[str] = mapped_column(String(255), nullable=False) + run_triggered_from: Mapped[WorkflowRunTriggeredFrom] = mapped_column( + EnumText(WorkflowRunTriggeredFrom, length=255), nullable=False + ) run_error: Mapped[str | None] = mapped_column(LongText, nullable=True) run_elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0")) run_total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) @@ -1229,7 +1323,7 @@ class WorkflowArchiveLog(TypeBase): ) @property - def workflow_run_summary(self) -> dict[str, Any]: + def workflow_run_summary(self) -> WorkflowRunSummaryDict: return { "id": self.workflow_run_id, "status": self.run_status, @@ -1284,16 +1378,17 @@ class WorkflowDraftVariable(Base): """ @staticmethod - def unique_app_id_node_id_name() -> list[str]: + def unique_app_id_user_id_node_id_name() -> list[str]: return [ "app_id", + "user_id", "node_id", "name", ] __tablename__ = "workflow_draft_variables" __table_args__ = ( - UniqueConstraint(*unique_app_id_node_id_name()), + UniqueConstraint(*unique_app_id_user_id_node_id_name()), Index("workflow_draft_variable_file_id_idx", "file_id"), ) # Required for instance variable annotation. @@ -1319,6 +1414,11 @@ class WorkflowDraftVariable(Base): # "`app_id` maps to the `id` field in the `model.App` model." app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + # Owner of this draft variable. + # + # This field is nullable during migration and will be migrated to NOT NULL + # in a follow-up release. + user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None) # `last_edited_at` records when the value of a given draft variable # is edited. @@ -1345,7 +1445,7 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. # - # ref: api/core/workflow/entities/variable_pool.py:18 + # ref: api/dify_graph/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), @@ -1571,6 +1671,7 @@ class WorkflowDraftVariable(Base): cls, *, app_id: str, + user_id: str | None, node_id: str, name: str, value: Segment, @@ -1584,6 +1685,7 @@ class WorkflowDraftVariable(Base): variable.updated_at = naive_utc_now() variable.description = description variable.app_id = app_id + variable.user_id = user_id variable.node_id = node_id variable.name = name variable.set_value(value) @@ -1597,12 +1699,14 @@ class WorkflowDraftVariable(Base): cls, *, app_id: str, + user_id: str | None = None, name: str, value: Segment, description: str = "", ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, + user_id=user_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, value=value, @@ -1617,6 +1721,7 @@ class WorkflowDraftVariable(Base): cls, *, app_id: str, + user_id: str | None = None, name: str, value: Segment, node_execution_id: str, @@ -1624,6 +1729,7 @@ class WorkflowDraftVariable(Base): ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, + user_id=user_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, node_execution_id=node_execution_id, @@ -1637,6 +1743,7 @@ class WorkflowDraftVariable(Base): cls, *, app_id: str, + user_id: str | None = None, node_id: str, name: str, value: Segment, @@ -1647,6 +1754,7 @@ class WorkflowDraftVariable(Base): ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, + user_id=user_id, node_id=node_id, name=name, node_execution_id=node_execution_id, diff --git a/api/pyproject.toml b/api/pyproject.toml index a314eab737..7e70bc76d3 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,99 +1,98 @@ [project] name = "dify-api" -version = "1.13.0" +version = "1.13.2" requires-python = ">=3.11,<3.13" dependencies = [ "aliyun-log-python-sdk~=0.9.37", - "arize-phoenix-otel~=0.9.2", - "azure-identity==1.16.1", - "beautifulsoup4==4.12.2", - "boto3==1.35.99", + "arize-phoenix-otel~=0.15.0", + "azure-identity==1.25.3", + "beautifulsoup4==4.14.3", + "boto3==1.42.73", "bs4~=0.0.1", "cachetools~=5.3.0", - "celery~=5.5.2", + "celery~=5.6.2", "charset-normalizer>=3.4.4", "flask~=3.1.2", - "flask-compress>=1.17,<1.18", + "flask-compress>=1.17,<1.24", "flask-cors~=6.0.0", "flask-login~=0.6.3", - "flask-migrate~=4.0.7", + "flask-migrate~=4.1.0", "flask-orjson~=2.0.0", "flask-sqlalchemy~=3.1.1", "gevent~=25.9.1", "gmpy2~=2.3.0", - "google-api-core==2.18.0", - "google-api-python-client==2.189.0", - "google-auth==2.29.0", - "google-auth-httplib2==0.2.0", - "google-cloud-aiplatform==1.49.0", - "googleapis-common-protos==1.63.0", - "gunicorn~=23.0.0", - "httpx[socks]~=0.27.0", + "google-api-core>=2.19.1", + "google-api-python-client==2.193.0", + "google-auth>=2.47.0", + "google-auth-httplib2==0.3.0", + "google-cloud-aiplatform>=1.123.0", + "googleapis-common-protos>=1.65.0", + "gunicorn~=25.1.0", + "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", "jsonschema>=4.25.1", "langfuse~=2.51.3", - "langsmith~=0.1.77", - "markdown~=3.5.1", + "langsmith~=0.7.16", + "markdown~=3.10.2", "mlflow-skinny>=3.0.0", "numpy~=1.26.4", "openpyxl~=3.1.5", - "opik~=1.8.72", - "litellm==1.77.1", # Pinned to avoid madoka dependency issue - "opentelemetry-api==1.27.0", - "opentelemetry-distro==0.48b0", - "opentelemetry-exporter-otlp==1.27.0", - "opentelemetry-exporter-otlp-proto-common==1.27.0", - "opentelemetry-exporter-otlp-proto-grpc==1.27.0", - "opentelemetry-exporter-otlp-proto-http==1.27.0", - "opentelemetry-instrumentation==0.48b0", - "opentelemetry-instrumentation-celery==0.48b0", - "opentelemetry-instrumentation-flask==0.48b0", - "opentelemetry-instrumentation-httpx==0.48b0", - "opentelemetry-instrumentation-redis==0.48b0", - "opentelemetry-instrumentation-httpx==0.48b0", - "opentelemetry-instrumentation-sqlalchemy==0.48b0", - "opentelemetry-propagator-b3==1.27.0", - # opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0), - # which is conflict with googleapis-common-protos (1.63.0) - "opentelemetry-proto==1.27.0", - "opentelemetry-sdk==1.27.0", - "opentelemetry-semantic-conventions==0.48b0", - "opentelemetry-util-http==0.48b0", - "pandas[excel,output-formatting,performance]~=2.2.2", + "opik~=1.10.37", + "litellm==1.82.6", # Pinned to avoid madoka dependency issue + "opentelemetry-api==1.28.0", + "opentelemetry-distro==0.49b0", + "opentelemetry-exporter-otlp==1.28.0", + "opentelemetry-exporter-otlp-proto-common==1.28.0", + "opentelemetry-exporter-otlp-proto-grpc==1.28.0", + "opentelemetry-exporter-otlp-proto-http==1.28.0", + "opentelemetry-instrumentation==0.49b0", + "opentelemetry-instrumentation-celery==0.49b0", + "opentelemetry-instrumentation-flask==0.49b0", + "opentelemetry-instrumentation-httpx==0.49b0", + "opentelemetry-instrumentation-redis==0.49b0", + "opentelemetry-instrumentation-sqlalchemy==0.49b0", + "opentelemetry-propagator-b3==1.40.0", + "opentelemetry-proto==1.28.0", + "opentelemetry-sdk==1.28.0", + "opentelemetry-semantic-conventions==0.49b0", + "opentelemetry-util-http==0.49b0", + "pandas[excel,output-formatting,performance]~=3.0.1", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.10.3", - "pydantic-settings~=2.12.0", - "pyjwt~=2.11.0", - "pypdfium2==5.2.0", + "pydantic-extra-types~=2.11.0", + "pydantic-settings~=2.13.1", + "pyjwt~=2.12.0", + "pypdfium2==5.6.0", "python-docx~=1.2.0", - "python-dotenv==1.0.1", + "python-dotenv==1.2.2", "pyyaml~=6.0.1", "readabilipy~=0.3.0", - "redis[hiredis]~=7.2.0", - "resend~=2.9.0", - "sentry-sdk[flask]~=2.28.0", + "redis[hiredis]~=7.3.0", + "resend~=2.26.0", + "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", - "starlette==0.49.1", - "tiktoken~=0.9.0", - "transformers~=4.56.1", - "unstructured[docx,epub,md,ppt,pptx]~=0.18.18", - "yarl~=1.18.3", + "starlette==1.0.0", + "tiktoken~=0.12.0", + "transformers~=5.3.0", + "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", + "pypandoc~=1.13", + "yarl~=1.23.0", "webvtt-py~=0.5.1", - "sseclient-py~=1.8.0", + "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", "flask-restx~=1.3.2", "packaging~=23.2", "croniter>=6.0.0", - "weaviate-client==4.17.0", + "weaviate-client==4.20.4", "apscheduler>=3.11.0", "weave>=0.52.16", "fastopenapi[flask]>=0.7.0", + "bleach~=6.3.0", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -112,46 +111,46 @@ package = false # Required for development and running tests ############################################################ dev = [ - "coverage~=7.2.4", - "dotenv-linter~=0.5.0", - "faker~=38.2.0", + "coverage~=7.13.4", + "dotenv-linter~=0.7.0", + "faker~=40.11.0", "lxml-stubs~=0.5.1", - "basedpyright~=1.31.0", - "ruff~=0.14.0", - "pytest~=8.3.2", - "pytest-benchmark~=4.0.0", - "pytest-cov~=4.1.0", - "pytest-env~=1.1.3", - "pytest-mock~=3.14.0", - "testcontainers~=4.13.2", + "basedpyright~=1.38.2", + "ruff~=0.15.5", + "pytest~=9.0.2", + "pytest-benchmark~=5.2.3", + "pytest-cov~=7.1.0", + "pytest-env~=1.6.0", + "pytest-mock~=3.15.1", + "testcontainers~=4.14.1", "types-aiofiles~=25.1.0", "types-beautifulsoup4~=4.12.0", - "types-cachetools~=5.5.0", + "types-cachetools~=6.2.0", "types-colorama~=0.4.15", "types-defusedxml~=0.7.0", - "types-deprecated~=1.2.15", - "types-docutils~=0.21.0", - "types-jsonschema~=4.23.0", - "types-flask-cors~=5.0.0", + "types-deprecated~=1.3.1", + "types-docutils~=0.22.3", + "types-jsonschema~=4.26.0", + "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", "types-greenlet~=3.3.0", "types-html5lib~=1.1.11", "types-markdown~=3.10.2", - "types-oauthlib~=3.2.0", + "types-oauthlib~=3.3.0", "types-objgraph~=3.6.0", "types-olefile~=0.47.0", "types-openpyxl~=3.1.5", "types-pexpect~=4.9.0", - "types-protobuf~=5.29.1", + "types-protobuf~=6.32.1", "types-psutil~=7.2.2", "types-psycopg2~=2.9.21", "types-pygments~=2.19.0", "types-pymysql~=1.1.0", "types-python-dateutil~=2.9.0", - "types-pywin32~=310.0.0", + "types-pywin32~=311.0.0", "types-pyyaml~=6.0.12", - "types-regex~=2024.11.6", + "types-regex~=2026.2.28", "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -164,18 +163,18 @@ dev = [ "types_pyOpenSSL>=24.1.0", "types_cffi>=1.17.0", "types_setuptools>=80.9.0", - "pandas-stubs~=2.2.3", + "pandas-stubs~=3.0.0", "scipy-stubs>=1.15.3.0", "types-python-http-client>=3.3.7.20240910", "import-linter>=2.3", "types-redis>=4.6.0.20241004", "celery-types>=0.23.0", - "mypy~=1.17.1", + "mypy~=1.19.1", # "locust>=2.40.4", # Temporarily removed due to compatibility issues. Uncomment when resolved. "sseclient-py>=1.8.0", "pytest-timeout>=2.4.0", "pytest-xdist>=3.8.0", - "pyrefly>=0.54.0", + "pyrefly>=0.55.0", ] ############################################################ @@ -183,13 +182,13 @@ dev = [ # Required for storage clients ############################################################ storage = [ - "azure-storage-blob==12.26.0", + "azure-storage-blob==12.28.0", "bce-python-sdk~=0.9.23", - "cos-python-sdk-v5==1.9.38", - "esdk-obs-python==3.25.8", - "google-cloud-storage==2.16.0", + "cos-python-sdk-v5==1.9.41", + "esdk-obs-python==3.26.2", + "google-cloud-storage>=3.0.0", "opendal~=0.46.0", - "oss2==2.18.5", + "oss2==2.19.1", "supabase~=2.18.1", "tos~=2.9.0", ] @@ -211,30 +210,31 @@ evaluation = ["ragas>=0.2.0", "deepeval>=2.0.0"] ############################################################ vdb = [ "alibabacloud_gpdb20160503~=3.8.0", - "alibabacloud_tea_openapi~=0.3.9", + "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", - "clickhouse-connect~=0.10.0", + "clickhouse-connect~=0.14.1", "clickzetta-connector-python>=0.8.102", - "couchbase~=4.3.0", + "couchbase~=4.5.0", "elasticsearch==8.14.0", "opensearch-py==3.1.0", - "oracledb==3.3.0", + "oracledb==3.4.2", "pgvecto-rs[sqlalchemy]~=0.2.1", - "pgvector==0.2.5", - "pymilvus~=2.5.0", - "pymochow==2.2.9", + "pgvector==0.4.2", + "pymilvus~=2.6.10", + "pymochow==2.3.6", "pyobvector~=0.2.17", "qdrant-client==1.9.0", "intersystems-irispython>=5.1.0", - "tablestore==6.3.7", - "tcvectordb~=1.6.4", - "tidb-vector==0.0.9", - "upstash-vector==0.6.0", + "tablestore==6.4.1", + "tcvectordb~=2.0.0", + "tidb-vector==0.0.15", + "upstash-vector==0.8.0", "volcengine-compat~=1.0.0", - "weaviate-client==4.17.0", - "xinference-client~=1.2.2", + "weaviate-client==4.20.4", + "xinference-client~=2.3.1", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", + "holo-search-sdk>=0.4.1", ] [tool.mypy] @@ -247,7 +247,7 @@ module = [ "configs.middleware.cache.redis_pubsub_config", "extensions.ext_redis", "tasks.workflow_execution_tasks", - "core.workflow.nodes.base.node", + "dify_graph.nodes.base.node", "services.human_input_delivery_test_service", "core.app.apps.advanced_chat.app_generator", "controllers.console.human_input_form", @@ -256,3 +256,10 @@ module = [ "extensions.logstore.repositories.logstore_api_workflow_run_repository", ] ignore_errors = true + +[tool.pyrefly] +project-includes = ["."] +project-excludes = [".venv", "migrations/"] +python-platform = "linux" +python-version = "3.11.0" +infer-with-first-use = false diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt new file mode 100644 index 0000000000..ad3c1e8389 --- /dev/null +++ b/api/pyrefly-local-excludes.txt @@ -0,0 +1,190 @@ +controllers/console/app/annotation.py +controllers/console/app/app.py +controllers/console/app/app_import.py +controllers/console/app/mcp_server.py +controllers/console/app/site.py +controllers/console/auth/email_register.py +controllers/console/human_input_form.py +controllers/console/init_validate.py +controllers/console/ping.py +controllers/console/setup.py +controllers/console/version.py +controllers/console/workspace/trigger_providers.py +controllers/service_api/app/annotation.py +controllers/web/workflow_events.py +core/agent/fc_agent_runner.py +core/app/apps/advanced_chat/app_generator.py +core/app/apps/advanced_chat/app_runner.py +core/app/apps/advanced_chat/generate_task_pipeline.py +core/app/apps/agent_chat/app_generator.py +core/app/apps/base_app_generate_response_converter.py +core/app/apps/base_app_generator.py +core/app/apps/chat/app_generator.py +core/app/apps/common/workflow_response_converter.py +core/app/apps/completion/app_generator.py +core/app/apps/pipeline/pipeline_generator.py +core/app/apps/pipeline/pipeline_runner.py +core/app/apps/workflow/app_generator.py +core/app/apps/workflow/app_runner.py +core/app/apps/workflow/generate_task_pipeline.py +core/app/apps/workflow_app_runner.py +core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +core/datasource/datasource_manager.py +core/external_data_tool/api/api.py +core/llm_generator/llm_generator.py +core/llm_generator/output_parser/structured_output.py +core/mcp/mcp_client.py +core/ops/aliyun_trace/data_exporter/traceclient.py +core/ops/arize_phoenix_trace/arize_phoenix_trace.py +core/ops/mlflow_trace/mlflow_trace.py +core/ops/ops_trace_manager.py +core/ops/tencent_trace/client.py +core/ops/tencent_trace/utils.py +core/plugin/backwards_invocation/base.py +core/plugin/backwards_invocation/model.py +core/prompt/utils/extract_thread_messages.py +core/rag/datasource/keyword/jieba/jieba.py +core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector.py +core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +core/rag/datasource/vdb/baidu/baidu_vector.py +core/rag/datasource/vdb/chroma/chroma_vector.py +core/rag/datasource/vdb/clickzetta/clickzetta_vector.py +core/rag/datasource/vdb/couchbase/couchbase_vector.py +core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +core/rag/datasource/vdb/huawei/huawei_cloud_vector.py +core/rag/datasource/vdb/lindorm/lindorm_vector.py +core/rag/datasource/vdb/matrixone/matrixone_vector.py +core/rag/datasource/vdb/milvus/milvus_vector.py +core/rag/datasource/vdb/myscale/myscale_vector.py +core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +core/rag/datasource/vdb/opensearch/opensearch_vector.py +core/rag/datasource/vdb/oracle/oraclevector.py +core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py +core/rag/datasource/vdb/relyt/relyt_vector.py +core/rag/datasource/vdb/tablestore/tablestore_vector.py +core/rag/datasource/vdb/tencent/tencent_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +core/rag/datasource/vdb/tidb_vector/tidb_vector.py +core/rag/datasource/vdb/upstash/upstash_vector.py +core/rag/datasource/vdb/vikingdb/vikingdb_vector.py +core/rag/datasource/vdb/weaviate/weaviate_vector.py +core/rag/extractor/csv_extractor.py +core/rag/extractor/excel_extractor.py +core/rag/extractor/firecrawl/firecrawl_app.py +core/rag/extractor/firecrawl/firecrawl_web_extractor.py +core/rag/extractor/html_extractor.py +core/rag/extractor/jina_reader_extractor.py +core/rag/extractor/markdown_extractor.py +core/rag/extractor/notion_extractor.py +core/rag/extractor/pdf_extractor.py +core/rag/extractor/text_extractor.py +core/rag/extractor/unstructured/unstructured_doc_extractor.py +core/rag/extractor/unstructured/unstructured_eml_extractor.py +core/rag/extractor/unstructured/unstructured_epub_extractor.py +core/rag/extractor/unstructured/unstructured_markdown_extractor.py +core/rag/extractor/unstructured/unstructured_msg_extractor.py +core/rag/extractor/unstructured/unstructured_ppt_extractor.py +core/rag/extractor/unstructured/unstructured_pptx_extractor.py +core/rag/extractor/unstructured/unstructured_xml_extractor.py +core/rag/extractor/watercrawl/client.py +core/rag/extractor/watercrawl/extractor.py +core/rag/extractor/watercrawl/provider.py +core/rag/extractor/word_extractor.py +core/rag/index_processor/processor/paragraph_index_processor.py +core/rag/index_processor/processor/parent_child_index_processor.py +core/rag/index_processor/processor/qa_index_processor.py +core/rag/retrieval/router/multi_dataset_function_call_router.py +core/rag/summary_index/summary_index.py +core/repositories/sqlalchemy_workflow_execution_repository.py +core/repositories/sqlalchemy_workflow_node_execution_repository.py +core/tools/__base/tool.py +core/tools/mcp_tool/provider.py +core/tools/plugin_tool/provider.py +core/tools/utils/message_transformer.py +core/tools/utils/web_reader_tool.py +core/tools/workflow_as_tool/provider.py +core/trigger/debug/event_selectors.py +core/trigger/entities/entities.py +core/trigger/provider.py +core/workflow/workflow_entry.py +dify_graph/entities/workflow_execution.py +dify_graph/file/file_manager.py +dify_graph/graph_engine/error_handler.py +dify_graph/graph_engine/layers/execution_limits.py +dify_graph/nodes/agent/agent_node.py +dify_graph/nodes/base/node.py +dify_graph/nodes/code/code_node.py +dify_graph/nodes/datasource/datasource_node.py +dify_graph/nodes/document_extractor/node.py +dify_graph/nodes/human_input/human_input_node.py +dify_graph/nodes/if_else/if_else_node.py +dify_graph/nodes/iteration/iteration_node.py +dify_graph/nodes/knowledge_index/knowledge_index_node.py +core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +dify_graph/nodes/list_operator/node.py +dify_graph/nodes/llm/node.py +dify_graph/nodes/loop/loop_node.py +dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +dify_graph/nodes/question_classifier/question_classifier_node.py +dify_graph/nodes/start/start_node.py +dify_graph/nodes/template_transform/template_transform_node.py +dify_graph/nodes/tool/tool_node.py +dify_graph/nodes/trigger_plugin/trigger_event_node.py +dify_graph/nodes/trigger_schedule/trigger_schedule_node.py +dify_graph/nodes/trigger_webhook/node.py +dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +dify_graph/nodes/variable_assigner/v1/node.py +dify_graph/nodes/variable_assigner/v2/node.py +extensions/logstore/repositories/logstore_api_workflow_run_repository.py +extensions/otel/instrumentation.py +extensions/otel/runtime.py +extensions/storage/aliyun_oss_storage.py +extensions/storage/aws_s3_storage.py +extensions/storage/azure_blob_storage.py +extensions/storage/baidu_obs_storage.py +extensions/storage/clickzetta_volume/clickzetta_volume_storage.py +extensions/storage/clickzetta_volume/file_lifecycle.py +extensions/storage/google_cloud_storage.py +extensions/storage/huawei_obs_storage.py +extensions/storage/opendal_storage.py +extensions/storage/oracle_oci_storage.py +extensions/storage/supabase_storage.py +extensions/storage/tencent_cos_storage.py +extensions/storage/volcengine_tos_storage.py +libs/gmpy2_pkcs10aep_cipher.py +schedule/queue_monitor_task.py +services/account_service.py +services/audio_service.py +services/auth/firecrawl/firecrawl.py +services/auth/jina.py +services/auth/jina/jina.py +services/auth/watercrawl/watercrawl.py +services/conversation_service.py +services/dataset_service.py +services/document_indexing_proxy/document_indexing_task_proxy.py +services/document_indexing_proxy/duplicate_document_indexing_task_proxy.py +services/external_knowledge_service.py +services/plugin/plugin_migration.py +services/recommend_app/buildin/buildin_retrieval.py +services/recommend_app/database/database_retrieval.py +services/recommend_app/remote/remote_retrieval.py +services/summary_index_service.py +services/tools/tools_transform_service.py +services/trigger/trigger_provider_service.py +services/trigger/trigger_subscription_builder_service.py +services/trigger/webhook_service.py +services/workflow_draft_variable_service.py +services/workflow_event_snapshot_service.py +services/workflow_service.py +tasks/app_generate/workflow_execute_task.py +tasks/regenerate_summary_index_task.py +tasks/trigger_processing_tasks.py +tasks/workflow_cfs_scheduler/cfs_scheduler.py +tasks/add_document_to_index_task.py +tasks/create_segment_to_index_task.py +tasks/disable_segment_from_index_task.py +tasks/enable_segment_to_index_task.py +tasks/remove_document_from_index_task.py +tasks/workflow_execution_tasks.py diff --git a/api/pyrefly.toml b/api/pyrefly.toml deleted file mode 100644 index 01f4c5a529..0000000000 --- a/api/pyrefly.toml +++ /dev/null @@ -1,8 +0,0 @@ -project-includes = ["."] -project-excludes = [ - ".venv", - "migrations/", -] -python-platform = "linux" -python-version = "3.11.0" -infer-with-first-use = false diff --git a/api/pyrightconfig.json b/api/pyrightconfig.json index 007c49ddb0..48271aab61 100644 --- a/api/pyrightconfig.json +++ b/api/pyrightconfig.json @@ -35,7 +35,8 @@ "tos", "gmpy2", "sendgrid", - "sendgrid.helpers.mail" + "sendgrid.helpers.mail", + "holo_search_sdk.types" ], "reportUnknownMemberType": "hint", "reportUnknownParameterType": "hint", diff --git a/api/pytest.ini b/api/pytest.ini index 4a9470fa0c..4d5d0ab6e0 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,6 @@ [pytest] -addopts = --cov=./api --cov-report=json +pythonpath = . +addopts = --cov=./api --cov-report=json --import-mode=importlib --cov-branch --cov-report=xml env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz AZURE_OPENAI_API_BASE = https://difyai-openai.openai.azure.com @@ -19,7 +20,7 @@ env = GOOGLE_API_KEY = abcdefghijklmnopqrstuvwxyz HUGGINGFACE_API_KEY = hf-awuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwuwu HUGGINGFACE_EMBEDDINGS_ENDPOINT_URL = c - HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b + HUGGINGFACE_TEXT2TEXT_GEN_ENDPOINT_URL = b HUGGINGFACE_TEXT_GEN_ENDPOINT_URL = a MIXEDBREAD_API_KEY = mk-aaaaaaaaaaaaaaaaaaaa MOCK_SWITCH = true diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 6446eb0d6e..2fa065bcc8 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index ffa87b209f..a96c4acb31 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session -from core.workflow.entities.pause_reason import PauseReason -from core.workflow.enums import WorkflowType -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.entities.pause_reason import PauseReason +from dify_graph.enums import WorkflowType +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index a3c4039aaa..be28b7e613 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from core.workflow.entities.pause_reason import PauseReason +from dify_graph.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 6c696b6478..77e40fc6fc 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -8,13 +8,13 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. import json from collections.abc import Sequence from datetime import datetime -from typing import cast +from typing import Protocol, cast from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, @@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import ( ) +class _WorkflowNodeExecutionSnapshotRow(Protocol): + id: str + node_execution_id: str | None + node_id: str + node_type: str + title: str + index: int + status: WorkflowNodeExecutionStatus + elapsed_time: float | None + created_at: datetime + finished_at: datetime | None + execution_metadata: str | None + + class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): """ SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. @@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut - Thread-safe database operations using session-per-request pattern """ + _session_maker: sessionmaker[Session] + def __init__(self, session_maker: sessionmaker[Session]): """ Initialize the repository with a sessionmaker. @@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut ) with self._session_maker() as session: - rows = session.execute(stmt).all() + rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all()) return [self._row_to_snapshot(row) for row in rows] @staticmethod - def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot: + def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot: metadata: dict[str, object] = {} execution_metadata = getattr(row, "execution_metadata", None) if execution_metadata: diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index 5ba7a7e7e8..fdd3e123e4 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -33,9 +33,9 @@ from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus, WorkflowType -from core.workflow.nodes.human_input.entities import FormDefinition +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from dify_graph.enums import WorkflowExecutionStatus, WorkflowType +from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 5a2c0ea46f..508db22eb0 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -18,9 +18,9 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from core.workflow.nodes.human_input.entities import FormDefinition -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.human_input.entities import FormDefinition +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py index 13d2f24ca0..cf223f6e9e 100644 --- a/api/schedule/check_upgradable_plugin_task.py +++ b/api/schedule/check_upgradable_plugin_task.py @@ -3,6 +3,7 @@ import math import time import click +from sqlalchemy import select import app from core.helper.marketplace import fetch_global_plugin_manifest @@ -28,17 +29,15 @@ def check_upgradable_plugin_task(): now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC click.echo(click.style(f"Now seconds of day: {now_seconds_of_day}", fg="green")) - strategies = ( - db.session.query(TenantPluginAutoUpgradeStrategy) - .where( + strategies = db.session.scalars( + select(TenantPluginAutoUpgradeStrategy).where( TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day, TenantPluginAutoUpgradeStrategy.upgrade_time_of_day < now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL, TenantPluginAutoUpgradeStrategy.strategy_setting != TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED, ) - .all() - ) + ).all() total_strategies = len(strategies) click.echo(click.style(f"Total strategies: {total_strategies}", fg="green")) diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py index 2b74fb2dd0..04c954875f 100644 --- a/api/schedule/clean_embedding_cache_task.py +++ b/api/schedule/clean_embedding_cache_task.py @@ -2,7 +2,7 @@ import datetime import time import click -from sqlalchemy import text +from sqlalchemy import select, text from sqlalchemy.exc import SQLAlchemyError import app @@ -19,14 +19,12 @@ def clean_embedding_cache_task(): thirty_days_ago = datetime.datetime.now() - datetime.timedelta(days=clean_days) while True: try: - embedding_ids = ( - db.session.query(Embedding.id) + embedding_ids = db.session.scalars( + select(Embedding.id) .where(Embedding.created_at < thirty_days_ago) .order_by(Embedding.created_at.desc()) .limit(100) - .all() - ) - embedding_ids = [embedding_id[0] for embedding_id in embedding_ids] + ).all() except SQLAlchemyError: raise if embedding_ids: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index d9fb6a24f1..0b0fc1b229 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -3,7 +3,7 @@ import time from typing import TypedDict import click -from sqlalchemy import func, select +from sqlalchemy import func, select, update from sqlalchemy.exc import SQLAlchemyError import app @@ -51,7 +51,7 @@ def clean_unused_datasets_task(): try: # Subquery for counting new documents document_subquery_new = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + select(Document.dataset_id, func.count(Document.id).label("document_count")) .where( Document.indexing_status == "completed", Document.enabled == True, @@ -64,7 +64,7 @@ def clean_unused_datasets_task(): # Subquery for counting old documents document_subquery_old = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("document_count")) + select(Document.dataset_id, func.count(Document.id).label("document_count")) .where( Document.indexing_status == "completed", Document.enabled == True, @@ -142,8 +142,8 @@ def clean_unused_datasets_task(): index_processor.clean(dataset, None) # Update document - db.session.query(Document).filter_by(dataset_id=dataset.id).update( - {Document.enabled: False} + db.session.execute( + update(Document).where(Document.dataset_id == dataset.id).values(enabled=False) ) db.session.commit() click.echo(click.style(f"Cleaned unused dataset {dataset.id} from db success!", fg="green")) diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index ed46c1c70a..8b9d973d6d 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -1,6 +1,7 @@ import time import click +from sqlalchemy import func, select import app from configs import dify_config @@ -20,7 +21,7 @@ def create_tidb_serverless_task(): try: # check the number of idle tidb serverless idle_tidb_serverless_number = ( - db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count() + db.session.scalar(select(func.count(TidbAuthBinding.id)).where(TidbAuthBinding.active == False)) or 0 ) if idle_tidb_serverless_number >= tidb_serverless_number: break diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index d738bf46fa..8479cdfb0c 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -49,16 +49,18 @@ def mail_clean_document_notify_task(): if plan != CloudPlan.SANDBOX: knowledge_details = [] # check tenant - tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first() + tenant = db.session.scalar(select(Tenant).where(Tenant.id == tenant_id)) if not tenant: continue # check current owner - current_owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join = db.session.scalar( + select(TenantAccountJoin) + .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role == "owner") + .limit(1) ) if not current_owner_join: continue - account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first() + account = db.session.scalar(select(Account).where(Account.id == current_owner_join.account_id)) if not account: continue @@ -71,7 +73,7 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = db.session.scalar(select(Dataset).where(Dataset.id == dataset_id)) if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py index 77d6b5a138..01642e397e 100644 --- a/api/schedule/queue_monitor_task.py +++ b/api/schedule/queue_monitor_task.py @@ -21,6 +21,10 @@ celery_redis = Redis( ssl_cert_reqs=getattr(dify_config, "REDIS_SSL_CERT_REQS", None) if dify_config.BROKER_USE_SSL else None, ssl_certfile=getattr(dify_config, "REDIS_SSL_CERTFILE", None) if dify_config.BROKER_USE_SSL else None, ssl_keyfile=getattr(dify_config, "REDIS_SSL_KEYFILE", None) if dify_config.BROKER_USE_SSL else None, + # Add conservative socket timeouts and health checks to avoid long-lived half-open sockets + socket_timeout=5, + socket_connect_timeout=5, + health_check_interval=30, ) logger = logging.getLogger(__name__) diff --git a/api/schedule/trigger_provider_refresh_task.py b/api/schedule/trigger_provider_refresh_task.py index 3b3e478793..df5058d70a 100644 --- a/api/schedule/trigger_provider_refresh_task.py +++ b/api/schedule/trigger_provider_refresh_task.py @@ -3,6 +3,7 @@ import math import time from collections.abc import Iterable, Sequence +from celery import group from sqlalchemy import ColumnElement, and_, func, or_, select from sqlalchemy.engine.row import Row from sqlalchemy.orm import Session @@ -85,20 +86,25 @@ def trigger_provider_refresh() -> None: lock_keys: list[str] = build_trigger_refresh_lock_keys(subscriptions) acquired: list[bool] = _acquire_locks(keys=lock_keys, ttl_seconds=lock_ttl) - enqueued: int = 0 - for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired): - if not is_locked: - continue - trigger_subscription_refresh.delay(tenant_id=tenant_id, subscription_id=subscription_id) - enqueued += 1 + if not any(acquired): + continue + + jobs = [ + trigger_subscription_refresh.s(tenant_id=tenant_id, subscription_id=subscription_id) + for (tenant_id, subscription_id), is_locked in zip(subscriptions, acquired) + if is_locked + ] + result = group(jobs).apply_async() + enqueued = len(jobs) logger.info( - "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d", + "Trigger refresh page %d/%d: scanned=%d locks_acquired=%d enqueued=%d result=%s", page + 1, pages, len(subscriptions), sum(1 for x in acquired if x), enqueued, + result, ) logger.info("Trigger refresh scan done: due=%d", total_due) diff --git a/api/schedule/workflow_schedule_task.py b/api/schedule/workflow_schedule_task.py index d68b9565ec..2fee9e467d 100644 --- a/api/schedule/workflow_schedule_task.py +++ b/api/schedule/workflow_schedule_task.py @@ -1,6 +1,6 @@ import logging -from celery import group, shared_task +from celery import current_app, group, shared_task from sqlalchemy import and_, select from sqlalchemy.orm import Session, sessionmaker @@ -29,31 +29,27 @@ def poll_workflow_schedules() -> None: with session_factory() as session: total_dispatched = 0 - # Process in batches until we've handled all due schedules or hit the limit while True: due_schedules = _fetch_due_schedules(session) if not due_schedules: break - dispatched_count = _process_schedules(session, due_schedules) - total_dispatched += dispatched_count + with current_app.producer_or_acquire() as producer: # type: ignore + dispatched_count = _process_schedules(session, due_schedules, producer) + total_dispatched += dispatched_count - logger.debug("Batch processed: %d dispatched", dispatched_count) - - # Circuit breaker: check if we've hit the per-tick limit (if enabled) - if ( - dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK > 0 - and total_dispatched >= dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK - ): - logger.warning( - "Circuit breaker activated: reached dispatch limit (%d), will continue next tick", - dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK, - ) - break + logger.debug("Batch processed: %d dispatched", dispatched_count) + # Circuit breaker: check if we've hit the per-tick limit (if enabled) + if 0 < dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK <= total_dispatched: + logger.warning( + "Circuit breaker activated: reached dispatch limit (%d), will continue next tick", + dify_config.WORKFLOW_SCHEDULE_MAX_DISPATCH_PER_TICK, + ) + break if total_dispatched > 0: - logger.info("Total processed: %d dispatched", total_dispatched) + logger.info("Total processed: %d workflow schedule(s) dispatched", total_dispatched) def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]: @@ -90,7 +86,7 @@ def _fetch_due_schedules(session: Session) -> list[WorkflowSchedulePlan]: return list(due_schedules) -def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) -> int: +def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan], producer=None) -> int: """Process schedules: check quota, update next run time and dispatch to Celery in parallel.""" if not schedules: return 0 @@ -107,7 +103,7 @@ def _process_schedules(session: Session, schedules: list[WorkflowSchedulePlan]) if tasks_to_dispatch: job = group(run_schedule_trigger.s(schedule_id) for schedule_id in tasks_to_dispatch) - job.apply_async() + job.apply_async(producer=producer) logger.debug("Dispatched %d tasks in parallel", len(tasks_to_dispatch)) diff --git a/api/services/account_service.py b/api/services/account_service.py index 648b5e834f..bd520f54cf 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -74,6 +74,16 @@ from tasks.mail_reset_password_task import ( logger = logging.getLogger(__name__) +def _try_join_enterprise_default_workspace(account_id: str) -> None: + """Best-effort join to enterprise default workspace.""" + if not dify_config.ENTERPRISE_ENABLED: + return + + from services.enterprise.enterprise_service import try_join_default_workspace + + try_join_default_workspace(account_id) + + class TokenPair(BaseModel): access_token: str refresh_token: str @@ -287,13 +297,14 @@ class AccountService: email=email, name=name, interface_language=interface_language, password=password ) - TenantService.create_owner_tenant_if_not_exist(account=account) + try: + TenantService.create_owner_tenant_if_not_exist(account=account) + except Exception: + # Enterprise-only side-effect should run independently from personal workspace creation. + _try_join_enterprise_default_workspace(str(account.id)) + raise - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) return account @@ -1078,9 +1089,9 @@ class TenantService: ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: - ta.role = role + ta.role = TenantAccountRole(role) else: - ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole(role)) db.session.add(ta) db.session.commit() @@ -1308,10 +1319,10 @@ class TenantService: db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() ) if current_owner_join: - current_owner_join.role = "admin" + current_owner_join.role = TenantAccountRole.ADMIN # Update the role of the target member - target_member_join.role = new_role + target_member_join.role = TenantAccountRole(new_role) db.session.commit() @staticmethod @@ -1407,18 +1418,18 @@ class RegisterService: and create_workspace_required and FeatureService.get_system_features().license.workspaces.is_available() ): - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + try: + tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(tenant, account, role="owner") + account.current_tenant = tenant + tenant_was_created.send(tenant) + except Exception: + _try_join_enterprise_default_workspace(str(account.id)) + raise db.session.commit() - # Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace). - if dify_config.ENTERPRISE_ENABLED: - from services.enterprise.enterprise_service import try_join_default_workspace - - try_join_default_workspace(str(account.id)) + _try_join_enterprise_default_workspace(str(account.id)) except WorkSpaceNotAllowedCreateError: db.session.rollback() logger.exception("Register failed") diff --git a/api/services/agent_service.py b/api/services/agent_service.py index b2db895a5a..2b8a3ee594 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -11,7 +11,7 @@ from core.tools.tool_manager import ToolManager from extensions.ext_database import db from libs.login import current_user from models import Account -from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from models.model import App, Conversation, EndUser, Message class AgentService: @@ -47,7 +47,7 @@ class AgentService: if not message: raise ValueError(f"Message not found: {message_id}") - agent_thoughts: list[MessageAgentThought] = message.agent_thoughts + agent_thoughts = message.agent_thoughts if conversation.from_end_user_id: # only select name field diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 9400362605..68cb3438ca 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -4,6 +4,7 @@ import logging import uuid from collections.abc import Mapping from enum import StrEnum +from typing import cast from urllib.parse import urlparse from uuid import uuid4 @@ -18,21 +19,26 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper import ssrf_proxy -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import PluginDependency -from core.workflow.enums import NodeType +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData +from dify_graph.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory from libs.datetime_utils import naive_utc_now from models import Account, App, AppMode -from models.model import AppModelConfig, IconType +from models.model import AppModelConfig, AppModelConfigDict, IconType from models.workflow import Workflow from services.plugin.dependencies_analysis import DependenciesAnalysisService from services.workflow_draft_variable_service import WorkflowDraftVariableService @@ -298,7 +304,7 @@ class AppDslService: ) draft_var_srv = WorkflowDraftVariableService(session=self._session) - draft_var_srv.delete_workflow_variables(app_id=app.id) + draft_var_srv.delete_app_workflow_variables(app_id=app.id) return Import( id=import_id, status=status, @@ -428,17 +434,18 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") + resolved_icon_type: IconType if icon_type_value in [IconType.EMOJI, IconType.IMAGE, IconType.LINK]: - icon_type = icon_type_value + resolved_icon_type = IconType(icon_type_value) else: - icon_type = IconType.EMOJI + resolved_icon_type = IconType.EMOJI icon = icon or str(app_data.get("icon", "")) if app: # Update existing app app.name = name or app_data.get("name", app.name) app.description = description or app_data.get("description", app.description) - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", app.icon_background) app.updated_by = account.id @@ -451,10 +458,10 @@ class AppDslService: app = App() app.id = str(uuid4()) app.tenant_id = account.current_tenant_id - app.mode = app_mode.value + app.mode = app_mode app.name = name or app_data.get("name", "") app.description = description or app_data.get("description", "") - app.icon_type = icon_type + app.icon_type = resolved_icon_type app.icon = icon app.icon_background = icon_background or app_data.get("icon_background", "#FFFFFF") app.enable_site = True @@ -498,7 +505,7 @@ class AppDslService: unique_hash = None graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -523,7 +530,7 @@ class AppDslService: if not app.app_model_config: app_model_config = AppModelConfig( app_id=app.id, created_by=account.id, updated_by=account.id - ).from_model_config_dict(model_config) + ).from_model_config_dict(cast(AppModelConfigDict, model_config)) app_model_config.id = str(uuid4()) app.app_model_config_id = app_model_config.id @@ -548,9 +555,12 @@ class AppDslService: "kind": "app", "app": { "name": app_model.name, - "mode": app_model.mode, - "icon": app_model.icon if app_model.icon_type == "image" else "🤖", - "icon_background": "#FFEAD5" if app_model.icon_type == "image" else app_model.icon_background, + "mode": app_model.mode.value if isinstance(app_model.mode, AppMode) else app_model.mode, + "icon": app_model.icon, + "icon_type": ( + app_model.icon_type.value if isinstance(app_model.icon_type, IconType) else app_model.icon_type + ), + "icon_background": app_model.icon_background, "description": app_model.description, "use_icon_as_answer_icon": app_model.use_icon_as_answer_icon, }, @@ -586,27 +596,27 @@ class AppDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL: + if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL: + if not include_secret and data_type == BuiltinNodeTypes.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT: + if not include_secret and data_type == BuiltinNodeTypes.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) - if data_type == NodeType.TRIGGER_SCHEDULE.value: + if data_type == TRIGGER_SCHEDULE_NODE_TYPE: # override the config with the default config node_data["config"] = TriggerScheduleNode.get_default_config()["config"] - if data_type == NodeType.TRIGGER_WEBHOOK.value: + if data_type == TRIGGER_WEBHOOK_NODE_TYPE: # clear the webhook_url node_data["webhook_url"] = "" node_data["webhook_debug_url"] = "" - if data_type == NodeType.TRIGGER_PLUGIN.value: + if data_type == TRIGGER_PLUGIN_NODE_TYPE: # clear the subscription_id node_data["subscription_id"] = "" @@ -670,31 +680,31 @@ class AppDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL: + case BuiltinNodeTypes.TOOL: tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.LLM: + case BuiltinNodeTypes.LLM: llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_RETRIEVAL: + case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 31003cb8f7..40013f2b66 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -38,6 +38,13 @@ if TYPE_CHECKING: class AppGenerateService: @staticmethod def _build_streaming_task_on_subscribe(start_task: Callable[[], None]) -> Callable[[], None]: + """ + Build a subscription callback that coordinates when the background task starts. + + - streams transport: start immediately (events are durable; late subscribers can replay). + - pubsub/sharded transport: start on first subscribe, with a short fallback timer so the task + still runs if the client never connects. + """ started = False lock = threading.Lock() @@ -54,10 +61,18 @@ class AppGenerateService: started = True return True - # XXX(QuantumGhost): dirty hacks to avoid a race between publisher and SSE subscriber. - # The Celery task may publish the first event before the API side actually subscribes, - # causing an "at most once" drop with Redis Pub/Sub. We start the task on subscribe, - # but also use a short fallback timer so the task still runs if the client never consumes. + channel_type = dify_config.PUBSUB_REDIS_CHANNEL_TYPE + if channel_type == "streams": + # With Redis Streams, we can safely start right away; consumers can read past events. + _try_start() + + # Keep return type Callable[[], None] consistent while allowing an extra (no-op) call. + def _on_subscribe_streams() -> None: + _try_start() + + return _on_subscribe_streams + + # Pub/Sub modes (at-most-once): subscribe-gated start with a tiny fallback. timer = threading.Timer(SSE_TASK_START_FALLBACK_MS / 1000.0, _try_start) timer.daemon = True timer.start() diff --git a/api/services/app_model_config_service.py b/api/services/app_model_config_service.py index 6f54f90734..3bc30cb323 100644 --- a/api/services/app_model_config_service.py +++ b/api/services/app_model_config_service.py @@ -1,12 +1,12 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager -from models.model import AppMode +from models.model import AppMode, AppModelConfigDict class AppModelConfigService: @classmethod - def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode): + def validate_configuration(cls, tenant_id: str, config: dict, app_mode: AppMode) -> AppModelConfigDict: if app_mode == AppMode.CHAT: return ChatAppConfigManager.config_validate(tenant_id, config) elif app_mode == AppMode.AGENT_CHAT: diff --git a/api/services/app_service.py b/api/services/app_service.py index e57253f8b6..c5d1479a20 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from typing import TypedDict, cast +from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination @@ -10,16 +10,16 @@ from constants.model_template import default_app_templates from core.agent.entities import AgentToolEntity from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.app_event import app_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user from models import Account -from models.model import App, AppMode, AppModelConfig, Site +from models.model import App, AppMode, AppModelConfig, IconType, Site from models.tools import ApiToolProvider from services.billing_service import BillingService from services.enterprise.enterprise_service import EnterpriseService @@ -187,7 +187,10 @@ class AppService: for tool in agent_mode.get("tools") or []: if not isinstance(tool, dict) or len(tool.keys()) <= 3: continue - agent_tool_entity = AgentToolEntity(**tool) + typed_tool = {key: value for key, value in tool.items() if isinstance(key, str)} + if len(typed_tool) != len(tool): + continue + agent_tool_entity = AgentToolEntity.model_validate(typed_tool) # get tool try: tool_runtime = ToolManager.get_agent_tool_runtime( @@ -254,7 +257,7 @@ class AppService: assert current_user is not None app.name = args["name"] app.description = args["description"] - app.icon_type = args["icon_type"] + app.icon_type = IconType(args["icon_type"]) if args["icon_type"] else None app.icon = args["icon"] app.icon_background = args["icon_background"] app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) @@ -388,7 +391,7 @@ class AppService: agent_config = app_model_config.agent_mode_dict # get all tools - tools = agent_config.get("tools", []) + tools = cast(list[dict[str, Any]], agent_config.get("tools", [])) url_prefix = dify_config.CONSOLE_API_URL + "/console/api/workspaces/current/tool-provider/builtin/" diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index 5ae1fba2e8..d556230044 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -7,7 +7,7 @@ new GraphEngine command channel mechanism. from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client from models.model import AppMode diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index 94452482b3..0133634e5a 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -18,7 +18,7 @@ from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowTriggerStatus from models.model import App, EndUser -from models.trigger import WorkflowTriggerLog +from models.trigger import WorkflowTriggerLog, WorkflowTriggerLogDict from models.workflow import Workflow from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError @@ -224,7 +224,9 @@ class AsyncWorkflowService: return cls.trigger_workflow_async(session, user, trigger_data) @classmethod - def get_trigger_log(cls, workflow_trigger_log_id: str, tenant_id: str | None = None) -> dict[str, Any] | None: + def get_trigger_log( + cls, workflow_trigger_log_id: str, tenant_id: str | None = None + ) -> WorkflowTriggerLogDict | None: """ Get trigger log by ID @@ -247,7 +249,7 @@ class AsyncWorkflowService: @classmethod def get_recent_logs( cls, tenant_id: str, app_id: str, hours: int = 24, limit: int = 100, offset: int = 0 - ) -> list[dict[str, Any]]: + ) -> list[WorkflowTriggerLogDict]: """ Get recent trigger logs @@ -272,7 +274,7 @@ class AsyncWorkflowService: @classmethod def get_failed_logs_for_retry( cls, tenant_id: str, max_retry_count: int = 3, limit: int = 100 - ) -> list[dict[str, Any]]: + ) -> list[WorkflowTriggerLogDict]: """ Get failed logs eligible for retry diff --git a/api/services/audio_service.py b/api/services/audio_service.py index a95361cebd..1794ea9947 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -2,13 +2,14 @@ import io import logging import uuid from collections.abc import Generator +from typing import cast from flask import Response, stream_with_context from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -106,7 +107,7 @@ class AudioService: if not text_to_speech_dict.get("enabled"): raise ValueError("TTS is not enabled") - voice = text_to_speech_dict.get("voice") + voice = cast(str | None, text_to_speech_dict.get("voice")) model_manager = ModelManager() model_instance = model_manager.get_default_model_instance( diff --git a/api/services/billing_service.py b/api/services/billing_service.py index 946b8cdfdb..70d4ce1ee6 100644 --- a/api/services/billing_service.py +++ b/api/services/billing_service.py @@ -335,7 +335,11 @@ class BillingService: # Redis returns bytes, decode to string and parse JSON json_str = cached_value.decode("utf-8") if isinstance(cached_value, bytes) else cached_value plan_dict = json.loads(json_str) + # NOTE (hj24): New billing versions may return timestamp as str, and validate_python + # in non-strict mode will coerce it to the expected int type. + # To preserve compatibility, always keep non-strict mode here and avoid strict mode. subscription_plan = subscription_adapter.validate_python(plan_dict) + # NOTE END tenant_plans[tenant_id] = subscription_plan except Exception: logger.exception( @@ -393,3 +397,78 @@ class BillingService: for item in data: tenant_whitelist.append(item["tenant_id"]) return tenant_whitelist + + @classmethod + def get_account_notification(cls, account_id: str) -> dict: + """Return the active in-product notification for account_id, if any. + + Calling this endpoint also marks the notification as seen; subsequent + calls will return should_show=false when frequency='once'. + + Response shape (mirrors GetAccountNotificationReply): + { + "should_show": bool, + "notification": { # present only when should_show=true + "notification_id": str, + "contents": { # lang -> LangContent + "en": {"lang": "en", "title": ..., "subtitle": ..., "body": ..., "title_pic_url": ...}, + ... + }, + "frequency": "once" | "every_page_load" + } + } + """ + return cls._send_request("GET", "/notifications/active", params={"account_id": account_id}) + + @classmethod + def upsert_notification( + cls, + contents: list[dict], + frequency: str = "once", + status: str = "active", + notification_id: str | None = None, + start_time: str | None = None, + end_time: str | None = None, + ) -> dict: + """Create or update a notification. + + contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str} + start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional. + Returns {"notification_id": str}. + """ + payload: dict = { + "contents": contents, + "frequency": frequency, + "status": status, + } + if notification_id: + payload["notification_id"] = notification_id + if start_time: + payload["start_time"] = start_time + if end_time: + payload["end_time"] = end_time + return cls._send_request("POST", "/notifications", json=payload) + + @classmethod + def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict: + """Register target account IDs for a notification (max 1000 per call). + + Returns {"count": int}. + """ + return cls._send_request( + "POST", + f"/notifications/{notification_id}/accounts", + json={"account_ids": account_ids}, + ) + + @classmethod + def dismiss_notification(cls, notification_id: str, account_id: str) -> dict: + """Mark a notification as dismissed for an account. + + Returns {"success": bool}. + """ + return cls._send_request( + "POST", + f"/notifications/{notification_id}/dismiss", + json={"account_id": account_id}, + ) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index aefc34fcae..0e0eab00ad 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -10,7 +10,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from core.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 4c87150cf7..566c27c0f3 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -10,7 +10,7 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index b0012d6f6a..f00e3fe01e 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.variables.variables import VariableBase +from dify_graph.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 35b20f7601..cdab90a3dc 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -20,12 +20,12 @@ from core.db.session_factory import session_factory from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.helper.name_generator import generate_incremental_name from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -51,6 +51,14 @@ from models.dataset import ( Pipeline, SegmentAttachmentBinding, ) +from models.enums import ( + DatasetRuntimeMode, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) from models.model import UploadFile from models.provider_ids import ModelProviderID from models.source import DataSourceOauthBinding @@ -254,7 +262,7 @@ class DatasetService: dataset.embedding_model_provider = embedding_model.provider if embedding_model else None dataset.embedding_model = embedding_model.model_name if embedding_model else None dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None - dataset.permission = permission or DatasetPermissionEnum.ONLY_ME + dataset.permission = DatasetPermissionEnum(permission) if permission else DatasetPermissionEnum.ONLY_ME dataset.provider = provider if summary_index_setting is not None: dataset.summary_index_setting = summary_index_setting @@ -319,7 +327,7 @@ class DatasetService: description=rag_pipeline_dataset_create_entity.description, permission=rag_pipeline_dataset_create_entity.permission, provider="vendor", - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, icon_info=rag_pipeline_dataset_create_entity.icon_info.model_dump(), created_by=current_user.id, pipeline_id=pipeline.id, @@ -614,7 +622,7 @@ class DatasetService: """ Update pipeline knowledge base node data. """ - if dataset.runtime_mode != "rag_pipeline": + if dataset.runtime_mode != DatasetRuntimeMode.RAG_PIPELINE: return pipeline = db.session.query(Pipeline).filter_by(id=dataset.pipeline_id).first() @@ -1229,10 +1237,15 @@ class DocumentService: "enabled": "available", } - _INDEXING_STATUSES: tuple[str, ...] = ("parsing", "cleaning", "splitting", "indexing") + _INDEXING_STATUSES: tuple[IndexingStatus, ...] = ( + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ) DISPLAY_STATUS_FILTERS: dict[str, tuple[Any, ...]] = { - "queuing": (Document.indexing_status == "waiting",), + "queuing": (Document.indexing_status == IndexingStatus.WAITING,), "indexing": ( Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_not(True), @@ -1241,19 +1254,19 @@ class DocumentService: Document.indexing_status.in_(_INDEXING_STATUSES), Document.is_paused.is_(True), ), - "error": (Document.indexing_status == "error",), + "error": (Document.indexing_status == IndexingStatus.ERROR,), "available": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(True), ), "disabled": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(False), Document.enabled.is_(False), ), "archived": ( - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived.is_(True), ), } @@ -1536,7 +1549,7 @@ class DocumentService: """ Normalize and validate `Document -> UploadFile` linkage for download flows. """ - if document.data_source_type != "upload_file": + if document.data_source_type != DataSourceType.UPLOAD_FILE: raise NotFound(invalid_source_message) data_source_info: dict[str, Any] = document.data_source_info_dict or {} @@ -1617,7 +1630,7 @@ class DocumentService: select(Document).where( Document.id.in_(document_ids), Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1640,7 +1653,7 @@ class DocumentService: select(Document).where( Document.dataset_id == dataset_id, Document.enabled == True, - Document.indexing_status == "completed", + Document.indexing_status == IndexingStatus.COMPLETED, Document.archived == False, ) ).all() @@ -1650,7 +1663,10 @@ class DocumentService: @staticmethod def get_error_documents_by_dataset_id(dataset_id: str) -> Sequence[Document]: documents = db.session.scalars( - select(Document).where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"])) + select(Document).where( + Document.dataset_id == dataset_id, + Document.indexing_status.in_([IndexingStatus.ERROR, IndexingStatus.PAUSED]), + ) ).all() return documents @@ -1683,7 +1699,7 @@ class DocumentService: def delete_document(document): # trigger document_was_deleted signal file_id = None - if document.data_source_type == "upload_file": + if document.data_source_type == DataSourceType.UPLOAD_FILE: if document.data_source_info: data_source_info = document.data_source_info_dict if data_source_info and "upload_file_id" in data_source_info: @@ -1704,7 +1720,7 @@ class DocumentService: file_ids = [ document.data_source_info_dict.get("upload_file_id", "") for document in documents - if document.data_source_type == "upload_file" and document.data_source_info_dict + if document.data_source_type == DataSourceType.UPLOAD_FILE and document.data_source_info_dict ] # Delete documents first, then dispatch cleanup task after commit @@ -1753,7 +1769,13 @@ class DocumentService: @staticmethod def pause_document(document): - if document.indexing_status not in {"waiting", "parsing", "cleaning", "splitting", "indexing"}: + if document.indexing_status not in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + }: raise DocumentIndexingError() # update document to be paused assert current_user is not None @@ -1793,7 +1815,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being retried, please try again later") # retry document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() @@ -1812,7 +1834,7 @@ class DocumentService: if cache_result is not None: raise ValueError("Document is being synced, please try again later") # sync document indexing - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING data_source_info = document.data_source_info_dict if data_source_info: data_source_info["mode"] = "scrape" @@ -1840,7 +1862,7 @@ class DocumentService: knowledge_config: KnowledgeConfig, account: Account | Any, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ) -> tuple[list[Document], str]: # check doc_form DatasetService.check_doc_form(dataset, knowledge_config.doc_form) @@ -1932,7 +1954,7 @@ class DocumentService: if not dataset_process_rule: process_rule = knowledge_config.process_rule if process_rule: - if process_rule.mode in ("custom", "hierarchical"): + if process_rule.mode in (ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL): if process_rule.rules: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -1944,7 +1966,7 @@ class DocumentService: dataset_process_rule = dataset.latest_process_rule if not dataset_process_rule: raise ValueError("No process rule found.") - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -1967,7 +1989,7 @@ class DocumentService: if not dataset_process_rule: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, rules=json.dumps(DatasetProcessRule.AUTOMATIC_RULES), created_by=account.id, ) @@ -2001,7 +2023,7 @@ class DocumentService: .where( Document.dataset_id == dataset.id, Document.tenant_id == current_user.current_tenant_id, - Document.data_source_type == "upload_file", + Document.data_source_type == DataSourceType.UPLOAD_FILE, Document.enabled == True, Document.name.in_(file_names), ) @@ -2021,7 +2043,7 @@ class DocumentService: document.doc_language = knowledge_config.doc_language document.data_source_info = json.dumps(data_source_info) document.batch = batch - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) documents.append(document) duplicate_document_ids.append(document.id) @@ -2056,7 +2078,7 @@ class DocumentService: .filter_by( dataset_id=dataset.id, tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, enabled=True, ) .all() @@ -2507,7 +2529,7 @@ class DocumentService: document_data: KnowledgeConfig, account: Account, dataset_process_rule: DatasetProcessRule | None = None, - created_from: str = "web", + created_from: str = DocumentCreatedFrom.WEB, ): assert isinstance(current_user, Account) @@ -2520,14 +2542,14 @@ class DocumentService: # save process rule if document_data.process_rule: process_rule = document_data.process_rule - if process_rule.mode in {"custom", "hierarchical"}: + if process_rule.mode in {ProcessRuleMode.CUSTOM, ProcessRuleMode.HIERARCHICAL}: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, rules=process_rule.rules.model_dump_json() if process_rule.rules else None, created_by=account.id, ) - elif process_rule.mode == "automatic": + elif process_rule.mode == ProcessRuleMode.AUTOMATIC: dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, mode=process_rule.mode, @@ -2609,7 +2631,7 @@ class DocumentService: if document_data.name: document.name = document_data.name # update document to be waiting - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING document.completed_at = None document.processing_started_at = None document.parsing_completed_at = None @@ -2623,7 +2645,7 @@ class DocumentService: # update document segment db.session.query(DocumentSegment).filter_by(document_id=document.id).update( - {DocumentSegment.status: "re_segment"} + {DocumentSegment.status: SegmentStatus.RE_SEGMENT} ) db.session.commit() # trigger async task @@ -2754,7 +2776,7 @@ class DocumentService: if knowledge_config.process_rule.mode not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if knowledge_config.process_rule.mode == "automatic": + if knowledge_config.process_rule.mode == ProcessRuleMode.AUTOMATIC: knowledge_config.process_rule.rules = None else: if not knowledge_config.process_rule.rules: @@ -2785,7 +2807,7 @@ class DocumentService: raise ValueError("Process rule segmentation separator is invalid") if not ( - knowledge_config.process_rule.mode == "hierarchical" + knowledge_config.process_rule.mode == ProcessRuleMode.HIERARCHICAL and knowledge_config.process_rule.rules.parent_mode == "full-doc" ): if not knowledge_config.process_rule.rules.segmentation.max_tokens: @@ -2814,7 +2836,7 @@ class DocumentService: if args["process_rule"]["mode"] not in DatasetProcessRule.MODES: raise ValueError("Process rule mode is invalid") - if args["process_rule"]["mode"] == "automatic": + if args["process_rule"]["mode"] == ProcessRuleMode.AUTOMATIC: args["process_rule"]["rules"] = {} else: if "rules" not in args["process_rule"] or not args["process_rule"]["rules"]: @@ -3021,7 +3043,7 @@ class DocumentService: @staticmethod def _prepare_disable_update(document, user, now): """Prepare updates for disabling a document.""" - if not document.completed_at or document.indexing_status != "completed": + if not document.completed_at or document.indexing_status != IndexingStatus.COMPLETED: raise DocumentIndexingError(f"Document: {document.name} is not completed.") if not document.enabled: @@ -3130,7 +3152,7 @@ class SegmentService: content=content, word_count=len(content), tokens=tokens, - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3167,7 +3189,7 @@ class SegmentService: logger.exception("create segment index failed") segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first() @@ -3227,7 +3249,7 @@ class SegmentService: word_count=len(content), tokens=tokens, keywords=segment_item.get("keywords", []), - status="completed", + status=SegmentStatus.COMPLETED, indexing_at=naive_utc_now(), completed_at=naive_utc_now(), created_by=current_user.id, @@ -3259,7 +3281,7 @@ class SegmentService: for segment_document in segment_data_list: segment_document.enabled = False segment_document.disabled_at = naive_utc_now() - segment_document.status = "error" + segment_document.status = SegmentStatus.ERROR segment_document.error = str(e) db.session.commit() return segment_data_list @@ -3405,7 +3427,7 @@ class SegmentService: segment.index_node_hash = segment_hash segment.word_count = len(content) segment.tokens = tokens - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.indexing_at = naive_utc_now() segment.completed_at = naive_utc_now() segment.updated_by = current_user.id @@ -3530,7 +3552,7 @@ class SegmentService: logger.exception("update segment index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) db.session.commit() new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first() diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index eeb14072bd..f3b2adb965 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -10,11 +10,11 @@ from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper import encrypter from core.helper.name_generator import generate_incremental_name from core.helper.provider_cache import NoOpProviderCredentialCache -from core.model_runtime.entities.provider_entities import FormType from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter +from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider @@ -824,6 +824,7 @@ class DatasourceProviderService: "langgenius/firecrawl_datasource", "langgenius/notion_datasource", "langgenius/jina_datasource", + "watercrawl/watercrawl_datasource", ]: datasource_provider_id = DatasourceProviderID(f"{datasource.plugin_id}/{datasource.provider}") credentials = self.list_datasource_credentials( diff --git a/api/services/enterprise/base.py b/api/services/enterprise/base.py index 744b7992f8..68835e76d0 100644 --- a/api/services/enterprise/base.py +++ b/api/services/enterprise/base.py @@ -6,6 +6,13 @@ from typing import Any import httpx from core.helper.trace_id_helper import generate_traceparent_header +from services.errors.enterprise import ( + EnterpriseAPIBadRequestError, + EnterpriseAPIError, + EnterpriseAPIForbiddenError, + EnterpriseAPINotFoundError, + EnterpriseAPIUnauthorizedError, +) logger = logging.getLogger(__name__) @@ -64,10 +71,51 @@ class BaseRequest: request_kwargs["timeout"] = timeout response = client.request(method, url, **request_kwargs) - if raise_for_status: - response.raise_for_status() + + # Validate HTTP status and raise domain-specific errors + if not response.is_success: + cls._handle_error_response(response) return response.json() + @classmethod + def _handle_error_response(cls, response: httpx.Response) -> None: + """ + Handle non-2xx HTTP responses by raising appropriate domain errors. + + Attempts to extract error message from JSON response body, + falls back to status text if parsing fails. + """ + error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}" + + # Try to extract error message from JSON response + try: + error_data = response.json() + if isinstance(error_data, dict): + # Common error response formats: + # {"error": "...", "message": "..."} + # {"message": "..."} + # {"detail": "..."} + error_message = ( + error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message + ) + except Exception: + # If JSON parsing fails, use the default message + logger.debug( + "Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True + ) + + # Raise specific error based on status code + if response.status_code == 400: + raise EnterpriseAPIBadRequestError(error_message) + elif response.status_code == 401: + raise EnterpriseAPIUnauthorizedError(error_message) + elif response.status_code == 403: + raise EnterpriseAPIForbiddenError(error_message) + elif response.status_code == 404: + raise EnterpriseAPINotFoundError(error_message) + else: + raise EnterpriseAPIError(error_message, status_code=response.status_code) + class EnterpriseRequest(BaseRequest): base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL") diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 71d456aa2d..5040fcc7e3 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,15 +1,26 @@ +from __future__ import annotations + import logging import uuid from datetime import datetime +from typing import TYPE_CHECKING from pydantic import BaseModel, ConfigDict, Field, model_validator from configs import dify_config +from extensions.ext_redis import redis_client from services.enterprise.base import EnterpriseRequest +if TYPE_CHECKING: + from services.feature_service import LicenseStatus + logger = logging.getLogger(__name__) DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0 +# License status cache configuration +LICENSE_STATUS_CACHE_KEY = "enterprise:license:status" +VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable +INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly class WebAppSettings(BaseModel): @@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel): model_config = ConfigDict(extra="forbid", populate_by_name=True) @model_validator(mode="after") - def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult": + def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult: if self.joined and not self.workspace_id: raise ValueError("workspace_id must be non-empty when joined is True") return self @@ -115,7 +126,6 @@ class EnterpriseService: "/default-workspace/members", json={"account_id": account_id}, timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS, - raise_for_status=True, ) if not isinstance(data, dict): raise ValueError("Invalid response format from enterprise default workspace API") @@ -223,3 +233,64 @@ class EnterpriseService: params = {"appId": app_id} EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params) + + @classmethod + def get_cached_license_status(cls) -> LicenseStatus | None: + """Get enterprise license status with Redis caching to reduce HTTP calls. + + Caches valid statuses (active/expiring) for 10 minutes and invalid statuses + (inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses + balances prompt license-fix detection against DoS mitigation — without + caching, every request on an expired license would hit the enterprise API. + + Returns: + LicenseStatus enum value, or None if enterprise is disabled / unreachable. + """ + if not dify_config.ENTERPRISE_ENABLED: + return None + + cached = cls._read_cached_license_status() + if cached is not None: + return cached + + return cls._fetch_and_cache_license_status() + + @classmethod + def _read_cached_license_status(cls) -> LicenseStatus | None: + """Read license status from Redis cache, returning None on miss or failure.""" + from services.feature_service import LicenseStatus + + try: + raw = redis_client.get(LICENSE_STATUS_CACHE_KEY) + if raw: + value = raw.decode("utf-8") if isinstance(raw, bytes) else raw + return LicenseStatus(value) + except Exception: + logger.debug("Failed to read license status from cache", exc_info=True) + return None + + @classmethod + def _fetch_and_cache_license_status(cls) -> LicenseStatus | None: + """Fetch license status from enterprise API and cache the result.""" + from services.feature_service import LicenseStatus + + try: + info = cls.get_info() + license_info = info.get("License") + if not license_info: + return None + + status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + ttl = ( + VALID_LICENSE_CACHE_TTL + if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING) + else INVALID_LICENSE_CACHE_TTL + ) + try: + redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status) + except Exception: + logger.debug("Failed to cache license status", exc_info=True) + return status + except Exception: + logger.debug("Failed to fetch enterprise license status", exc_info=True) + return None diff --git a/api/services/enterprise/plugin_manager_service.py b/api/services/enterprise/plugin_manager_service.py index 817dbd95f8..d4be36305e 100644 --- a/api/services/enterprise/plugin_manager_service.py +++ b/api/services/enterprise/plugin_manager_service.py @@ -3,6 +3,7 @@ import logging from pydantic import BaseModel +from configs import dify_config from services.enterprise.base import EnterprisePluginManagerRequest from services.errors.base import BaseServiceError @@ -28,6 +29,11 @@ class CheckCredentialPolicyComplianceRequest(BaseModel): return data +class PreUninstallPluginRequest(BaseModel): + tenant_id: str + plugin_unique_identifier: str + + class CredentialPolicyViolationError(BaseServiceError): pass @@ -55,3 +61,20 @@ class PluginManagerService: body.dify_credential_id, ret.get("result", False), ) + + @classmethod + def try_pre_uninstall_plugin(cls, body: PreUninstallPluginRequest): + try: + # the invocation must be synchronous. + EnterprisePluginManagerRequest.send_request( + "POST", + "/pre-uninstall-plugin", + json=body.model_dump(), + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + except Exception: + logger.exception( + "failed to perform pre uninstall plugin hook. tenant_id: %s, plugin_unique_identifier: %s", + body.tenant_id, + body.plugin_unique_identifier, + ) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 8dc5b93501..66309f0e59 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -1,8 +1,9 @@ from enum import StrEnum from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, field_validator +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -127,6 +128,18 @@ class KnowledgeConfig(BaseModel): name: str | None = None is_multimodal: bool = False + @field_validator("doc_form") + @classmethod + def validate_doc_form(cls, value: str) -> str: + valid_forms = [ + IndexStructureType.PARAGRAPH_INDEX, + IndexStructureType.QA_INDEX, + IndexStructureType.PARENT_CHILD_INDEX, + ] + if value not in valid_forms: + raise ValueError("Invalid doc_form.") + return value + class SegmentCreateArgs(BaseModel): content: str | None = None diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index a29d848ac5..9dd595f516 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -15,9 +15,9 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, ModelCredentialSchema, ProviderCredentialSchema, diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index 697e691224..15f004463d 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -7,6 +7,7 @@ from . import ( conversation, dataset, document, + enterprise, file, index, message, @@ -21,6 +22,7 @@ __all__ = [ "conversation", "dataset", "document", + "enterprise", "file", "index", "message", diff --git a/api/services/errors/enterprise.py b/api/services/errors/enterprise.py new file mode 100644 index 0000000000..c9126199fd --- /dev/null +++ b/api/services/errors/enterprise.py @@ -0,0 +1,45 @@ +"""Enterprise service errors.""" + +from services.errors.base import BaseServiceError + + +class EnterpriseServiceError(BaseServiceError): + """Base exception for enterprise service errors.""" + + def __init__(self, description: str | None = None, status_code: int | None = None): + super().__init__(description) + self.status_code = status_code + + +class EnterpriseAPIError(EnterpriseServiceError): + """Generic enterprise API error (non-2xx response).""" + + pass + + +class EnterpriseAPINotFoundError(EnterpriseServiceError): + """Enterprise API returned 404 Not Found.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=404) + + +class EnterpriseAPIForbiddenError(EnterpriseServiceError): + """Enterprise API returned 403 Forbidden.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=403) + + +class EnterpriseAPIUnauthorizedError(EnterpriseServiceError): + """Enterprise API returned 401 Unauthorized.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=401) + + +class EnterpriseAPIBadRequestError(EnterpriseServiceError): + """Enterprise API returned 400 Bad Request.""" + + def __init__(self, description: str | None = None): + super().__init__(description, status_code=400) diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 65dd41af43..4cf42b7f44 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -9,7 +9,7 @@ from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from core.workflow.nodes.http_request.exc import InvalidHttpMethodError +from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import ( diff --git a/api/services/feature_service.py b/api/services/feature_service.py index fda3a15144..f38e1762d1 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -379,14 +379,19 @@ class FeatureService: ) features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if is_authenticated and (license_info := enterprise_info.get("License")): + # SECURITY NOTE: Only license *status* is exposed to unauthenticated callers + # so the login page can detect an expired/inactive license after force-logout. + # All other license details (expiry date, workspace usage) remain auth-gated. + # This behavior reflects prior internal review of information-leakage risks. + if license_info := enterprise_info.get("License"): features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - features.license.expired_at = license_info.get("expiredAt", "") - if workspaces_info := license_info.get("workspaces"): - features.license.workspaces.enabled = workspaces_info.get("enabled", False) - features.license.workspaces.limit = workspaces_info.get("limit", 0) - features.license.workspaces.size = workspaces_info.get("used", 0) + if is_authenticated: + features.license.expired_at = license_info.get("expiredAt", "") + if workspaces_info := license_info.get("workspaces"): + features.license.workspaces.enabled = workspaces_info.get("enabled", False) + features.license.workspaces.limit = workspaces_info.get("limit", 0) + features.license.workspaces.size = workspaces_info.get("used", 0) if "PluginInstallationPermission" in enterprise_info: plugin_installation_info = enterprise_info["PluginInstallationPermission"] diff --git a/api/services/feedback_service.py b/api/services/feedback_service.py index 1a1cbbb450..e7473d371b 100644 --- a/api/services/feedback_service.py +++ b/api/services/feedback_service.py @@ -7,6 +7,7 @@ from flask import Response from sqlalchemy import or_ from extensions.ext_database import db +from models.enums import FeedbackRating from models.model import Account, App, Conversation, Message, MessageFeedback @@ -100,7 +101,7 @@ class FeedbackService: "ai_response": message.answer[:500] + "..." if len(message.answer) > 500 else message.answer, # Truncate long responses - "feedback_rating": "👍" if feedback.rating == "like" else "👎", + "feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎", "feedback_rating_raw": feedback.rating, "feedback_comment": feedback.content or "", "feedback_source": feedback.from_source, diff --git a/api/services/file_service.py b/api/services/file_service.py index da99a66bb9..a7060f3b92 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -20,9 +20,10 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from core.workflow.file import helpers as file_helpers +from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from models import Account @@ -58,8 +59,9 @@ class FileService: # get file extension extension = os.path.splitext(filename)[1].lstrip(".").lower() - # check if filename contains invalid characters - if any(c in filename for c in ["/", "\\", ":", "*", "?", '"', "<", ">", "|"]): + # Only reject path separators here. The original filename is stored as metadata, + # while the storage key is UUID-based. + if any(c in filename for c in ["/", "\\"]): raise ValueError("Filename contains invalid characters") if len(filename) > 200: @@ -92,7 +94,7 @@ class FileService: # save file to db upload_file = UploadFile( tenant_id=current_tenant_id or "", - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=filename, size=file_size, @@ -151,7 +153,7 @@ class FileService: # save file to db upload_file = UploadFile( tenant_id=tenant_id, - storage_type=dify_config.STORAGE_TYPE, + storage_type=StorageType(dify_config.STORAGE_TYPE), key=file_key, name=text_name, size=len(text), diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 8cbf3a25c3..9993d24c70 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -4,15 +4,16 @@ import time from typing import Any from core.app.app_config.entities import ModelConfig -from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery +from models.enums import CreatorUserRole, DatasetQuerySource logger = logging.getLogger(__name__) @@ -96,9 +97,9 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=json.dumps(dataset_queries), - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) db.session.add(dataset_query) @@ -136,9 +137,9 @@ class HitTestingService: dataset_query = DatasetQuery( dataset_id=dataset.id, content=query, - source="hit_testing", + source=DatasetQuerySource.HIT_TESTING, source_app_id=None, - created_by_role="account", + created_by_role=CreatorUserRole.ACCOUNT, created_by=account.id, ) diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index ff37ff098f..229e6608da 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,14 +8,14 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from core.workflow.runtime import VariablePool +from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_template_renderer import render_email_template @@ -155,13 +155,15 @@ class EmailDeliveryTestHandler: context=context, recipient_email=recipient_email, ) - subject = render_email_template(method.config.subject, substitutions) + subject_template = render_email_template(method.config.subject, substitutions) + subject = EmailDeliveryConfig.sanitize_subject(subject_template) templated_body = EmailDeliveryConfig.render_body_template( body=method.config.body, url=substitutions.get("form_link"), variable_pool=context.variable_pool, ) body = render_email_template(templated_body, substitutions) + body = EmailDeliveryConfig.render_markdown_body(body) mail.send( to=recipient_email, @@ -245,5 +247,6 @@ class EmailDeliveryTestHandler: ) if token: substitutions["form_token"] = token - substitutions["form_link"] = _build_form_link(token) or "" + link = _build_form_link(token) + substitutions["form_link"] = link if link is not None else f"/form/{token}" return substitutions diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 87816643f6..2e74c50963 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -11,12 +11,12 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, HumanInputSubmissionValidationError, validate_human_input_submission, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType @@ -130,7 +130,7 @@ class HumanInputService: if isinstance(session_factory, Engine): session_factory = sessionmaker(bind=session_factory) self._session_factory = session_factory - self._form_repository = form_repository or HumanInputFormSubmissionRepository(session_factory) + self._form_repository = form_repository or HumanInputFormSubmissionRepository() def get_form_by_token(self, form_token: str) -> Form | None: record = self._form_repository.get_by_token(form_token) diff --git a/api/services/message_service.py b/api/services/message_service.py index ce699e79d4..fc87802f51 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -9,13 +9,14 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback from repositories.execution_extra_content_repository import ExecutionExtraContentRepository from repositories.sqlalchemy_execution_extra_content_repository import ( @@ -172,7 +173,7 @@ class MessageService: app_model: App, message_id: str, user: Union[Account, EndUser] | None, - rating: str | None, + rating: FeedbackRating | None, content: str | None, ): if not user: @@ -197,7 +198,7 @@ class MessageService: message_id=message.id, rating=rating, content=content, - from_source=("user" if isinstance(user, EndUser) else "admin"), + from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN), from_end_user_id=(user.id if isinstance(user, EndUser) else None), from_account_id=(user.id if isinstance(user, Account) else None), ) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 859fc1902b..2f47a647a8 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -7,6 +7,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from libs.login import current_account_with_tenant from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding +from models.enums import DatasetMetadataType from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ( MetadataArgs, @@ -130,11 +131,11 @@ class MetadataService: @staticmethod def get_built_in_fields(): return [ - {"name": BuiltInField.document_name, "type": "string"}, - {"name": BuiltInField.uploader, "type": "string"}, - {"name": BuiltInField.upload_date, "type": "time"}, - {"name": BuiltInField.last_update_date, "type": "time"}, - {"name": BuiltInField.source, "type": "string"}, + {"name": BuiltInField.document_name, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.uploader, "type": DatasetMetadataType.STRING}, + {"name": BuiltInField.upload_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.last_update_date, "type": DatasetMetadataType.TIME}, + {"name": BuiltInField.source, "type": DatasetMetadataType.STRING}, ] @staticmethod diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index 69da3bfb79..bf3b6db3ed 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,15 +10,16 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ModelCredentialSchema, ProviderCredentialSchema, ) -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.provider_manager import ProviderManager +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now +from models.enums import CredentialSourceType from models.provider import LoadBalancingModelConfig, ProviderCredential, ProviderModelCredential logger = logging.getLogger(__name__) @@ -103,9 +104,9 @@ class ModelLoadBalancingService: is_load_balancing_enabled = True if config_from == "predefined-model": - credential_source_type = "provider" + credential_source_type = CredentialSourceType.PROVIDER else: - credential_source_type = "custom_model" + credential_source_type = CredentialSourceType.CUSTOM_MODEL # Get load balancing configurations load_balancing_configs = ( @@ -421,7 +422,11 @@ class ModelLoadBalancingService: raise ValueError("Invalid load balancing config name") if credential_id: - credential_source = "provider" if config_from == "predefined-model" else "custom_model" + credential_source = ( + CredentialSourceType.PROVIDER + if config_from == "predefined-model" + else CredentialSourceType.CUSTOM_MODEL + ) assert credential_record is not None load_balancing_model_config = LoadBalancingModelConfig( tenant_id=tenant_id, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index edd1004b82..0ddd6b9b1a 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity -from core.model_runtime.entities.model_entities import ModelType, ParameterRule -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 6eed3a6b38..ca83742d65 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -30,8 +30,12 @@ from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_database import db from extensions.ext_redis import redis_client -from models.provider import Provider, ProviderCredential +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from models.provider_ids import GenericProviderID +from services.enterprise.plugin_manager_service import ( + PluginManagerService, + PreUninstallPluginRequest, +) from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import FeatureService, PluginInstallationScope @@ -519,10 +523,24 @@ class PluginService: if not plugin: return manager.uninstall(tenant_id, plugin_installation_id) + if dify_config.ENTERPRISE_ENABLED: + PluginManagerService.try_pre_uninstall_plugin( + PreUninstallPluginRequest( + tenant_id=tenant_id, + plugin_unique_identifier=plugin.plugin_unique_identifier, + ) + ) with Session(db.engine) as session, session.begin(): plugin_id = plugin.plugin_id logger.info("Deleting credentials for plugin: %s", plugin_id) + session.execute( + delete(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"), + ) + ) + # Delete provider credentials that match this plugin credential_ids = session.scalars( select(ProviderCredential.id).where( diff --git a/api/services/rag_pipeline/pipeline_generate_service.py b/api/services/rag_pipeline/pipeline_generate_service.py index f397b28283..07e1b8f20e 100644 --- a/api/services/rag_pipeline/pipeline_generate_service.py +++ b/api/services/rag_pipeline/pipeline_generate_service.py @@ -6,6 +6,7 @@ from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from models.dataset import Document, Pipeline +from models.enums import IndexingStatus from models.model import Account, App, EndUser from models.workflow import Workflow from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -111,6 +112,6 @@ class PipelineGenerateService: """ document = db.session.query(Document).where(Document.id == document_id).first() if document: - document.indexing_status = "waiting" + document.indexing_status = IndexingStatus.WAITING db.session.add(document) db.session.commit() diff --git a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py index 571ca6c7a6..f996db11dc 100644 --- a/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py +++ b/api/services/rag_pipeline/pipeline_template/remote/remote_retrieval.py @@ -15,7 +15,8 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): Retrieval recommended app from dify official """ - def get_pipeline_template_detail(self, template_id: str): + def get_pipeline_template_detail(self, template_id: str) -> dict | None: + result: dict | None try: result = self.fetch_pipeline_template_detail_from_dify_official(template_id) except Exception as e: @@ -35,17 +36,23 @@ class RemotePipelineTemplateRetrieval(PipelineTemplateRetrievalBase): return PipelineTemplateType.REMOTE @classmethod - def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict | None: + def fetch_pipeline_template_detail_from_dify_official(cls, template_id: str) -> dict: """ Fetch pipeline template detail from dify official. - :param template_id: Pipeline ID - :return: + + :param template_id: Pipeline template ID + :return: Template detail dict + :raises ValueError: When upstream returns a non-200 status code """ domain = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_REMOTE_DOMAIN url = f"{domain}/pipeline-templates/{template_id}" response = httpx.get(url, timeout=httpx.Timeout(10.0, connect=3.0)) if response.status_code != 200: - return None + raise ValueError( + "fetch pipeline template detail failed," + + f" status_code: {response.status_code}," + + f" response: {response.text[:1000]}" + ) data: dict = response.json() return data diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index c0f9e4f323..296b9f0890 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -36,23 +36,23 @@ from core.rag.entities.event import ( ) from core.repositories.factory import DifyCoreRepositoryFactory from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import ErrorStrategy, NodeType, SystemVariableKey -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from core.workflow.graph_events.base import GraphNodeEventBase -from core.workflow.node_events.base import NodeRunResult -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.variables import VariableBase -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.graph_events.base import GraphNodeEventBase +from dify_graph.node_events.base import NodeRunResult +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.variables import VariableBase from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -64,7 +64,7 @@ from models.dataset import ( # type: ignore PipelineCustomizedTemplate, PipelineRecommendedPlugin, ) -from models.enums import WorkflowRunTriggeredFrom +from models.enums import IndexingStatus, WorkflowRunTriggeredFrom from models.model import EndUser from models.workflow import ( Workflow, @@ -79,10 +79,11 @@ from services.entities.knowledge_entities.rag_pipeline_entities import ( KnowledgeConfiguration, PipelineTemplateInfoEntity, ) -from services.errors.app import WorkflowHashNotEqualError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.rag_pipeline.pipeline_template.pipeline_template_factory import PipelineTemplateRetrievalFactory from services.tools.builtin_tools_manage_service import BuiltinToolManageService from services.workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader +from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) @@ -117,13 +118,21 @@ class RagPipelineService: def get_pipeline_template_detail(cls, template_id: str, type: str = "built-in") -> dict | None: """ Get pipeline template detail. + :param template_id: template id - :return: + :param type: template type, "built-in" or "customized" + :return: template detail dict, or None if not found """ if type == "built-in": mode = dify_config.HOSTED_FETCH_PIPELINE_TEMPLATES_MODE retrieval_instance = PipelineTemplateRetrievalFactory.get_pipeline_template_factory(mode)() built_in_result: dict | None = retrieval_instance.get_pipeline_template_detail(template_id) + if built_in_result is None: + logger.warning( + "pipeline template retrieval returned empty result, template_id: %s, mode: %s", + template_id, + mode, + ) return built_in_result else: mode = "customized" @@ -226,6 +235,21 @@ class RagPipelineService: return workflow + def get_published_workflow_by_id(self, pipeline: Pipeline, workflow_id: str) -> Workflow | None: + """Fetch a published workflow snapshot by ID for restore operations.""" + workflow = ( + db.session.query(Workflow) + .where( + Workflow.tenant_id == pipeline.tenant_id, + Workflow.app_id == pipeline.id, + Workflow.id == workflow_id, + ) + .first() + ) + if workflow and workflow.version == Workflow.VERSION_DRAFT: + raise IsDraftWorkflowError("source workflow must be published") + return workflow + def get_all_published_workflow( self, *, @@ -319,6 +343,42 @@ class RagPipelineService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + pipeline: Pipeline, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published pipeline workflow snapshot into the draft workflow. + + Pipelines reuse the shared draft-restore field copy helper, but still own + the pipeline-specific flush/link step that wires a newly created draft + back onto ``pipeline.workflow_id``. + """ + source_workflow = self.get_published_workflow_by_id(pipeline=pipeline, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + draft_workflow = self.get_draft_workflow(pipeline=pipeline) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=pipeline.tenant_id, + app_id=pipeline.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=lambda: datetime.now(UTC).replace(tzinfo=None), + ) + + if is_new_draft: + db.session.add(draft_workflow) + db.session.flush() + pipeline.workflow_id = draft_workflow.id + + db.session.commit() + + return draft_workflow + def publish_workflow( self, *, @@ -381,10 +441,10 @@ class RagPipelineService: """ # return default block config default_block_configs: list[dict[str, Any]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None - if node_type is NodeType.HTTP_REQUEST: + if node_type == BuiltinNodeTypes.HTTP_REQUEST: filters = { HTTP_REQUEST_CONFIG_FILTER_KEY: build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -410,14 +470,15 @@ class RagPipelineService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return None - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] final_filters = dict(filters) if filters else {} - if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: + if node_type_enum == BuiltinNodeTypes.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in final_filters: final_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, @@ -471,6 +532,7 @@ class RagPipelineService: engine=db.engine, app_id=pipeline.id, tenant_id=pipeline.tenant_id, + user_id=account.id, ), ), start_at=start_at, @@ -499,7 +561,7 @@ class RagPipelineService: session=session, app_id=pipeline.id, node_id=workflow_node_execution.node_id, - node_type=NodeType(workflow_node_execution.node_type), + node_type=workflow_node_execution.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, user=account, @@ -904,7 +966,7 @@ class RagPipelineService: if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = error db.session.add(document) db.session.commit() @@ -1236,6 +1298,7 @@ class RagPipelineService: engine=db.engine, app_id=pipeline.id, tenant_id=pipeline.tenant_id, + user_id=current_user.id, ), ), start_at=start_at, @@ -1261,7 +1324,7 @@ class RagPipelineService: session=session, app_id=pipeline.id, node_id=workflow_node_execution_db_model.node_id, - node_type=NodeType(workflow_node_execution_db_model.node_type), + node_type=workflow_node_execution_db_model.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=workflow_node_execution.id, user=current_user, diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index be1ce834f6..deb59da8d3 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -21,19 +21,21 @@ from sqlalchemy.orm import Session from core.helper import ssrf_proxy from core.helper.name_generator import generate_incremental_name -from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import PluginDependency -from core.workflow.enums import NodeType from core.workflow.nodes.datasource.entities import DatasourceNodeData +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from core.workflow.nodes.llm.entities import LLMNodeData -from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData -from core.workflow.nodes.tool.entities import ToolNodeData +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes.llm.entities import LLMNodeData +from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData +from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory from models import Account from models.dataset import Dataset, DatasetCollectionBinding, Pipeline +from models.enums import CollectionBindingType, DatasetRuntimeMode from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import ( IconInfo, @@ -287,7 +289,7 @@ class RagPipelineDslService: nodes = graph.get("nodes", []) dataset_id = None for node in nodes: - if node.get("data", {}).get("type") == "knowledge-index": + if node.get("data", {}).get("type") == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if ( dataset @@ -312,7 +314,7 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) if knowledge_configuration.indexing_technique == "high_quality": @@ -322,7 +324,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -333,7 +335,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() @@ -428,7 +430,7 @@ class RagPipelineDslService: nodes = graph.get("nodes", []) dataset_id = None for node in nodes: - if node.get("data", {}).get("type") == "knowledge-index": + if node.get("data", {}).get("type") == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) if not dataset: dataset = Dataset( @@ -444,13 +446,13 @@ class RagPipelineDslService: indexing_technique=knowledge_configuration.indexing_technique, created_by=account.id, retrieval_model=knowledge_configuration.retrieval_model.model_dump(), - runtime_mode="rag_pipeline", + runtime_mode=DatasetRuntimeMode.RAG_PIPELINE, chunk_structure=knowledge_configuration.chunk_structure, ) else: dataset.indexing_technique = knowledge_configuration.indexing_technique dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump() - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.chunk_structure = knowledge_configuration.chunk_structure if knowledge_configuration.indexing_technique == "high_quality": dataset_collection_binding = ( @@ -459,7 +461,7 @@ class RagPipelineDslService: DatasetCollectionBinding.provider_name == knowledge_configuration.embedding_model_provider, DatasetCollectionBinding.model_name == knowledge_configuration.embedding_model, - DatasetCollectionBinding.type == "dataset", + DatasetCollectionBinding.type == CollectionBindingType.DATASET, ) .order_by(DatasetCollectionBinding.created_at) .first() @@ -470,7 +472,7 @@ class RagPipelineDslService: provider_name=knowledge_configuration.embedding_model_provider, model_name=knowledge_configuration.embedding_model, collection_name=Dataset.gen_collection_name_by_id(str(uuid.uuid4())), - type="dataset", + type=CollectionBindingType.DATASET, ) self._session.add(dataset_collection_binding) self._session.commit() @@ -562,7 +564,7 @@ class RagPipelineDslService: graph = workflow_data.get("graph", {}) for node in graph.get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL: + if node.get("data", {}).get("type", "") == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node["data"].get("dataset_ids", []) node["data"]["dataset_ids"] = [ decrypted_id @@ -696,17 +698,17 @@ class RagPipelineDslService: if not node_data: continue data_type = node_data.get("type", "") - if data_type == NodeType.KNOWLEDGE_RETRIEVAL: + if data_type == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: dataset_ids = node_data.get("dataset_ids", []) node["data"]["dataset_ids"] = [ self.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id) for dataset_id in dataset_ids ] # filter credential id from tool node - if not include_secret and data_type == NodeType.TOOL: + if not include_secret and data_type == BuiltinNodeTypes.TOOL: node_data.pop("credential_id", None) # filter credential id from agent node - if not include_secret and data_type == NodeType.AGENT: + if not include_secret and data_type == BuiltinNodeTypes.AGENT: for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): tool.pop("credential_id", None) @@ -740,35 +742,35 @@ class RagPipelineDslService: try: typ = node.get("data", {}).get("type") match typ: - case NodeType.TOOL: + case BuiltinNodeTypes.TOOL: tool_entity = ToolNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id), ) - case NodeType.DATASOURCE: + case BuiltinNodeTypes.DATASOURCE: datasource_entity = DatasourceNodeData.model_validate(node["data"]) if datasource_entity.provider_type != "local_file": dependencies.append(datasource_entity.plugin_id) - case NodeType.LLM: + case BuiltinNodeTypes.LLM: llm_entity = LLMNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider), ) - case NodeType.QUESTION_CLASSIFIER: + case BuiltinNodeTypes.QUESTION_CLASSIFIER: question_classifier_entity = QuestionClassifierNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( question_classifier_entity.model.provider ), ) - case NodeType.PARAMETER_EXTRACTOR: + case BuiltinNodeTypes.PARAMETER_EXTRACTOR: parameter_extractor_entity = ParameterExtractorNodeData.model_validate(node["data"]) dependencies.append( DependenciesAnalysisService.analyze_model_provider_dependency( parameter_extractor_entity.model.provider ), ) - case NodeType.KNOWLEDGE_INDEX: + case _ if typ == KNOWLEDGE_INDEX_NODE_TYPE: knowledge_index_entity = KnowledgeConfiguration.model_validate(node["data"]) if knowledge_index_entity.indexing_technique == "high_quality": if knowledge_index_entity.embedding_model_provider: @@ -789,7 +791,7 @@ class RagPipelineDslService: knowledge_index_entity.retrieval_model.reranking_model.reranking_provider_name ), ) - case NodeType.KNOWLEDGE_RETRIEVAL: + case BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: knowledge_retrieval_entity = KnowledgeRetrievalNodeData.model_validate(node["data"]) if knowledge_retrieval_entity.retrieval_mode == "multiple": if knowledge_retrieval_entity.multiple_retrieval_config: diff --git a/api/services/rag_pipeline/rag_pipeline_transform_service.py b/api/services/rag_pipeline/rag_pipeline_transform_service.py index d0dfbc1070..1d0aafd5fd 100644 --- a/api/services/rag_pipeline/rag_pipeline_transform_service.py +++ b/api/services/rag_pipeline/rag_pipeline_transform_service.py @@ -13,6 +13,7 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from factories import variable_factory from models.dataset import Dataset, Document, DocumentPipelineExecutionLog, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType from models.model import UploadFile from models.workflow import Workflow, WorkflowType from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration, RetrievalSetting @@ -27,7 +28,7 @@ class RagPipelineTransformService: dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first() if not dataset: raise ValueError("Dataset not found") - if dataset.pipeline_id and dataset.runtime_mode == "rag_pipeline": + if dataset.pipeline_id and dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE: return { "pipeline_id": dataset.pipeline_id, "dataset_id": dataset_id, @@ -63,7 +64,12 @@ class RagPipelineTransformService: ): node = self._deal_file_extensions(node) if node.get("data", {}).get("type") == "knowledge-index": - node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node) + knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {})) + if dataset.tenant_id != current_user.current_tenant_id: + raise ValueError("Unauthorized") + node = self._deal_knowledge_index( + knowledge_configuration, dataset, indexing_technique, retrieval_model, node + ) new_nodes.append(node) if new_nodes: graph["nodes"] = new_nodes @@ -80,7 +86,7 @@ class RagPipelineTransformService: else: raise ValueError("Unsupported doc form") - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.pipeline_id = pipeline.id # deal document data @@ -97,7 +103,7 @@ class RagPipelineTransformService: pipeline_yaml = {} if doc_form == "text_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: if indexing_technique == "high_quality": # get graph from transform.file-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/file-general-high-quality.yml") as f: @@ -106,7 +112,7 @@ class RagPipelineTransformService: # get graph from transform.file-general-economy.yml with open(f"{Path(__file__).parent}/transform/file-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: if indexing_technique == "high_quality": # get graph from transform.notion-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/notion-general-high-quality.yml") as f: @@ -115,7 +121,7 @@ class RagPipelineTransformService: # get graph from transform.notion-general-economy.yml with open(f"{Path(__file__).parent}/transform/notion-general-economy.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: if indexing_technique == "high_quality": # get graph from transform.website-crawl-general-high-quality.yml with open(f"{Path(__file__).parent}/transform/website-crawl-general-high-quality.yml") as f: @@ -128,15 +134,15 @@ class RagPipelineTransformService: raise ValueError("Unsupported datasource type") elif doc_form == "hierarchical_model": match datasource_type: - case "upload_file": + case DataSourceType.UPLOAD_FILE: # get graph from transform.file-parentchild.yml with open(f"{Path(__file__).parent}/transform/file-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "notion_import": + case DataSourceType.NOTION_IMPORT: # get graph from transform.notion-parentchild.yml with open(f"{Path(__file__).parent}/transform/notion-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) - case "website_crawl": + case DataSourceType.WEBSITE_CRAWL: # get graph from transform.website-crawl-parentchild.yml with open(f"{Path(__file__).parent}/transform/website-crawl-parentchild.yml") as f: pipeline_yaml = yaml.safe_load(f) @@ -155,14 +161,13 @@ class RagPipelineTransformService: def _deal_knowledge_index( self, + knowledge_configuration: KnowledgeConfiguration, dataset: Dataset, - doc_form: str, indexing_technique: str | None, retrieval_model: RetrievalSetting | None, node: dict, ): knowledge_configuration_dict = node.get("data", {}) - knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict) if indexing_technique == "high_quality": knowledge_configuration.embedding_model = dataset.embedding_model @@ -283,7 +288,7 @@ class RagPipelineTransformService: db.session.flush() dataset.pipeline_id = pipeline.id - dataset.runtime_mode = "rag_pipeline" + dataset.runtime_mode = DatasetRuntimeMode.RAG_PIPELINE dataset.updated_by = current_user.id dataset.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.add(dataset) @@ -306,8 +311,8 @@ class RagPipelineTransformService: data_source_info_dict = document.data_source_info_dict if not data_source_info_dict: continue - if document.data_source_type == "upload_file": - document.data_source_type = "local_file" + if document.data_source_type == DataSourceType.UPLOAD_FILE: + document.data_source_type = DataSourceType.LOCAL_FILE file_id = data_source_info_dict.get("upload_file_id") if file_id: file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() @@ -327,7 +332,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="local_file", + datasource_type=DataSourceType.LOCAL_FILE, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -336,8 +341,8 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "notion_import": - document.data_source_type = "online_document" + elif document.data_source_type == DataSourceType.NOTION_IMPORT: + document.data_source_type = DataSourceType.ONLINE_DOCUMENT data_source_info = json.dumps( { "workspace_id": data_source_info_dict.get("notion_workspace_id"), @@ -355,7 +360,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="online_document", + datasource_type=DataSourceType.ONLINE_DOCUMENT, datasource_info=data_source_info, input_data={}, created_by=document.created_by, @@ -364,8 +369,7 @@ class RagPipelineTransformService: document_pipeline_execution_log.created_at = document.created_at db.session.add(document) db.session.add(document_pipeline_execution_log) - elif document.data_source_type == "website_crawl": - document.data_source_type = "website_crawl" + elif document.data_source_type == DataSourceType.WEBSITE_CRAWL: data_source_info = json.dumps( { "source_url": data_source_info_dict.get("url"), @@ -384,7 +388,7 @@ class RagPipelineTransformService: document_pipeline_execution_log = DocumentPipelineExecutionLog( document_id=document.id, pipeline_id=dataset.pipeline_id, - datasource_type="website_crawl", + datasource_type=DataSourceType.WEBSITE_CRAWL, datasource_info=data_source_info, input_data={}, created_by=document.created_by, diff --git a/api/services/retention/conversation/message_export_service.py b/api/services/retention/conversation/message_export_service.py new file mode 100644 index 0000000000..fbe0d2795d --- /dev/null +++ b/api/services/retention/conversation/message_export_service.py @@ -0,0 +1,304 @@ +""" +Export app messages to JSONL.GZ format. + +Outputs: conversation_id, message_id, query, answer, inputs (raw JSON), +retriever_resources (from message_metadata), feedback (user feedbacks array). + +Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1. +Does NOT touch Message.inputs / Message.user_feedback properties. +""" + +import datetime +import gzip +import json +import logging +import tempfile +from collections import defaultdict +from collections.abc import Generator, Iterable +from pathlib import Path, PurePosixPath +from typing import Any, BinaryIO, cast + +import orjson +import sqlalchemy as sa +from pydantic import BaseModel, ConfigDict, Field +from sqlalchemy import select, tuple_ +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from extensions.ext_storage import storage +from models.model import Message, MessageFeedback + +logger = logging.getLogger(__name__) + +MAX_FILENAME_BASE_LENGTH = 1024 +FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz") + + +class AppMessageExportFeedback(BaseModel): + id: str + app_id: str + conversation_id: str + message_id: str + rating: str + content: str | None = None + from_source: str + from_end_user_id: str | None = None + from_account_id: str | None = None + created_at: str + updated_at: str + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportRecord(BaseModel): + conversation_id: str + message_id: str + query: str + answer: str + inputs: dict[str, Any] + retriever_resources: list[Any] = Field(default_factory=list) + feedback: list[AppMessageExportFeedback] = Field(default_factory=list) + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportStats(BaseModel): + batches: int = 0 + total_messages: int = 0 + messages_with_feedback: int = 0 + total_feedbacks: int = 0 + + model_config = ConfigDict(extra="forbid") + + +class AppMessageExportService: + @staticmethod + def validate_export_filename(filename: str) -> str: + normalized = filename.strip() + if not normalized: + raise ValueError("--filename must not be empty.") + + normalized_lower = normalized.lower() + if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES): + raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.") + + if normalized.startswith("/"): + raise ValueError("--filename must be a relative path; absolute paths are not allowed.") + + if "\\" in normalized: + raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.") + + if "//" in normalized: + raise ValueError("--filename must not contain empty path segments ('//').") + + if len(normalized) > MAX_FILENAME_BASE_LENGTH: + raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.") + + for ch in normalized: + if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127: + raise ValueError("--filename must not contain control characters or NUL.") + + parts = PurePosixPath(normalized).parts + if not parts: + raise ValueError("--filename must include a file name.") + + if any(part in (".", "..") for part in parts): + raise ValueError("--filename must not contain '.' or '..' path segments.") + + return normalized + + @property + def output_gz_name(self) -> str: + return f"{self._filename_base}.jsonl.gz" + + @property + def output_jsonl_name(self) -> str: + return f"{self._filename_base}.jsonl" + + def __init__( + self, + app_id: str, + end_before: datetime.datetime, + filename: str, + *, + start_from: datetime.datetime | None = None, + batch_size: int = 1000, + use_cloud_storage: bool = False, + dry_run: bool = False, + ) -> None: + if start_from and start_from >= end_before: + raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})") + + self._app_id = app_id + self._end_before = end_before + self._start_from = start_from + self._filename_base = self.validate_export_filename(filename) + self._batch_size = batch_size + self._use_cloud_storage = use_cloud_storage + self._dry_run = dry_run + + def run(self) -> AppMessageExportStats: + stats = AppMessageExportStats() + + logger.info( + "export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s", + self._app_id, + self._start_from, + self._end_before, + self._dry_run, + self._use_cloud_storage, + self.output_gz_name, + ) + + if self._dry_run: + for _ in self._iter_records_with_stats(stats): + pass + self._finalize_stats(stats) + return stats + + if self._use_cloud_storage: + self._export_to_cloud(stats) + else: + self._export_to_local(stats) + + self._finalize_stats(stats) + return stats + + def iter_records(self) -> Generator[AppMessageExportRecord, None, None]: + for batch in self._iter_record_batches(): + yield from batch + + @staticmethod + def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None: + with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz: + for record in records: + gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n") + + def _export_to_local(self, stats: AppMessageExportStats) -> None: + output_path = Path.cwd() / self.output_gz_name + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("wb") as output_file: + self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file) + + def _export_to_cloud(self, stats: AppMessageExportStats) -> None: + with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp: + self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp)) + tmp.seek(0) + data = tmp.read() + + storage.save(self.output_gz_name, data) + logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name) + + def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]: + for record in self.iter_records(): + self._update_stats(stats, record) + yield record + + @staticmethod + def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None: + stats.total_messages += 1 + if record.feedback: + stats.messages_with_feedback += 1 + stats.total_feedbacks += len(record.feedback) + + def _finalize_stats(self, stats: AppMessageExportStats) -> None: + if stats.total_messages == 0: + stats.batches = 0 + return + stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size + + def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]: + cursor: tuple[datetime.datetime, str] | None = None + while True: + rows, cursor = self._fetch_batch(cursor) + if not rows: + break + + message_ids = [str(row.id) for row in rows] + feedbacks_map = self._fetch_feedbacks(message_ids) + yield [self._build_record(row, feedbacks_map) for row in rows] + + def _fetch_batch( + self, cursor: tuple[datetime.datetime, str] | None + ) -> tuple[list[Any], tuple[datetime.datetime, str] | None]: + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select( + Message.id, + Message.conversation_id, + Message.query, + Message.answer, + Message._inputs, # pyright: ignore[reportPrivateUsage] + Message.message_metadata, + Message.created_at, + ) + .where( + Message.app_id == self._app_id, + Message.created_at < self._end_before, + ) + .order_by(Message.created_at, Message.id) + .limit(self._batch_size) + ) + + if self._start_from: + stmt = stmt.where(Message.created_at >= self._start_from) + + if cursor: + stmt = stmt.where( + tuple_(Message.created_at, Message.id) + > tuple_( + sa.literal(cursor[0], type_=sa.DateTime()), + sa.literal(cursor[1], type_=Message.id.type), + ) + ) + + rows = list(session.execute(stmt).all()) + + if not rows: + return [], cursor + + last = rows[-1] + return rows, (last.created_at, last.id) + + def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]: + if not message_ids: + return {} + + with Session(db.engine, expire_on_commit=False) as session: + stmt = ( + select(MessageFeedback) + .where( + MessageFeedback.message_id.in_(message_ids), + MessageFeedback.from_source == "user", + ) + .order_by(MessageFeedback.message_id, MessageFeedback.created_at) + ) + feedbacks = list(session.scalars(stmt).all()) + + result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list) + for feedback in feedbacks: + result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict())) + return result + + @staticmethod + def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord: + retriever_resources: list[Any] = [] + if row.message_metadata: + try: + metadata = json.loads(row.message_metadata) + value = metadata.get("retriever_resources", []) + if isinstance(value, list): + retriever_resources = value + except (json.JSONDecodeError, TypeError): + pass + + message_id = str(row.id) + return AppMessageExportRecord( + conversation_id=str(row.conversation_id), + message_id=message_id, + query=row.query, + answer=row.answer, + inputs=row._inputs if isinstance(row._inputs, dict) else {}, + retriever_resources=retriever_resources, + feedback=feedbacks_map.get(message_id, []), + ) diff --git a/api/services/retention/conversation/messages_clean_service.py b/api/services/retention/conversation/messages_clean_service.py index f7836a2b14..48c3e72af0 100644 --- a/api/services/retention/conversation/messages_clean_service.py +++ b/api/services/retention/conversation/messages_clean_service.py @@ -1,17 +1,18 @@ import datetime import logging -import os import random import time from collections.abc import Sequence -from typing import cast +from typing import TYPE_CHECKING, cast import sqlalchemy as sa from sqlalchemy import delete, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session +from configs import dify_config from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.model import ( App, AppAnnotationHitHistory, @@ -32,6 +33,131 @@ from services.retention.conversation.messages_clean_policy import ( logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from opentelemetry.metrics import Counter, Histogram + + +class MessagesCleanupMetrics: + """ + Records low-cardinality OpenTelemetry metrics for expired message cleanup jobs. + + We keep labels stable (dry_run/window_mode/task_label/status) so these metrics remain + dashboard-friendly for long-running CronJob executions. + """ + + _job_runs_total: "Counter | None" + _batches_total: "Counter | None" + _messages_scanned_total: "Counter | None" + _messages_filtered_total: "Counter | None" + _messages_deleted_total: "Counter | None" + _job_duration_seconds: "Histogram | None" + _batch_duration_seconds: "Histogram | None" + _base_attributes: dict[str, str] + + def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None: + self._job_runs_total = None + self._batches_total = None + self._messages_scanned_total = None + self._messages_filtered_total = None + self._messages_deleted_total = None + self._job_duration_seconds = None + self._batch_duration_seconds = None + self._base_attributes = { + "job_name": "messages_cleanup", + "dry_run": str(dry_run).lower(), + "window_mode": "between" if has_window else "before_cutoff", + "task_label": task_label, + } + self._init_instruments() + + def _init_instruments(self) -> None: + if not dify_config.ENABLE_OTEL: + return + + try: + from opentelemetry.metrics import get_meter + + meter = get_meter("messages_cleanup", version=dify_config.project.version) + self._job_runs_total = meter.create_counter( + "messages_cleanup_jobs_total", + description="Total number of expired message cleanup jobs by status.", + unit="{job}", + ) + self._batches_total = meter.create_counter( + "messages_cleanup_batches_total", + description="Total number of message cleanup batches processed.", + unit="{batch}", + ) + self._messages_scanned_total = meter.create_counter( + "messages_cleanup_scanned_messages_total", + description="Total messages scanned by cleanup jobs.", + unit="{message}", + ) + self._messages_filtered_total = meter.create_counter( + "messages_cleanup_filtered_messages_total", + description="Total messages selected by cleanup policy.", + unit="{message}", + ) + self._messages_deleted_total = meter.create_counter( + "messages_cleanup_deleted_messages_total", + description="Total messages deleted by cleanup jobs.", + unit="{message}", + ) + self._job_duration_seconds = meter.create_histogram( + "messages_cleanup_job_duration_seconds", + description="Duration of expired message cleanup jobs in seconds.", + unit="s", + ) + self._batch_duration_seconds = meter.create_histogram( + "messages_cleanup_batch_duration_seconds", + description="Duration of expired message cleanup batch processing in seconds.", + unit="s", + ) + except Exception: + logger.exception("messages_cleanup_metrics: failed to initialize instruments") + + def _attrs(self, **extra: str) -> dict[str, str]: + return {**self._base_attributes, **extra} + + @staticmethod + def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None: + if not counter or value <= 0: + return + try: + counter.add(value, attributes) + except Exception: + logger.exception("messages_cleanup_metrics: failed to add counter value") + + @staticmethod + def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None: + if not histogram: + return + try: + histogram.record(value, attributes) + except Exception: + logger.exception("messages_cleanup_metrics: failed to record histogram value") + + def record_batch( + self, + *, + scanned_messages: int, + filtered_messages: int, + deleted_messages: int, + batch_duration_seconds: float, + ) -> None: + attributes = self._attrs() + self._add(self._batches_total, 1, attributes) + self._add(self._messages_scanned_total, scanned_messages, attributes) + self._add(self._messages_filtered_total, filtered_messages, attributes) + self._add(self._messages_deleted_total, deleted_messages, attributes) + self._record(self._batch_duration_seconds, batch_duration_seconds, attributes) + + def record_completion(self, *, status: str, job_duration_seconds: float) -> None: + attributes = self._attrs(status=status) + self._add(self._job_runs_total, 1, attributes) + self._record(self._job_duration_seconds, job_duration_seconds, attributes) + + class MessagesCleanService: """ Service for cleaning expired messages based on retention policies. @@ -47,6 +173,7 @@ class MessagesCleanService: start_from: datetime.datetime | None = None, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> None: """ Initialize the service with cleanup parameters. @@ -57,12 +184,18 @@ class MessagesCleanService: start_from: Optional start time (inclusive) of the range batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics """ self._policy = policy self._end_before = end_before self._start_from = start_from self._batch_size = batch_size self._dry_run = dry_run + self._metrics = MessagesCleanupMetrics( + dry_run=dry_run, + has_window=bool(start_from), + task_label=task_label, + ) @classmethod def from_time_range( @@ -72,6 +205,7 @@ class MessagesCleanService: end_before: datetime.datetime, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> "MessagesCleanService": """ Create a service instance for cleaning messages within a specific time range. @@ -84,6 +218,7 @@ class MessagesCleanService: end_before: End time (exclusive) of the range batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics Returns: MessagesCleanService instance @@ -111,6 +246,7 @@ class MessagesCleanService: start_from=start_from, batch_size=batch_size, dry_run=dry_run, + task_label=task_label, ) @classmethod @@ -120,6 +256,7 @@ class MessagesCleanService: days: int = 30, batch_size: int = 1000, dry_run: bool = False, + task_label: str = "custom", ) -> "MessagesCleanService": """ Create a service instance for cleaning messages older than specified days. @@ -129,6 +266,7 @@ class MessagesCleanService: days: Number of days to look back from now batch_size: Number of messages to process per batch dry_run: Whether to perform a dry run (no actual deletion) + task_label: Optional task label for retention metrics Returns: MessagesCleanService instance @@ -142,7 +280,7 @@ class MessagesCleanService: if batch_size <= 0: raise ValueError(f"batch_size ({batch_size}) must be greater than 0") - end_before = datetime.datetime.now() - datetime.timedelta(days=days) + end_before = naive_utc_now() - datetime.timedelta(days=days) logger.info( "clean_messages: days=%s, end_before=%s, batch_size=%s, policy=%s", @@ -152,7 +290,14 @@ class MessagesCleanService: policy.__class__.__name__, ) - return cls(policy=policy, end_before=end_before, start_from=None, batch_size=batch_size, dry_run=dry_run) + return cls( + policy=policy, + end_before=end_before, + start_from=None, + batch_size=batch_size, + dry_run=dry_run, + task_label=task_label, + ) def run(self) -> dict[str, int]: """ @@ -161,7 +306,18 @@ class MessagesCleanService: Returns: Dict with statistics: batches, filtered_messages, total_deleted """ - return self._clean_messages_by_time_range() + status = "success" + run_start = time.monotonic() + try: + return self._clean_messages_by_time_range() + except Exception: + status = "failed" + raise + finally: + self._metrics.record_completion( + status=status, + job_duration_seconds=time.monotonic() - run_start, + ) def _clean_messages_by_time_range(self) -> dict[str, int]: """ @@ -196,11 +352,14 @@ class MessagesCleanService: self._end_before, ) - max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200)) + max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL while True: stats["batches"] += 1 batch_start = time.monotonic() + batch_scanned_messages = 0 + batch_filtered_messages = 0 + batch_deleted_messages = 0 # Step 1: Fetch a batch of messages using cursor with Session(db.engine, expire_on_commit=False) as session: @@ -239,9 +398,16 @@ class MessagesCleanService: # Track total messages fetched across all batches stats["total_messages"] += len(messages) + batch_scanned_messages = len(messages) if not messages: logger.info("clean_messages (batch %s): no more messages to process", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) break # Update cursor to the last message's (created_at, id) @@ -267,6 +433,12 @@ class MessagesCleanService: if not apps: logger.info("clean_messages (batch %s): no apps found, skip", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) continue # Build app_id -> tenant_id mapping @@ -285,9 +457,16 @@ class MessagesCleanService: if not message_ids_to_delete: logger.info("clean_messages (batch %s): no messages to delete, skip", stats["batches"]) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) continue stats["filtered_messages"] += len(message_ids_to_delete) + batch_filtered_messages = len(message_ids_to_delete) # Step 4: Batch delete messages and their relations if not self._dry_run: @@ -308,6 +487,7 @@ class MessagesCleanService: commit_ms = int((time.monotonic() - commit_start) * 1000) stats["total_deleted"] += messages_deleted + batch_deleted_messages = messages_deleted logger.info( "clean_messages (batch %s): processed %s messages, deleted %s messages", @@ -342,6 +522,13 @@ class MessagesCleanService: for msg_id in sampled_ids: logger.info("clean_messages (batch %s, dry_run) sample: message_id=%s", stats["batches"], msg_id) + self._metrics.record_batch( + scanned_messages=batch_scanned_messages, + filtered_messages=batch_filtered_messages, + deleted_messages=batch_deleted_messages, + batch_duration_seconds=time.monotonic() - batch_start, + ) + logger.info( "clean_messages completed: total batches: %s, total messages: %s, filtered messages: %s, total deleted: %s", stats["batches"], diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index ea5cbb7740..00a2144800 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -31,7 +31,7 @@ from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from core.workflow.enums import WorkflowType +from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.archive_storage import ( diff --git a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py index 2c94cb5324..62bc9f5f10 100644 --- a/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py +++ b/api/services/retention/workflow_run/clear_free_plan_expired_workflow_run_logs.py @@ -1,9 +1,9 @@ import datetime import logging -import os import random import time from collections.abc import Iterable, Sequence +from typing import TYPE_CHECKING import click from sqlalchemy.orm import Session, sessionmaker @@ -20,6 +20,159 @@ from services.billing_service import BillingService, SubscriptionPlan logger = logging.getLogger(__name__) +if TYPE_CHECKING: + from opentelemetry.metrics import Counter, Histogram + + +class WorkflowRunCleanupMetrics: + """ + Records low-cardinality OpenTelemetry metrics for workflow run cleanup jobs. + + Metrics are emitted with stable labels only (dry_run/window_mode/task_label/status) + to keep dashboard and alert cardinality predictable in production clusters. + """ + + _job_runs_total: "Counter | None" + _batches_total: "Counter | None" + _runs_scanned_total: "Counter | None" + _runs_targeted_total: "Counter | None" + _runs_deleted_total: "Counter | None" + _runs_skipped_total: "Counter | None" + _related_records_total: "Counter | None" + _job_duration_seconds: "Histogram | None" + _batch_duration_seconds: "Histogram | None" + _base_attributes: dict[str, str] + + def __init__(self, *, dry_run: bool, has_window: bool, task_label: str) -> None: + self._job_runs_total = None + self._batches_total = None + self._runs_scanned_total = None + self._runs_targeted_total = None + self._runs_deleted_total = None + self._runs_skipped_total = None + self._related_records_total = None + self._job_duration_seconds = None + self._batch_duration_seconds = None + self._base_attributes = { + "job_name": "workflow_run_cleanup", + "dry_run": str(dry_run).lower(), + "window_mode": "between" if has_window else "before_cutoff", + "task_label": task_label, + } + self._init_instruments() + + def _init_instruments(self) -> None: + if not dify_config.ENABLE_OTEL: + return + + try: + from opentelemetry.metrics import get_meter + + meter = get_meter("workflow_run_cleanup", version=dify_config.project.version) + self._job_runs_total = meter.create_counter( + "workflow_run_cleanup_jobs_total", + description="Total number of workflow run cleanup jobs by status.", + unit="{job}", + ) + self._batches_total = meter.create_counter( + "workflow_run_cleanup_batches_total", + description="Total number of processed cleanup batches.", + unit="{batch}", + ) + self._runs_scanned_total = meter.create_counter( + "workflow_run_cleanup_scanned_runs_total", + description="Total workflow runs scanned by cleanup jobs.", + unit="{run}", + ) + self._runs_targeted_total = meter.create_counter( + "workflow_run_cleanup_targeted_runs_total", + description="Total workflow runs targeted by cleanup policy.", + unit="{run}", + ) + self._runs_deleted_total = meter.create_counter( + "workflow_run_cleanup_deleted_runs_total", + description="Total workflow runs deleted by cleanup jobs.", + unit="{run}", + ) + self._runs_skipped_total = meter.create_counter( + "workflow_run_cleanup_skipped_runs_total", + description="Total workflow runs skipped because tenant is paid/unknown.", + unit="{run}", + ) + self._related_records_total = meter.create_counter( + "workflow_run_cleanup_related_records_total", + description="Total related records processed by cleanup jobs.", + unit="{record}", + ) + self._job_duration_seconds = meter.create_histogram( + "workflow_run_cleanup_job_duration_seconds", + description="Duration of workflow run cleanup jobs in seconds.", + unit="s", + ) + self._batch_duration_seconds = meter.create_histogram( + "workflow_run_cleanup_batch_duration_seconds", + description="Duration of workflow run cleanup batch processing in seconds.", + unit="s", + ) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to initialize instruments") + + def _attrs(self, **extra: str) -> dict[str, str]: + return {**self._base_attributes, **extra} + + @staticmethod + def _add(counter: "Counter | None", value: int, attributes: dict[str, str]) -> None: + if not counter or value <= 0: + return + try: + counter.add(value, attributes) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to add counter value") + + @staticmethod + def _record(histogram: "Histogram | None", value: float, attributes: dict[str, str]) -> None: + if not histogram: + return + try: + histogram.record(value, attributes) + except Exception: + logger.exception("workflow_run_cleanup_metrics: failed to record histogram value") + + def record_batch( + self, + *, + batch_rows: int, + targeted_runs: int, + skipped_runs: int, + deleted_runs: int, + related_counts: dict[str, int] | None, + related_action: str | None, + batch_duration_seconds: float, + ) -> None: + attributes = self._attrs() + self._add(self._batches_total, 1, attributes) + self._add(self._runs_scanned_total, batch_rows, attributes) + self._add(self._runs_targeted_total, targeted_runs, attributes) + self._add(self._runs_skipped_total, skipped_runs, attributes) + self._add(self._runs_deleted_total, deleted_runs, attributes) + self._record(self._batch_duration_seconds, batch_duration_seconds, attributes) + + if not related_counts or not related_action: + return + + for record_type, count in related_counts.items(): + self._add( + self._related_records_total, + count, + self._attrs(action=related_action, record_type=record_type), + ) + + def record_completion(self, *, status: str, job_duration_seconds: float) -> None: + attributes = self._attrs(status=status) + self._add(self._job_runs_total, 1, attributes) + self._record(self._job_duration_seconds, job_duration_seconds, attributes) + + class WorkflowRunCleanup: def __init__( self, @@ -29,6 +182,7 @@ class WorkflowRunCleanup: end_before: datetime.datetime | None = None, workflow_run_repo: APIWorkflowRunRepository | None = None, dry_run: bool = False, + task_label: str = "custom", ): if (start_from is None) ^ (end_before is None): raise ValueError("start_from and end_before must be both set or both omitted.") @@ -46,6 +200,11 @@ class WorkflowRunCleanup: self.batch_size = batch_size self._cleanup_whitelist: set[str] | None = None self.dry_run = dry_run + self._metrics = WorkflowRunCleanupMetrics( + dry_run=dry_run, + has_window=bool(start_from), + task_label=task_label, + ) self.free_plan_grace_period_days = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD self.workflow_run_repo: APIWorkflowRunRepository if workflow_run_repo: @@ -74,153 +233,193 @@ class WorkflowRunCleanup: related_totals = self._empty_related_counts() if self.dry_run else None batch_index = 0 last_seen: tuple[datetime.datetime, str] | None = None + status = "success" + run_start = time.monotonic() + max_batch_interval_ms = dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL - max_batch_interval_ms = int(os.environ.get("SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", 200)) + try: + while True: + batch_start = time.monotonic() - while True: - batch_start = time.monotonic() - - fetch_start = time.monotonic() - run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( - start_from=self.window_start, - end_before=self.window_end, - last_seen=last_seen, - batch_size=self.batch_size, - ) - if not run_rows: - logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1) - break - - batch_index += 1 - last_seen = (run_rows[-1].created_at, run_rows[-1].id) - logger.info( - "workflow_run_cleanup (batch #%s): fetched %s rows in %sms", - batch_index, - len(run_rows), - int((time.monotonic() - fetch_start) * 1000), - ) - - tenant_ids = {row.tenant_id for row in run_rows} - - filter_start = time.monotonic() - free_tenants = self._filter_free_tenants(tenant_ids) - logger.info( - "workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms", - batch_index, - len(free_tenants), - len(tenant_ids), - int((time.monotonic() - filter_start) * 1000), - ) - - free_runs = [row for row in run_rows if row.tenant_id in free_tenants] - paid_or_skipped = len(run_rows) - len(free_runs) - - if not free_runs: - skipped_message = ( - f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)" + fetch_start = time.monotonic() + run_rows = self.workflow_run_repo.get_runs_batch_by_time_range( + start_from=self.window_start, + end_before=self.window_end, + last_seen=last_seen, + batch_size=self.batch_size, ) - click.echo( - click.style( - skipped_message, - fg="yellow", - ) - ) - continue + if not run_rows: + logger.info("workflow_run_cleanup (batch #%s): no more rows to process", batch_index + 1) + break - total_runs_targeted += len(free_runs) - - if self.dry_run: - count_start = time.monotonic() - batch_counts = self.workflow_run_repo.count_runs_with_related( - free_runs, - count_node_executions=self._count_node_executions, - count_trigger_logs=self._count_trigger_logs, - ) + batch_index += 1 + last_seen = (run_rows[-1].created_at, run_rows[-1].id) logger.info( - "workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms", + "workflow_run_cleanup (batch #%s): fetched %s rows in %sms", batch_index, - int((time.monotonic() - count_start) * 1000), + len(run_rows), + int((time.monotonic() - fetch_start) * 1000), ) - if related_totals is not None: - for key in related_totals: - related_totals[key] += batch_counts.get(key, 0) - sample_ids = ", ".join(run.id for run in free_runs[:5]) + + tenant_ids = {row.tenant_id for row in run_rows} + + filter_start = time.monotonic() + free_tenants = self._filter_free_tenants(tenant_ids) + logger.info( + "workflow_run_cleanup (batch #%s): filtered %s free tenants from %s tenants in %sms", + batch_index, + len(free_tenants), + len(tenant_ids), + int((time.monotonic() - filter_start) * 1000), + ) + + free_runs = [row for row in run_rows if row.tenant_id in free_tenants] + paid_or_skipped = len(run_rows) - len(free_runs) + + if not free_runs: + skipped_message = ( + f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)" + ) + click.echo( + click.style( + skipped_message, + fg="yellow", + ) + ) + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=0, + skipped_runs=paid_or_skipped, + deleted_runs=0, + related_counts=None, + related_action=None, + batch_duration_seconds=time.monotonic() - batch_start, + ) + continue + + total_runs_targeted += len(free_runs) + + if self.dry_run: + count_start = time.monotonic() + batch_counts = self.workflow_run_repo.count_runs_with_related( + free_runs, + count_node_executions=self._count_node_executions, + count_trigger_logs=self._count_trigger_logs, + ) + logger.info( + "workflow_run_cleanup (batch #%s, dry_run): counted related records in %sms", + batch_index, + int((time.monotonic() - count_start) * 1000), + ) + if related_totals is not None: + for key in related_totals: + related_totals[key] += batch_counts.get(key, 0) + sample_ids = ", ".join(run.id for run in free_runs[:5]) + click.echo( + click.style( + f"[batch #{batch_index}] would delete {len(free_runs)} runs " + f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", + fg="yellow", + ) + ) + logger.info( + "workflow_run_cleanup (batch #%s, dry_run): batch total %sms", + batch_index, + int((time.monotonic() - batch_start) * 1000), + ) + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=len(free_runs), + skipped_runs=paid_or_skipped, + deleted_runs=0, + related_counts={key: batch_counts.get(key, 0) for key in self._empty_related_counts()}, + related_action="would_delete", + batch_duration_seconds=time.monotonic() - batch_start, + ) + continue + + try: + delete_start = time.monotonic() + counts = self.workflow_run_repo.delete_runs_with_related( + free_runs, + delete_node_executions=self._delete_node_executions, + delete_trigger_logs=self._delete_trigger_logs, + ) + delete_ms = int((time.monotonic() - delete_start) * 1000) + except Exception: + logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) + raise + + total_runs_deleted += counts["runs"] click.echo( click.style( - f"[batch #{batch_index}] would delete {len(free_runs)} runs " - f"(sample ids: {sample_ids}) and skip {paid_or_skipped} paid/unknown", - fg="yellow", + f"[batch #{batch_index}] deleted runs: {counts['runs']} " + f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " + f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " + f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " + f"skipped {paid_or_skipped} paid/unknown", + fg="green", ) ) logger.info( - "workflow_run_cleanup (batch #%s, dry_run): batch total %sms", + "workflow_run_cleanup (batch #%s): delete %sms, batch total %sms", batch_index, + delete_ms, int((time.monotonic() - batch_start) * 1000), ) - continue - - try: - delete_start = time.monotonic() - counts = self.workflow_run_repo.delete_runs_with_related( - free_runs, - delete_node_executions=self._delete_node_executions, - delete_trigger_logs=self._delete_trigger_logs, + self._metrics.record_batch( + batch_rows=len(run_rows), + targeted_runs=len(free_runs), + skipped_runs=paid_or_skipped, + deleted_runs=counts["runs"], + related_counts={key: counts.get(key, 0) for key in self._empty_related_counts()}, + related_action="deleted", + batch_duration_seconds=time.monotonic() - batch_start, ) - delete_ms = int((time.monotonic() - delete_start) * 1000) - except Exception: - logger.exception("Failed to delete workflow runs batch ending at %s", last_seen[0]) - raise - total_runs_deleted += counts["runs"] - click.echo( - click.style( - f"[batch #{batch_index}] deleted runs: {counts['runs']} " - f"(nodes {counts['node_executions']}, offloads {counts['offloads']}, " - f"app_logs {counts['app_logs']}, trigger_logs {counts['trigger_logs']}, " - f"pauses {counts['pauses']}, pause_reasons {counts['pause_reasons']}); " - f"skipped {paid_or_skipped} paid/unknown", - fg="green", - ) - ) - logger.info( - "workflow_run_cleanup (batch #%s): delete %sms, batch total %sms", - batch_index, - delete_ms, - int((time.monotonic() - batch_start) * 1000), - ) + # Random sleep between batches to avoid overwhelming the database + sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311 + logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms) + time.sleep(sleep_ms / 1000) - # Random sleep between batches to avoid overwhelming the database - sleep_ms = random.uniform(0, max_batch_interval_ms) # noqa: S311 - logger.info("workflow_run_cleanup (batch #%s): sleeping for %.2fms", batch_index, sleep_ms) - time.sleep(sleep_ms / 1000) - - if self.dry_run: - if self.window_start: - summary_message = ( - f"Dry run complete. Would delete {total_runs_targeted} workflow runs " - f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" - ) + if self.dry_run: + if self.window_start: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Dry run complete. Would delete {total_runs_targeted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + if related_totals is not None: + summary_message = ( + f"{summary_message}; related records: {self._format_related_counts(related_totals)}" + ) + summary_color = "yellow" else: - summary_message = ( - f"Dry run complete. Would delete {total_runs_targeted} workflow runs " - f"before {self.window_end.isoformat()}" - ) - if related_totals is not None: - summary_message = f"{summary_message}; related records: {self._format_related_counts(related_totals)}" - summary_color = "yellow" - else: - if self.window_start: - summary_message = ( - f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " - f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" - ) - else: - summary_message = ( - f"Cleanup complete. Deleted {total_runs_deleted} workflow runs before {self.window_end.isoformat()}" - ) - summary_color = "white" + if self.window_start: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"between {self.window_start.isoformat()} and {self.window_end.isoformat()}" + ) + else: + summary_message = ( + f"Cleanup complete. Deleted {total_runs_deleted} workflow runs " + f"before {self.window_end.isoformat()}" + ) + summary_color = "white" - click.echo(click.style(summary_message, fg=summary_color)) + click.echo(click.style(summary_message, fg=summary_color)) + except Exception: + status = "failed" + raise + finally: + self._metrics.record_completion( + status=status, + job_duration_seconds=time.monotonic() - run_start, + ) def _filter_free_tenants(self, tenant_ids: Iterable[str]) -> set[str]: tenant_id_list = list(tenant_ids) diff --git a/api/services/retention/workflow_run/restore_archived_workflow_run.py b/api/services/retention/workflow_run/restore_archived_workflow_run.py index d4a6e87585..64dad7ba52 100644 --- a/api/services/retention/workflow_run/restore_archived_workflow_run.py +++ b/api/services/retention/workflow_run/restore_archived_workflow_run.py @@ -358,21 +358,19 @@ class WorkflowRunRestore: self, model: type[DeclarativeBase] | Any, ) -> tuple[set[str], set[str], set[str]]: - columns = list(model.__table__.columns) + table = model.__table__ + columns = list(table.columns) + autoincrement_column = getattr(table, "autoincrement_column", None) + + def has_insert_default(column: Any) -> bool: + # SQLAlchemy may set column.autoincrement to "auto" on non-PK columns. + # Only treat the resolved autoincrement column as DB-generated. + return column.default is not None or column.server_default is not None or column is autoincrement_column + column_names = {column.key for column in columns} - required_columns = { - column.key - for column in columns - if not column.nullable - and column.default is None - and column.server_default is None - and not column.autoincrement - } + required_columns = {column.key for column in columns if not column.nullable and not has_insert_default(column)} non_nullable_with_default = { - column.key - for column in columns - if not column.nullable - and (column.default is not None or column.server_default is not None or column.autoincrement) + column.key for column in columns if not column.nullable and has_insert_default(column) } return column_names, required_columns, non_nullable_with_default diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py index 4dd6c8107b..d0f4f27968 100644 --- a/api/services/saved_message_service.py +++ b/api/services/saved_message_service.py @@ -3,6 +3,7 @@ from typing import Union from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import SavedMessage from services.message_service import MessageService @@ -54,7 +55,7 @@ class SavedMessageService: saved_message = SavedMessage( app_id=app_model.id, message_id=message.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index 7c03ceed5b..943dfc972b 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -10,14 +10,16 @@ from sqlalchemy.orm import Session from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument +from models.enums import SummaryStatus logger = logging.getLogger(__name__) @@ -29,7 +31,7 @@ class SummaryIndexService: def generate_summary_for_segment( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> tuple[str, LLMUsage]: """ Generate summary for a single segment. @@ -73,7 +75,7 @@ class SummaryIndexService: segment: DocumentSegment, dataset: Dataset, summary_content: str, - status: str = "generating", + status: SummaryStatus = SummaryStatus.GENERATING, ) -> DocumentSegmentSummary: """ Create or update a DocumentSegmentSummary record. @@ -83,7 +85,7 @@ class SummaryIndexService: segment: DocumentSegment to create summary for dataset: Dataset containing the segment summary_content: Generated summary content - status: Summary status (default: "generating") + status: Summary status (default: SummaryStatus.GENERATING) Returns: Created or updated DocumentSegmentSummary instance @@ -326,7 +328,7 @@ class SummaryIndexService: summary_index_node_id=summary_index_node_id, summary_index_node_hash=summary_hash, tokens=embedding_tokens, - status="completed", + status=SummaryStatus.COMPLETED, enabled=True, ) session.add(summary_record_in_session) @@ -362,7 +364,7 @@ class SummaryIndexService: summary_record_in_session.summary_index_node_id = summary_index_node_id summary_record_in_session.summary_index_node_hash = summary_hash summary_record_in_session.tokens = embedding_tokens # Save embedding tokens - summary_record_in_session.status = "completed" + summary_record_in_session.status = SummaryStatus.COMPLETED # Ensure summary_content is preserved (use the latest from summary_record parameter) # This is critical: use the parameter value, not the database value summary_record_in_session.summary_content = summary_content @@ -400,7 +402,7 @@ class SummaryIndexService: summary_record.summary_index_node_id = summary_index_node_id summary_record.summary_index_node_hash = summary_hash summary_record.tokens = embedding_tokens - summary_record.status = "completed" + summary_record.status = SummaryStatus.COMPLETED summary_record.summary_content = summary_content if summary_record_in_session.updated_at: summary_record.updated_at = summary_record_in_session.updated_at @@ -487,7 +489,7 @@ class SummaryIndexService: ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(e)}" summary_record_in_session.updated_at = datetime.now(UTC).replace(tzinfo=None) error_session.add(summary_record_in_session) @@ -498,7 +500,7 @@ class SummaryIndexService: summary_record_in_session.id, ) # Update the original object for consistency - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = summary_record_in_session.error summary_record.updated_at = summary_record_in_session.updated_at else: @@ -514,7 +516,7 @@ class SummaryIndexService: def batch_create_summary_records( segments: list[DocumentSegment], dataset: Dataset, - status: str = "not_started", + status: SummaryStatus = SummaryStatus.NOT_STARTED, ) -> None: """ Batch create summary records for segments with specified status. @@ -523,7 +525,7 @@ class SummaryIndexService: Args: segments: List of DocumentSegment instances dataset: Dataset containing the segments - status: Initial status for the records (default: "not_started") + status: Initial status for the records (default: SummaryStatus.NOT_STARTED) """ segment_ids = [segment.id for segment in segments] if not segment_ids: @@ -588,7 +590,7 @@ class SummaryIndexService: ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = error session.add(summary_record) session.commit() @@ -599,7 +601,7 @@ class SummaryIndexService: def generate_and_vectorize_summary( segment: DocumentSegment, dataset: Dataset, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, ) -> DocumentSegmentSummary: """ Generate summary for a segment and vectorize it. @@ -631,14 +633,14 @@ class SummaryIndexService: document_id=segment.document_id, chunk_id=segment.id, summary_content="", - status="generating", + status=SummaryStatus.GENERATING, enabled=True, ) session.add(summary_record_in_session) session.flush() # Update status to "generating" - summary_record_in_session.status = "generating" + summary_record_in_session.status = SummaryStatus.GENERATING summary_record_in_session.error = None # type: ignore[assignment] session.add(summary_record_in_session) # Don't flush here - wait until after vectorization succeeds @@ -681,7 +683,7 @@ class SummaryIndexService: except Exception as vectorize_error: # If vectorization fails, update status to error in current session logger.exception("Failed to vectorize summary for segment %s", segment.id) - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = f"Vectorization failed: {str(vectorize_error)}" session.add(summary_record_in_session) session.commit() @@ -694,7 +696,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record_in_session: - summary_record_in_session.status = "error" + summary_record_in_session.status = SummaryStatus.ERROR summary_record_in_session.error = str(e) session.add(summary_record_in_session) session.commit() @@ -704,7 +706,7 @@ class SummaryIndexService: def generate_summaries_for_document( dataset: Dataset, document: DatasetDocument, - summary_index_setting: dict, + summary_index_setting: SummaryIndexSettingDict, segment_ids: list[str] | None = None, only_parent_chunks: bool = False, ) -> list[DocumentSegmentSummary]: @@ -770,7 +772,7 @@ class SummaryIndexService: SummaryIndexService.batch_create_summary_records( segments=segments, dataset=dataset, - status="not_started", + status=SummaryStatus.NOT_STARTED, ) summary_records = [] @@ -1067,7 +1069,7 @@ class SummaryIndexService: # Update summary content summary_record.summary_content = summary_content - summary_record.status = "generating" + summary_record.status = SummaryStatus.GENERATING summary_record.error = None # type: ignore[assignment] # Clear any previous errors session.add(summary_record) # Flush to ensure summary_content is saved before vectorize_summary queries it @@ -1102,7 +1104,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Don't raise the exception - just log it and return the record with error status # This allows the segment update to complete even if vectorization fails - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1112,7 +1114,7 @@ class SummaryIndexService: else: # Create new summary record if doesn't exist summary_record = SummaryIndexService.create_summary_record( - segment, dataset, summary_content, status="generating" + segment, dataset, summary_content, status=SummaryStatus.GENERATING ) # Re-vectorize summary (this will update status to "completed" and tokens in its own session) # Note: summary_record was created in a different session, @@ -1132,7 +1134,7 @@ class SummaryIndexService: # If vectorization fails, update status to error in current session # Merge the record into current session first error_record = session.merge(summary_record) - error_record.status = "error" + error_record.status = SummaryStatus.ERROR error_record.error = f"Vectorization failed: {str(e)}" session.commit() logger.exception("Failed to vectorize summary for segment %s", segment.id) @@ -1146,7 +1148,7 @@ class SummaryIndexService: session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first() ) if summary_record: - summary_record.status = "error" + summary_record.status = SummaryStatus.ERROR summary_record.error = str(e) session.add(summary_record) session.commit() @@ -1266,7 +1268,7 @@ class SummaryIndexService: # Check if there are any "not_started" or "generating" status summaries has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1330,7 +1332,7 @@ class SummaryIndexService: # it means the summary is disabled (enabled=False) or not created yet, ignore it has_pending_summaries = any( summary_status_map.get(segment_id) is not None # Ensure summary exists (enabled=True) - and summary_status_map[segment_id] in ("not_started", "generating") + and summary_status_map[segment_id] in (SummaryStatus.NOT_STARTED, SummaryStatus.GENERATING) for segment_id in segment_ids ) @@ -1393,17 +1395,17 @@ class SummaryIndexService: # Count statuses status_counts = { - "completed": 0, - "generating": 0, - "error": 0, - "not_started": 0, + SummaryStatus.COMPLETED: 0, + SummaryStatus.GENERATING: 0, + SummaryStatus.ERROR: 0, + SummaryStatus.NOT_STARTED: 0, } summary_list = [] for segment in segments: summary = summary_map.get(segment.id) if summary: - status = summary.status + status = SummaryStatus(summary.status) status_counts[status] = status_counts.get(status, 0) + 1 summary_list.append( { @@ -1421,12 +1423,12 @@ class SummaryIndexService: } ) else: - status_counts["not_started"] += 1 + status_counts[SummaryStatus.NOT_STARTED] += 1 summary_list.append( { "segment_id": segment.id, "segment_position": segment.position, - "status": "not_started", + "status": SummaryStatus.NOT_STARTED, "summary_preview": None, "error": None, "created_at": None, diff --git a/api/services/tag_service.py b/api/services/tag_service.py index bd3585acf4..70bf7f16f2 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -7,6 +7,7 @@ from werkzeug.exceptions import NotFound from extensions.ext_database import db from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding @@ -83,7 +84,7 @@ class TagService: raise ValueError("Tag name already exists") tag = Tag( name=args["name"], - type=args["type"], + type=TagType(args["type"]), created_by=current_user.id, tenant_id=current_user.current_tenant_id, ) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index c32157919b..408b1c22d1 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -1,13 +1,12 @@ import json import logging -from collections.abc import Mapping from typing import Any, cast from httpx import get from sqlalchemy import select +from typing_extensions import TypedDict from core.entities.provider_entities import ProviderConfig -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_runtime import ToolRuntime from core.tools.custom_tool.provider import ApiToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -21,6 +20,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -28,9 +28,16 @@ from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) +class ApiSchemaParseResult(TypedDict): + schema_type: str + parameters_schema: list[dict[str, Any]] + credentials_schema: list[dict[str, Any]] + warning: dict[str, str] + + class ApiToolManageService: @staticmethod - def parser_api_schema(schema: str) -> Mapping[str, Any]: + def parser_api_schema(schema: str) -> ApiSchemaParseResult: """ parse api schema to tool bundle """ @@ -71,7 +78,7 @@ class ApiToolManageService: ] return cast( - Mapping, + ApiSchemaParseResult, jsonable_encoder( { "schema_type": schema_type, diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 0be106f597..deb26438a8 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -18,6 +18,7 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.auth.auth_flow import auth from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPAuthError, MCPError +from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.utils.encryption import ProviderConfigEncrypter from models.tools import MCPToolProvider @@ -681,7 +682,7 @@ class MCPToolManageService: raise ValueError(f"Failed to re-connect MCP server: {e}") from e def _build_tool_provider_response( - self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list + self, db_provider: MCPToolProvider, provider_entity: MCPProviderEntity, tools: list[MCPTool] ) -> ToolProviderApiEntity: """Build API response for tool provider.""" user = db_provider.load_user() @@ -703,7 +704,7 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_url} already exists") if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - raise + raise error def _is_valid_url(self, url: str) -> bool: """Validate URL format.""" diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index e323b3cda9..b6e5367023 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -33,6 +33,8 @@ logger = logging.getLogger(__name__) class ToolTransformService: + _MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10 + @classmethod def get_tool_provider_icon_url( cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str] @@ -435,6 +437,46 @@ class ToolTransformService: :return: list of ToolParameter instances """ + def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str: + """ + Resolve a JSON schema property type while guarding against cyclic or deeply nested unions. + """ + if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH: + return "string" + prop_type = prop.get("type") + if isinstance(prop_type, list): + non_null_types = [type_name for type_name in prop_type if type_name != "null"] + if non_null_types: + return non_null_types[0] + if prop_type: + return "string" + elif isinstance(prop_type, str): + if prop_type == "null": + return "string" + return prop_type + + for union_key in ("anyOf", "oneOf"): + union_schemas = prop.get(union_key) + if not isinstance(union_schemas, list): + continue + + for union_schema in union_schemas: + if not isinstance(union_schema, dict): + continue + union_type = resolve_property_type(union_schema, depth + 1) + if union_type != "null": + return union_type + + all_of_schemas = prop.get("allOf") + if isinstance(all_of_schemas, list): + for all_of_schema in all_of_schemas: + if not isinstance(all_of_schema, dict): + continue + all_of_type = resolve_property_type(all_of_schema, depth + 1) + if all_of_type != "null": + return all_of_type + return "string" + def create_parameter( name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None ) -> ToolParameter: @@ -461,10 +503,7 @@ class ToolTransformService: parameters = [] for name, prop in props.items(): current_description = prop.get("description", "") - prop_type = prop.get("type", "string") - - if isinstance(prop_type, list): - prop_type = prop_type[0] + prop_type = resolve_property_type(prop) if prop_type in TYPE_MAPPING: prop_type = TYPE_MAPPING[prop_type] input_schema = prop if prop_type in COMPLEX_TYPES else None diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index ff0b276f77..101b2fe5a2 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -5,7 +5,6 @@ from datetime import datetime from sqlalchemy import or_, select from sqlalchemy.orm import Session -from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool_provider import ToolProviderController from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration @@ -13,6 +12,7 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.model import App from models.tools import WorkflowToolProvider diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index b49d14f860..7e9d010d2f 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -1,15 +1,19 @@ import json import logging -from collections.abc import Mapping from datetime import datetime -from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session -from core.workflow.nodes import NodeType -from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig, SchedulePlanUpdate, VisualConfig +from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE +from core.workflow.nodes.trigger_schedule.entities import ( + ScheduleConfig, + SchedulePlanUpdate, + TriggerScheduleNodeData, + VisualConfig, +) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError +from dify_graph.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan @@ -176,26 +180,26 @@ class ScheduleService: return next_run_at @staticmethod - def to_schedule_config(node_config: Mapping[str, Any]) -> ScheduleConfig: + def to_schedule_config(node_config: NodeConfigDict) -> ScheduleConfig: """ Converts user-friendly visual schedule settings to cron expression. Maintains consistency with frontend UI expectations while supporting croniter's extended syntax. """ - node_data = node_config.get("data", {}) - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") - node_id = node_config.get("id", "start") + node_data = TriggerScheduleNodeData.model_validate(node_config["data"], from_attributes=True) + mode = node_data.mode + timezone = node_data.timezone + node_id = node_config["id"] cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = node_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = str(node_data.get("frequency")) + frequency = str(node_data.frequency or "") if not frequency: raise ScheduleConfigError("Frequency is required for visual mode") - visual_config = VisualConfig(**node_data.get("visual_config", {})) + visual_config = VisualConfig.model_validate(node_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency=frequency, visual_config=visual_config) if not cron_expression: raise ScheduleConfigError("Cron expression is required for visual mode") @@ -236,22 +240,24 @@ class ScheduleService: for node in nodes: node_data = node.get("data", {}) - if node_data.get("type") != NodeType.TRIGGER_SCHEDULE.value: + if node_data.get("type") != TRIGGER_SCHEDULE_NODE_TYPE: continue - mode = node_data.get("mode", "visual") - timezone = node_data.get("timezone", "UTC") node_id = node.get("id", "start") + trigger_data = TriggerScheduleNodeData.model_validate(node_data) + mode = trigger_data.mode + timezone = trigger_data.timezone cron_expression = None if mode == "cron": - cron_expression = node_data.get("cron_expression") + cron_expression = trigger_data.cron_expression if not cron_expression: raise ScheduleConfigError("Cron expression is required for cron mode") elif mode == "visual": - frequency = node_data.get("frequency") - visual_config_dict = node_data.get("visual_config", {}) - visual_config = VisualConfig(**visual_config_dict) + frequency = trigger_data.frequency + if not frequency: + raise ScheduleConfigError("Frequency is required for visual mode") + visual_config = VisualConfig.model_validate(trigger_data.visual_config or {}) cron_expression = ScheduleService.visual_to_cron(frequency, visual_config) else: raise ScheduleConfigError(f"Invalid schedule mode: {mode}") diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 7f12c2e19c..24bbeda329 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -12,12 +12,13 @@ from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerDispatchResponse, TriggerInvokeEventResponse from core.plugin.impl.exc import PluginNotFoundError +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.debug.events import PluginTriggerDebugEvent from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription -from core.workflow.enums import NodeType from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App @@ -41,7 +42,7 @@ class TriggerService: @classmethod def invoke_trigger_event( - cls, tenant_id: str, user_id: str, node_config: Mapping[str, Any], event: PluginTriggerDebugEvent + cls, tenant_id: str, user_id: str, node_config: NodeConfigDict, event: PluginTriggerDebugEvent ) -> TriggerInvokeEventResponse: """Invoke a trigger event.""" subscription: TriggerSubscription | None = TriggerProviderService.get_subscription_by_id( @@ -50,7 +51,7 @@ class TriggerService: ) if not subscription: raise ValueError("Subscription not found") - node_data: TriggerEventNodeData = TriggerEventNodeData.model_validate(node_config.get("data", {})) + node_data = TriggerEventNodeData.model_validate(node_config["data"], from_attributes=True) request = TriggerHttpRequestCachingService.get_request(event.request_id) payload = TriggerHttpRequestCachingService.get_payload(event.request_id) # invoke triger @@ -178,7 +179,7 @@ class TriggerService: # Walk nodes to find plugin triggers nodes_in_graph: list[Mapping[str, Any]] = [] - for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN): + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): # Extract plugin trigger configuration from node plugin_id = node_config.get("plugin_id", "") provider_id = node_config.get("provider_id", "") diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 75a1350e60..3c1a4cc747 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -2,7 +2,7 @@ import json import logging import mimetypes import secrets -from collections.abc import Mapping +from collections.abc import Callable, Mapping, Sequence from typing import Any import orjson @@ -16,9 +16,16 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.tool_file_manager import ToolFileManager -from core.workflow.enums import NodeType -from core.workflow.file.models import FileTransferMethod -from core.workflow.variables.types import SegmentType +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE +from core.workflow.nodes.trigger_webhook.entities import ( + ContentType, + WebhookBodyParameter, + WebhookData, + WebhookParameter, +) +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.file.models import FileTransferMethod +from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -57,7 +64,7 @@ class WebhookService: @classmethod def get_webhook_trigger_and_workflow( cls, webhook_id: str, is_debug: bool = False - ) -> tuple[WorkflowWebhookTrigger, Workflow, Mapping[str, Any]]: + ) -> tuple[WorkflowWebhookTrigger, Workflow, NodeConfigDict]: """Get webhook trigger, workflow, and node configuration. Args: @@ -135,7 +142,7 @@ class WebhookService: @classmethod def extract_and_validate_webhook_data( - cls, webhook_trigger: WorkflowWebhookTrigger, node_config: Mapping[str, Any] + cls, webhook_trigger: WorkflowWebhookTrigger, node_config: NodeConfigDict ) -> dict[str, Any]: """Extract and validate webhook data in a single unified process. @@ -153,7 +160,7 @@ class WebhookService: raw_data = cls.extract_webhook_data(webhook_trigger) # Validate HTTP metadata (method, content-type) - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) validation_result = cls._validate_http_metadata(raw_data, node_data) if not validation_result["valid"]: raise ValueError(validation_result["error"]) @@ -192,7 +199,7 @@ class WebhookService: content_type = cls._extract_content_type(dict(request.headers)) # Route to appropriate extractor based on content type - extractors = { + extractors: dict[str, Callable[[], tuple[dict[str, Any], dict[str, Any]]]] = { "application/json": cls._extract_json_body, "application/x-www-form-urlencoded": cls._extract_form_body, "multipart/form-data": lambda: cls._extract_multipart_body(webhook_trigger), @@ -214,7 +221,7 @@ class WebhookService: return data @classmethod - def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _process_and_validate_data(cls, raw_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Process and validate webhook data according to node configuration. Args: @@ -230,18 +237,13 @@ class WebhookService: result = raw_data.copy() # Validate and process headers - cls._validate_required_headers(raw_data["headers"], node_data.get("headers", [])) + cls._validate_required_headers(raw_data["headers"], node_data.headers) # Process query parameters with type conversion and validation - result["query_params"] = cls._process_parameters( - raw_data["query_params"], node_data.get("params", []), is_form_data=True - ) + result["query_params"] = cls._process_parameters(raw_data["query_params"], node_data.params, is_form_data=True) # Process body parameters based on content type - configured_content_type = node_data.get("content_type", "application/json").lower() - result["body"] = cls._process_body_parameters( - raw_data["body"], node_data.get("body", []), configured_content_type - ) + result["body"] = cls._process_body_parameters(raw_data["body"], node_data.body, node_data.content_type) return result @@ -424,7 +426,11 @@ class WebhookService: @classmethod def _process_parameters( - cls, raw_params: dict[str, str], param_configs: list, is_form_data: bool = False + cls, + raw_params: dict[str, str], + param_configs: Sequence[WebhookParameter], + *, + is_form_data: bool = False, ) -> dict[str, Any]: """Process parameters with unified validation and type conversion. @@ -440,13 +446,13 @@ class WebhookService: ValueError: If required parameters are missing or validation fails """ processed = {} - configured_params = {config.get("name", ""): config for config in param_configs} + configured_params = {config.name: config for config in param_configs} # Process configured parameters for param_config in param_configs: - name = param_config.get("name", "") - param_type = param_config.get("type", SegmentType.STRING) - required = param_config.get("required", False) + name = param_config.name + param_type = param_config.type + required = param_config.required # Check required parameters if required and name not in raw_params: @@ -465,7 +471,10 @@ class WebhookService: @classmethod def _process_body_parameters( - cls, raw_body: dict[str, Any], body_configs: list, content_type: str + cls, + raw_body: dict[str, Any], + body_configs: Sequence[WebhookBodyParameter], + content_type: ContentType, ) -> dict[str, Any]: """Process body parameters based on content type and configuration. @@ -480,25 +489,28 @@ class WebhookService: Raises: ValueError: If required body parameters are missing or validation fails """ - if content_type in ["text/plain", "application/octet-stream"]: - # For text/plain and octet-stream, validate required content exists - if body_configs and any(config.get("required", False) for config in body_configs): - raw_content = raw_body.get("raw") - if not raw_content: - raise ValueError(f"Required body content missing for {content_type} request") - return raw_body + match content_type: + case ContentType.TEXT | ContentType.BINARY: + # For text/plain and octet-stream, validate required content exists + if body_configs and any(config.required for config in body_configs): + raw_content = raw_body.get("raw") + if not raw_content: + raise ValueError(f"Required body content missing for {content_type} request") + return raw_body + case _: + pass # For structured data (JSON, form-data, etc.) processed = {} - configured_params = {config.get("name", ""): config for config in body_configs} + configured_params: dict[str, WebhookBodyParameter] = {config.name: config for config in body_configs} for body_config in body_configs: - name = body_config.get("name", "") - param_type = body_config.get("type", SegmentType.STRING) - required = body_config.get("required", False) + name = body_config.name + param_type = body_config.type + required = body_config.required # Handle file parameters for multipart data - if param_type == SegmentType.FILE and content_type == "multipart/form-data": + if param_type == SegmentType.FILE and content_type == ContentType.FORM_DATA: # File validation is handled separately in extract phase continue @@ -508,7 +520,7 @@ class WebhookService: if name in raw_body: raw_value = raw_body[name] - is_form_data = content_type in ["application/x-www-form-urlencoded", "multipart/form-data"] + is_form_data = content_type in [ContentType.FORM_URLENCODED, ContentType.FORM_DATA] processed[name] = cls._validate_and_convert_value(name, raw_value, param_type, is_form_data) # Include unconfigured parameters @@ -519,7 +531,9 @@ class WebhookService: return processed @classmethod - def _validate_and_convert_value(cls, param_name: str, value: Any, param_type: str, is_form_data: bool) -> Any: + def _validate_and_convert_value( + cls, param_name: str, value: Any, param_type: SegmentType | str, is_form_data: bool + ) -> Any: """Unified validation and type conversion for parameter values. Args: @@ -532,7 +546,8 @@ class WebhookService: Any: The validated and converted value Raises: - ValueError: If validation or conversion fails + ValueError: If validation or conversion fails. The original validation + error is preserved as ``__cause__`` for debugging. """ try: if is_form_data: @@ -542,10 +557,10 @@ class WebhookService: # JSON data should already be in correct types, just validate return cls._validate_json_value(param_name, value, param_type) except Exception as e: - raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") + raise ValueError(f"Parameter '{param_name}' validation failed: {str(e)}") from e @classmethod - def _convert_form_value(cls, param_name: str, value: str, param_type: str) -> Any: + def _convert_form_value(cls, param_name: str, value: str, param_type: SegmentType | str) -> Any: """Convert form data string values to specified types. Args: @@ -576,7 +591,7 @@ class WebhookService: raise ValueError(f"Unsupported type '{param_type}' for form data parameter '{param_name}'") @classmethod - def _validate_json_value(cls, param_name: str, value: Any, param_type: str) -> Any: + def _validate_json_value(cls, param_name: str, value: Any, param_type: SegmentType | str) -> Any: """Validate JSON values against expected types. Args: @@ -590,43 +605,43 @@ class WebhookService: Raises: ValueError: If the value type doesn't match the expected type """ - type_validators = { - SegmentType.STRING: (lambda v: isinstance(v, str), "string"), - SegmentType.NUMBER: (lambda v: isinstance(v, (int, float)), "number"), - SegmentType.BOOLEAN: (lambda v: isinstance(v, bool), "boolean"), - SegmentType.OBJECT: (lambda v: isinstance(v, dict), "object"), - SegmentType.ARRAY_STRING: ( - lambda v: isinstance(v, list) and all(isinstance(item, str) for item in v), - "array of strings", - ), - SegmentType.ARRAY_NUMBER: ( - lambda v: isinstance(v, list) and all(isinstance(item, (int, float)) for item in v), - "array of numbers", - ), - SegmentType.ARRAY_BOOLEAN: ( - lambda v: isinstance(v, list) and all(isinstance(item, bool) for item in v), - "array of booleans", - ), - SegmentType.ARRAY_OBJECT: ( - lambda v: isinstance(v, list) and all(isinstance(item, dict) for item in v), - "array of objects", - ), - } - - validator_info = type_validators.get(SegmentType(param_type)) - if not validator_info: - logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + param_type_enum = cls._coerce_segment_type(param_type, param_name=param_name) + if param_type_enum is None: return value - validator, expected_type = validator_info - if not validator(value): + if not param_type_enum.is_valid(value, array_validation=ArrayValidation.ALL): actual_type = type(value).__name__ + expected_type = cls._expected_type_label(param_type_enum) raise ValueError(f"Expected {expected_type}, got {actual_type}") return value @classmethod - def _validate_required_headers(cls, headers: dict[str, Any], header_configs: list) -> None: + def _coerce_segment_type(cls, param_type: SegmentType | str, *, param_name: str) -> SegmentType | None: + if isinstance(param_type, SegmentType): + return param_type + try: + return SegmentType(param_type) + except Exception: + logger.warning("Unknown parameter type: %s for parameter %s", param_type, param_name) + return None + + @staticmethod + def _expected_type_label(param_type: SegmentType) -> str: + match param_type: + case SegmentType.ARRAY_STRING: + return "array of strings" + case SegmentType.ARRAY_NUMBER: + return "array of numbers" + case SegmentType.ARRAY_BOOLEAN: + return "array of booleans" + case SegmentType.ARRAY_OBJECT: + return "array of objects" + case _: + return param_type.value + + @classmethod + def _validate_required_headers(cls, headers: dict[str, Any], header_configs: Sequence[WebhookParameter]) -> None: """Validate required headers are present. Args: @@ -639,14 +654,14 @@ class WebhookService: headers_lower = {k.lower(): v for k, v in headers.items()} headers_sanitized = {cls._sanitize_key(k).lower(): v for k, v in headers.items()} for header_config in header_configs: - if header_config.get("required", False): - header_name = header_config.get("name", "") + if header_config.required: + header_name = header_config.name sanitized_name = cls._sanitize_key(header_name).lower() if header_name.lower() not in headers_lower and sanitized_name not in headers_sanitized: raise ValueError(f"Required header missing: {header_name}") @classmethod - def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: dict[str, Any]) -> dict[str, Any]: + def _validate_http_metadata(cls, webhook_data: dict[str, Any], node_data: WebhookData) -> dict[str, Any]: """Validate HTTP method and content-type. Args: @@ -657,13 +672,13 @@ class WebhookService: dict[str, Any]: Validation result with 'valid' key and optional 'error' key """ # Validate HTTP method - configured_method = node_data.get("method", "get").upper() + configured_method = node_data.method.value.upper() request_method = webhook_data["method"].upper() if configured_method != request_method: return cls._validation_error(f"HTTP method mismatch. Expected {configured_method}, got {request_method}") # Validate Content-type - configured_content_type = node_data.get("content_type", "application/json").lower() + configured_content_type = node_data.content_type.value.lower() request_content_type = cls._extract_content_type(webhook_data["headers"]) if configured_content_type != request_content_type: @@ -788,7 +803,7 @@ class WebhookService: raise @classmethod - def generate_webhook_response(cls, node_config: Mapping[str, Any]) -> tuple[dict[str, Any], int]: + def generate_webhook_response(cls, node_config: NodeConfigDict) -> tuple[dict[str, Any], int]: """Generate HTTP response based on node configuration. Args: @@ -797,11 +812,11 @@ class WebhookService: Returns: tuple[dict[str, Any], int]: Response data and HTTP status code """ - node_data = node_config.get("data", {}) + node_data = WebhookData.model_validate(node_config["data"], from_attributes=True) # Get configured status code and response body - status_code = node_data.get("status_code", 200) - response_body = node_data.get("response_body", "") + status_code = node_data.status_code + response_body = node_data.response_body # Parse response body as JSON if it's valid JSON, otherwise return as text try: @@ -847,7 +862,7 @@ class WebhookService: node_id: str webhook_id: str - nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(NodeType.TRIGGER_WEBHOOK)] + nodes_id_in_graph = [node_id for node_id, _ in workflow.walk_nodes(TRIGGER_WEBHOOK_NODE_TYPE)] # Check webhook node limit if len(nodes_id_in_graph) > cls.MAX_WEBHOOK_NODES_PER_WORKFLOW: diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 12be12776a..60dc1dedb8 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -6,9 +6,9 @@ from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload from configs import dify_config -from core.workflow.file.models import File -from core.workflow.nodes.variable_assigner.common.helpers import UpdatedVariable -from core.workflow.variables.segments import ( +from dify_graph.file.models import File +from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable +from dify_graph.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -20,7 +20,7 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.utils import dumps_with_segments _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index f1fa33cb75..b66fdd7a20 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,7 +1,6 @@ import logging from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector from core.rag.index_processor.constant.doc_type import DocType @@ -9,6 +8,7 @@ from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding @@ -156,7 +156,8 @@ class VectorService: ) # use full doc mode to generate segment's child chunk processing_rule_dict = processing_rule.to_dict() - processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC + if processing_rule_dict["rules"] is not None: + processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC documents = index_processor.transform( [document], embedding_model_instance=embedding_model_instance, diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py index 560aec2330..e028e3e5e3 100644 --- a/api/services/web_conversation_service.py +++ b/api/services/web_conversation_service.py @@ -7,6 +7,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account +from models.enums import CreatorUserRole from models.model import App, EndUser from models.web import PinnedConversation from services.conversation_service import ConversationService @@ -84,7 +85,7 @@ class WebConversationService: pinned_conversation = PinnedConversation( app_id=app_model.id, conversation_id=conversation.id, - created_by_role="account" if isinstance(user, Account) else "end_user", + created_by_role=CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER, created_by=user.id, ) diff --git a/api/services/website_service.py b/api/services/website_service.py index fe48c3b08e..b2917ba152 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -9,7 +9,7 @@ import httpx from flask_login import current_user from core.helper import encrypter -from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp +from core.rag.extractor.firecrawl.firecrawl_app import CrawlStatusResponse, FirecrawlApp, FirecrawlDocumentData from core.rag.extractor.watercrawl.provider import WaterCrawlProvider from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -124,7 +124,7 @@ class WebsiteService: if provider == "firecrawl": plugin_id = "langgenius/firecrawl_datasource" elif provider == "watercrawl": - plugin_id = "langgenius/watercrawl_datasource" + plugin_id = "watercrawl/watercrawl_datasource" elif provider == "jinareader": plugin_id = "langgenius/jina_datasource" else: @@ -216,8 +216,10 @@ class WebsiteService: "max_depth": request.options.max_depth, "use_sitemap": request.options.use_sitemap, } - return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( - url=request.url, options=options + return dict( + WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( + url=request.url, options=options + ) ) @classmethod @@ -270,13 +272,13 @@ class WebsiteService: @classmethod def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) - result = firecrawl_app.check_crawl_status(job_id) - crawl_status_data = { - "status": result.get("status", "active"), + result: CrawlStatusResponse = firecrawl_app.check_crawl_status(job_id) + crawl_status_data: dict[str, Any] = { + "status": result["status"], "job_id": job_id, - "total": result.get("total", 0), - "current": result.get("current", 0), - "data": result.get("data", []), + "total": result["total"] or 0, + "current": result["current"] or 0, + "data": result["data"], } if crawl_status_data["status"] == "completed": website_crawl_time_cache_key = f"website_crawl_{job_id}" @@ -289,8 +291,8 @@ class WebsiteService: return crawl_status_data @classmethod - def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: - return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id) + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict[str, Any]) -> dict[str, Any]: + return dict(WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)) @classmethod def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: @@ -343,7 +345,7 @@ class WebsiteService: @classmethod def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: - crawl_data: list[dict[str, Any]] | None = None + crawl_data: list[FirecrawlDocumentData] | None = None file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): stored_data = storage.load_once(file_key) @@ -352,19 +354,22 @@ class WebsiteService: else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) result = firecrawl_app.check_crawl_status(job_id) - if result.get("status") != "completed": + if result["status"] != "completed": raise ValueError("Crawl job is not completed") - crawl_data = result.get("data") + crawl_data = result["data"] if crawl_data: for item in crawl_data: - if item.get("source_url") == url: + if item["source_url"] == url: return dict(item) return None @classmethod - def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: - return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + def _get_watercrawl_url_data( + cls, job_id: str, url: str, api_key: str, config: dict[str, Any] + ) -> dict[str, Any] | None: + result = WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + return dict(result) if result is not None else None @classmethod def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: @@ -416,8 +421,8 @@ class WebsiteService: def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) params = {"onlyMainContent": request.only_main_content} - return firecrawl_app.scrape_url(url=request.url, params=params) + return dict(firecrawl_app.scrape_url(url=request.url, params=params)) @classmethod - def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: - return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url) + def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict[str, Any]) -> dict[str, Any]: + return dict(WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)) diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index 5527c108a2..f0596e44c8 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,5 +1,7 @@ import json -from typing import Any, TypedDict +from typing import Any + +from typing_extensions import TypedDict from core.app.app_config.entities import ( DatasetEntity, @@ -13,18 +15,18 @@ from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManage from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.helper import encrypter -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file.models import FileUploadConfig -from core.workflow.nodes import NodeType -from core.workflow.variables.input_entities import VariableEntity +from dify_graph.file.models import FileUploadConfig +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db from models import Account from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, IconType from models.workflow import Workflow, WorkflowType @@ -34,6 +36,17 @@ class _NodeType(TypedDict): data: dict[str, Any] +class _EdgeType(TypedDict): + id: str + source: str + target: str + + +class WorkflowGraph(TypedDict): + nodes: list[_NodeType] + edges: list[_EdgeType] + + class WorkflowConverter: """ App Convert to Workflow Mode @@ -72,7 +85,7 @@ class WorkflowConverter: new_app.tenant_id = app_model.tenant_id new_app.name = name or app_model.name + "(workflow)" new_app.mode = AppMode.ADVANCED_CHAT if app_model.mode == AppMode.CHAT else AppMode.WORKFLOW - new_app.icon_type = icon_type or app_model.icon_type + new_app.icon_type = IconType(icon_type) if icon_type else app_model.icon_type new_app.icon = icon or app_model.icon new_app.icon_background = icon_background or app_model.icon_background new_app.enable_site = app_model.enable_site @@ -107,7 +120,7 @@ class WorkflowConverter: app_config = self._convert_to_app_config(app_model=app_model, app_model_config=app_model_config) # init workflow graph - graph: dict[str, Any] = {"nodes": [], "edges": []} + graph: WorkflowGraph = {"nodes": [], "edges": []} # Convert list: # - variables -> start @@ -234,7 +247,7 @@ class WorkflowConverter: "position": None, "data": { "title": "START", - "type": NodeType.START, + "type": BuiltinNodeTypes.START, "variables": [jsonable_encoder(v) for v in variables], }, } @@ -296,7 +309,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"HTTP REQUEST {api_based_extension.name}", - "type": NodeType.HTTP_REQUEST, + "type": BuiltinNodeTypes.HTTP_REQUEST, "method": "post", "url": api_based_extension.api_endpoint, "authorization": {"type": "api-key", "config": {"type": "bearer", "api_key": api_key}}, @@ -314,7 +327,7 @@ class WorkflowConverter: "position": None, "data": { "title": f"Parse {api_based_extension.name} Response", - "type": NodeType.CODE, + "type": BuiltinNodeTypes.CODE, "variables": [{"variable": "response_json", "value_selector": [http_request_node["id"], "body"]}], "code_language": "python3", "code": "import json\n\ndef main(response_json: str) -> str:\n response_body = json.loads(" @@ -354,7 +367,7 @@ class WorkflowConverter: "position": None, "data": { "title": "KNOWLEDGE RETRIEVAL", - "type": NodeType.KNOWLEDGE_RETRIEVAL, + "type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "query_variable_selector": query_variable_selector, "dataset_ids": dataset_config.dataset_ids, "retrieval_mode": retrieve_config.retrieve_strategy.value, @@ -385,7 +398,7 @@ class WorkflowConverter: self, original_app_mode: AppMode, new_app_mode: AppMode, - graph: dict, + graph: WorkflowGraph, model_config: ModelConfigEntity, prompt_template: PromptTemplateEntity, file_upload: FileUploadConfig | None = None, @@ -402,9 +415,9 @@ class WorkflowConverter: :param external_data_variable_node_mapping: external data variable node mapping """ # fetch start and knowledge retrieval node - start_node = next(filter(lambda n: n["data"]["type"] == NodeType.START, graph["nodes"])) + start_node = next(filter(lambda n: n["data"]["type"] == BuiltinNodeTypes.START, graph["nodes"])) knowledge_retrieval_node = next( - filter(lambda n: n["data"]["type"] == NodeType.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None + filter(lambda n: n["data"]["type"] == BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, graph["nodes"]), None ) role_prefix = None @@ -523,7 +536,7 @@ class WorkflowConverter: "position": None, "data": { "title": "LLM", - "type": NodeType.LLM, + "type": BuiltinNodeTypes.LLM, "model": { "provider": model_config.provider, "name": model_config.model, @@ -578,7 +591,7 @@ class WorkflowConverter: "position": None, "data": { "title": "END", - "type": NodeType.END, + "type": BuiltinNodeTypes.END, "outputs": [{"variable": "result", "value_selector": ["llm", "text"]}], }, } @@ -592,10 +605,10 @@ class WorkflowConverter: return { "id": "answer", "position": None, - "data": {"title": "ANSWER", "type": NodeType.ANSWER, "answer": "{{#llm.text#}}"}, + "data": {"title": "ANSWER", "type": BuiltinNodeTypes.ANSWER, "answer": "{{#llm.text#}}"}, } - def _create_edge(self, source: str, target: str): + def _create_edge(self, source: str, target: str) -> _EdgeType: """ Create Edge :param source: source node id @@ -604,7 +617,7 @@ class WorkflowConverter: """ return {"id": f"{source}-{target}", "source": source, "target": target} - def _append_node(self, graph: dict[str, Any], node: _NodeType): + def _append_node(self, graph: WorkflowGraph, node: _NodeType): """ Append Node to Graph diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 04eaeba9f3..f1601cd6be 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -5,15 +5,20 @@ from typing import Any from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session +from typing_extensions import TypedDict -from core.workflow.enums import WorkflowExecutionStatus -from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun +from dify_graph.enums import WorkflowExecutionStatus +from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog from services.plugin.plugin_service import PluginService from services.workflow.entities import TriggerMetadata +class LogViewDetails(TypedDict): + trigger_metadata: dict[str, Any] | None + + # Since the workflow_app_log table has exceeded 100 million records, we use an additional details field to extend it class LogView: """Lightweight wrapper for WorkflowAppLog with computed details. @@ -26,7 +31,7 @@ class LogView: def __init__( self, log: WorkflowAppLog, - details: dict | None, + details: LogViewDetails | None, evaluation: list[dict] | None = None, ): self.log = log @@ -34,7 +39,7 @@ class LogView: self.evaluation_ = evaluation @property - def details(self) -> dict | None: + def details(self) -> LogViewDetails | None: return self.details_ @property @@ -143,7 +148,14 @@ class WorkflowAppService: ), ) if created_by_account: - account = session.scalar(select(Account).where(Account.email == created_by_account)) + account = session.scalar( + select(Account) + .join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id) + .where( + Account.email == created_by_account, + TenantAccountJoin.tenant_id == app_model.tenant_id, + ) + ) if not account: raise ValueError(f"Account not found: {created_by_account}") diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 18ad6c5c16..f124e137c3 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,27 +14,28 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey -from core.workflow.file.models import File -from core.workflow.nodes import NodeType -from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables -from core.workflow.variable_loader import VariableLoader -from core.workflow.variables import Segment, StringSegment, VariableBase -from core.workflow.variables.consts import SELECTORS_LENGTH -from core.workflow.variables.segments import ( +from core.trigger.constants import is_trigger_node_type +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import NodeType, SystemVariableKey +from dify_graph.file.models import File +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables +from dify_graph.variable_loader import VariableLoader +from dify_graph.variables import Segment, StringSegment, VariableBase +from dify_graph.variables.consts import SELECTORS_LENGTH +from dify_graph.variables.segments import ( ArrayFileSegment, FileSegment, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.utils import dumps_with_segments +from dify_graph.variables.types import SegmentType +from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation -from models.enums import DraftVariableType +from models.enums import ConversationFromSource, DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -70,12 +71,13 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: core.workflow.variable_loader.VariableLoader + # ref: dify_graph.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine # Application ID for which variables are being loaded. _app_id: str + _user_id: str _tenant_id: str _fallback_variables: Sequence[VariableBase] @@ -84,10 +86,12 @@ class DraftVarLoader(VariableLoader): engine: Engine, app_id: str, tenant_id: str, + user_id: str, fallback_variables: Sequence[VariableBase] | None = None, ): self._engine = engine self._app_id = app_id + self._user_id = user_id self._tenant_id = tenant_id self._fallback_variables = fallback_variables or [] @@ -103,7 +107,7 @@ class DraftVarLoader(VariableLoader): with Session(bind=self._engine, expire_on_commit=False) as session: srv = WorkflowDraftVariableService(session) - draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors) + draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors, user_id=self._user_id) # Important: files: list[File] = [] @@ -217,6 +221,7 @@ class WorkflowDraftVariableService: self, app_id: str, selectors: Sequence[list[str]], + user_id: str, ) -> list[WorkflowDraftVariable]: """ Retrieve WorkflowDraftVariable instances based on app_id and selectors. @@ -237,22 +242,30 @@ class WorkflowDraftVariableService: # Alternatively, a `SELECT` statement could be constructed for each selector and # combined using `UNION` to fetch all rows. # Benchmarking indicates that both approaches yield comparable performance. - variables = ( + query = ( self._session.query(WorkflowDraftVariable) .options( orm.selectinload(WorkflowDraftVariable.variable_file).selectinload( WorkflowDraftVariableFile.upload_file ) ) - .where(WorkflowDraftVariable.app_id == app_id, or_(*ors)) - .all() + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + or_(*ors), + ) ) - return variables + return query.all() - def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList: - criteria = WorkflowDraftVariable.app_id == app_id + def list_variables_without_values( + self, app_id: str, page: int, limit: int, user_id: str + ) -> WorkflowDraftVariableList: + criteria = [ + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + ] total = None - query = self._session.query(WorkflowDraftVariable).where(criteria) + query = self._session.query(WorkflowDraftVariable).where(*criteria) if page == 1: total = query.count() variables = ( @@ -268,11 +281,12 @@ class WorkflowDraftVariableService: return WorkflowDraftVariableList(variables=variables, total=total) - def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: - criteria = ( + def _list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList: + criteria = [ WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, - ) + WorkflowDraftVariable.user_id == user_id, + ] query = self._session.query(WorkflowDraftVariable).where(*criteria) variables = ( query.options(orm.selectinload(WorkflowDraftVariable.variable_file)) @@ -281,36 +295,36 @@ class WorkflowDraftVariableService: ) return WorkflowDraftVariableList(variables=variables) - def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList: - return self._list_node_variables(app_id, node_id) + def list_node_variables(self, app_id: str, node_id: str, user_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, node_id, user_id=user_id) - def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList: - return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID) + def list_conversation_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID, user_id=user_id) - def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList: - return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID) + def list_system_variables(self, app_id: str, user_id: str) -> WorkflowDraftVariableList: + return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID, user_id=user_id) - def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: - return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name) + def get_conversation_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, user_id=user_id) - def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None: - return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name) + def get_system_variable(self, app_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, user_id=user_id) - def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: - return self._get_variable(app_id, node_id, name) + def get_node_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: + return self._get_variable(app_id, node_id, name, user_id=user_id) - def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None: - variable = ( + def _get_variable(self, app_id: str, node_id: str, name: str, user_id: str) -> WorkflowDraftVariable | None: + return ( self._session.query(WorkflowDraftVariable) .options(orm.selectinload(WorkflowDraftVariable.variable_file)) .where( WorkflowDraftVariable.app_id == app_id, WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name, + WorkflowDraftVariable.user_id == user_id, ) .first() ) - return variable def update_variable( self, @@ -386,7 +400,7 @@ class WorkflowDraftVariableService: # # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping` # and `save` methods. - if node_type == NodeType.VARIABLE_ASSIGNER: + if node_type == BuiltinNodeTypes.VARIABLE_ASSIGNER: return variable output_value = outputs_dict.get(variable.name, absent) else: @@ -461,7 +475,17 @@ class WorkflowDraftVariableService: self._session.delete(upload_file) self._session.delete(variable) - def delete_workflow_variables(self, app_id: str): + def delete_user_workflow_variables(self, app_id: str, user_id: str): + ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.user_id == user_id, + ) + .delete(synchronize_session=False) + ) + + def delete_app_workflow_variables(self, app_id: str): ( self._session.query(WorkflowDraftVariable) .where(WorkflowDraftVariable.app_id == app_id) @@ -500,28 +524,35 @@ class WorkflowDraftVariableService: self._session.delete(upload_file) self._session.delete(variable_file) - def delete_node_variables(self, app_id: str, node_id: str): - return self._delete_node_variables(app_id, node_id) + def delete_node_variables(self, app_id: str, node_id: str, user_id: str): + return self._delete_node_variables(app_id, node_id, user_id=user_id) - def _delete_node_variables(self, app_id: str, node_id: str): - self._session.query(WorkflowDraftVariable).where( - WorkflowDraftVariable.app_id == app_id, - WorkflowDraftVariable.node_id == node_id, - ).delete() + def _delete_node_variables(self, app_id: str, node_id: str, user_id: str): + ( + self._session.query(WorkflowDraftVariable) + .where( + WorkflowDraftVariable.app_id == app_id, + WorkflowDraftVariable.node_id == node_id, + WorkflowDraftVariable.user_id == user_id, + ) + .delete(synchronize_session=False) + ) - def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None: + def _get_conversation_id_from_draft_variable(self, app_id: str, user_id: str) -> str | None: draft_var = self._get_variable( app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=str(SystemVariableKey.CONVERSATION_ID), + user_id=user_id, ) if draft_var is None: return None segment = draft_var.get_value() if not isinstance(segment, StringSegment): logger.warning( - "sys.conversation_id variable is not a string: app_id=%s, id=%s", + "sys.conversation_id variable is not a string: app_id=%s, user_id=%s, id=%s", app_id, + user_id, draft_var.id, ) return None @@ -542,7 +573,7 @@ class WorkflowDraftVariableService: If no such conversation exists, a new conversation is created and its ID is returned. """ - conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id) + conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id, account_id) if conv_id is not None: conversation = ( @@ -570,7 +601,7 @@ class WorkflowDraftVariableService: system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.DEBUGGER, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account_id, ) @@ -579,12 +610,13 @@ class WorkflowDraftVariableService: self._session.flush() return conversation.id - def prefill_conversation_variable_default_values(self, workflow: Workflow): + def prefill_conversation_variable_default_values(self, workflow: Workflow, user_id: str): """""" draft_conv_vars: list[WorkflowDraftVariable] = [] for conv_var in workflow.conversation_variables: draft_var = WorkflowDraftVariable.new_conversation_variable( app_id=workflow.app_id, + user_id=user_id, name=conv_var.name, value=conv_var, description=conv_var.description, @@ -634,7 +666,7 @@ def _batch_upsert_draft_variable( stmt = pg_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) if policy == _UpsertPolicy.OVERWRITE: stmt = stmt.on_conflict_do_update( - index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(), + index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name(), set_={ # Refresh creation timestamp to ensure updated variables # appear first in chronologically sorted result sets. @@ -651,7 +683,9 @@ def _batch_upsert_draft_variable( }, ) elif policy == _UpsertPolicy.IGNORE: - stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name()) + stmt = stmt.on_conflict_do_nothing( + index_elements=WorkflowDraftVariable.unique_app_id_user_id_node_id_name() + ) else: stmt = mysql_insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars]) # type: ignore[assignment] if policy == _UpsertPolicy.OVERWRITE: @@ -681,6 +715,7 @@ def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]: d: dict[str, Any] = { "id": model.id, "app_id": model.app_id, + "user_id": model.user_id, "last_edited_at": None, "node_id": model.node_id, "name": model.name, @@ -753,8 +788,8 @@ class DraftVariableSaver: # technical variables from being exposed in the draft environment, particularly those # that aren't meant to be directly edited or viewed by users. _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = { - NodeType.LLM: frozenset(["finish_reason"]), - NodeType.LOOP: frozenset(["loop_round"]), + BuiltinNodeTypes.LLM: frozenset(["finish_reason"]), + BuiltinNodeTypes.LOOP: frozenset(["loop_round"]), } # Database session used for persisting draft variables. @@ -806,6 +841,7 @@ class DraftVariableSaver: def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, + user_id=self._user.id, node_id=self._node_id, name=self._DUMMY_OUTPUT_IDENTITY, node_execution_id=self._node_execution_id, @@ -815,7 +851,7 @@ class DraftVariableSaver: ) def _should_save_output_variables_for_draft(self) -> bool: - if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER: + if self._enclosing_node_id is not None and self._node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: # Currently we do not save output variables for nodes inside loop or iteration. return False return True @@ -841,6 +877,7 @@ class DraftVariableSaver: draft_vars.append( WorkflowDraftVariable.new_conversation_variable( app_id=self._app_id, + user_id=self._user.id, name=item.name, value=segment, ) @@ -861,6 +898,7 @@ class DraftVariableSaver: draft_vars.append( WorkflowDraftVariable.new_node_variable( app_id=self._app_id, + user_id=self._user.id, node_id=self._node_id, name=name, node_execution_id=self._node_execution_id, @@ -883,6 +921,7 @@ class DraftVariableSaver: draft_vars.append( WorkflowDraftVariable.new_sys_variable( app_id=self._app_id, + user_id=self._user.id, name=name, node_execution_id=self._node_execution_id, value=value_seg, @@ -1018,6 +1057,7 @@ class DraftVariableSaver: # Create the draft variable draft_var = WorkflowDraftVariable.new_node_variable( app_id=self._app_id, + user_id=self._user.id, node_id=self._node_id, name=name, node_execution_id=self._node_execution_id, @@ -1031,6 +1071,7 @@ class DraftVariableSaver: # Create the draft variable draft_var = WorkflowDraftVariable.new_node_variable( app_id=self._app_id, + user_id=self._user.id, node_id=self._node_id, name=name, node_execution_id=self._node_execution_id, @@ -1053,9 +1094,9 @@ class DraftVariableSaver: process_data = {} if not self._should_save_output_variables_for_draft(): return - if self._node_type == NodeType.VARIABLE_ASSIGNER: + if self._node_type == BuiltinNodeTypes.VARIABLE_ASSIGNER: draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data) - elif self._node_type == NodeType.START or self._node_type.is_trigger_node: + elif self._node_type == BuiltinNodeTypes.START or is_trigger_node_type(self._node_type): draft_vars = self._build_variables_from_start_mapping(outputs) else: draft_vars = self._build_variables_from_mapping(outputs) @@ -1071,7 +1112,7 @@ class DraftVariableSaver: @staticmethod def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool: - if node_type in NodeType.IF_ELSE: + if node_type == BuiltinNodeTypes.IF_ELSE: return False if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name): return False diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 09037a92ce..8f323ebb8b 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -22,10 +22,10 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.entities import WorkflowStartReason -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities import WorkflowStartReason +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_restore.py b/api/services/workflow_restore.py new file mode 100644 index 0000000000..083235d228 --- /dev/null +++ b/api/services/workflow_restore.py @@ -0,0 +1,58 @@ +"""Shared helpers for restoring published workflow snapshots into drafts. + +Both app workflows and RAG pipeline workflows restore the same workflow fields +from a published snapshot into a draft. Keeping that field-copy logic in one +place prevents the two restore paths from drifting when we add or adjust draft +state in the future. Restore stays within a tenant, so we can safely reuse the +serialized workflow storage blobs without decrypting and re-encrypting secrets. +""" + +from collections.abc import Callable +from datetime import datetime + +from models import Account +from models.workflow import Workflow, WorkflowType + +UpdatedAtFactory = Callable[[], datetime] + + +def apply_published_workflow_snapshot_to_draft( + *, + tenant_id: str, + app_id: str, + source_workflow: Workflow, + draft_workflow: Workflow | None, + account: Account, + updated_at_factory: UpdatedAtFactory, +) -> tuple[Workflow, bool]: + """Copy a published workflow snapshot into a draft workflow record. + + The caller remains responsible for source lookup, validation, flushing, and + post-commit side effects. This helper only centralizes the shared draft + creation/update semantics used by both restore entry points. Features are + copied from the stored JSON payload so restore does not normalize and dirty + the published source row before the caller commits. + """ + if not draft_workflow: + workflow_type = ( + source_workflow.type.value if isinstance(source_workflow.type, WorkflowType) else source_workflow.type + ) + draft_workflow = Workflow( + tenant_id=tenant_id, + app_id=app_id, + type=workflow_type, + version=Workflow.VERSION_DRAFT, + graph=source_workflow.graph, + features=source_workflow.serialized_features, + created_by=account.id, + ) + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + return draft_workflow, True + + draft_workflow.graph = source_workflow.graph + draft_workflow.features = source_workflow.serialized_features + draft_workflow.updated_by = account.id + draft_workflow.updated_at = updated_at_factory() + draft_workflow.copy_serialized_variable_storage_from(source_workflow) + + return draft_workflow, False diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 3b448423e8..66976058c0 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -11,37 +11,44 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.repositories import DifyCoreRepositoryFactory from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.entities import GraphInitParams, WorkflowNodeExecution -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.errors import WorkflowNodeRunFailedError -from core.workflow.file import File -from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes import NodeType -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from core.workflow.nodes.human_input.entities import ( +from core.trigger.constants import is_trigger_node_type +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities import GraphInitParams, WorkflowNodeExecution +from dify_graph.entities.graph_config import NodeConfigDict +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file import File +from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient, validate_human_input_submission, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.repositories.human_input_form_repository import FormCreateParams -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variable_loader import load_into_variable_pool -from core.workflow.variables import VariableBase -from core.workflow.variables.input_entities import VariableEntityType -from core.workflow.variables.variables import Variable -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.nodes.human_input.enums import HumanInputFormKind +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.repositories.human_input_form_repository import FormCreateParams +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variable_loader import load_into_variable_pool +from dify_graph.variables import VariableBase +from dify_graph.variables.input_entities import VariableEntityType +from dify_graph.variables.variables import Variable from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -49,7 +56,6 @@ from extensions.ext_storage import storage from factories.file_factory import build_from_mapping, build_from_mappings from libs.datetime_utils import naive_utc_now from models import Account -from models.enums import UserFrom from models.human_input import HumanInputFormRecipient, RecipientType from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -57,7 +63,12 @@ from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeEx from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService from services.enterprise.plugin_manager_service import PluginCredentialType -from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError +from services.errors.app import ( + IsDraftWorkflowError, + TriggerNodeLimitExceededError, + WorkflowHashNotEqualError, + WorkflowNotFoundError, +) from services.workflow.workflow_converter import WorkflowConverter from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError @@ -69,6 +80,7 @@ from .human_input_delivery_test_service import ( HumanInputDeliveryTestService, ) from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService +from .workflow_restore import apply_published_workflow_snapshot_to_draft class WorkflowService: @@ -273,6 +285,43 @@ class WorkflowService: # return draft workflow return workflow + def restore_published_workflow_to_draft( + self, + *, + app_model: App, + workflow_id: str, + account: Account, + ) -> Workflow: + """Restore a published workflow snapshot into the draft workflow. + + Secret environment variables are copied server-side from the selected + published workflow so the normal draft sync flow stays stateless. + """ + source_workflow = self.get_published_workflow_by_id(app_model=app_model, workflow_id=workflow_id) + if not source_workflow: + raise WorkflowNotFoundError("Workflow not found.") + + self.validate_features_structure(app_model=app_model, features=source_workflow.normalized_features_dict) + self.validate_graph_structure(graph=source_workflow.graph_dict) + + draft_workflow = self.get_draft_workflow(app_model=app_model) + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + source_workflow=source_workflow, + draft_workflow=draft_workflow, + account=account, + updated_at_factory=naive_utc_now, + ) + + if is_new_draft: + db.session.add(draft_workflow) + + db.session.commit() + app_draft_workflow_was_synced.send(app_model, synced_draft_workflow=draft_workflow) + + return draft_workflow + def publish_workflow( self, *, @@ -310,7 +359,7 @@ class WorkflowService: for _, node_data in draft_workflow.walk_nodes() if (node_type_str := node_data.get("type")) and isinstance(node_type_str, str) - and NodeType(node_type_str).is_trigger_node + and is_trigger_node_type(node_type_str) ) if trigger_node_count > 2: raise TriggerNodeLimitExceededError(count=trigger_node_count, limit=2) @@ -438,8 +487,8 @@ class WorkflowService: """ try: from core.model_manager import ModelManager - from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager + from dify_graph.model_runtime.entities.model_entities import ModelType # Get model instance to validate provider+model combination model_manager = ModelManager() @@ -558,8 +607,8 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager + from dify_graph.model_runtime.entities.model_entities import ModelType # Get provider configurations provider_manager = ProviderManager() @@ -619,10 +668,10 @@ class WorkflowService: """ # return default block config default_block_configs: list[Mapping[str, object]] = [] - for node_type, node_class_mapping in NODE_TYPE_CLASSES_MAPPING.items(): + for node_type, node_class_mapping in get_node_type_classes_mapping().items(): node_class = node_class_mapping[LATEST_VERSION] filters = None - if node_type is NodeType.HTTP_REQUEST: + if node_type == BuiltinNodeTypes.HTTP_REQUEST: filters = { HTTP_REQUEST_CONFIG_FILTER_KEY: build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -650,14 +699,15 @@ class WorkflowService: :return: """ node_type_enum = NodeType(node_type) + node_mapping = get_node_type_classes_mapping() # return default block config - if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: + if node_type_enum not in node_mapping: return {} - node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] + node_class = node_mapping[node_type_enum][LATEST_VERSION] resolved_filters = dict(filters) if filters else {} - if node_type_enum is NodeType.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: + if node_type_enum == BuiltinNodeTypes.HTTP_REQUEST and HTTP_REQUEST_CONFIG_FILTER_KEY not in resolved_filters: resolved_filters[HTTP_REQUEST_CONFIG_FILTER_KEY] = build_http_request_config( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, max_read_timeout=dify_config.HTTP_REQUEST_MAX_READ_TIMEOUT, @@ -690,12 +740,12 @@ class WorkflowService: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(draft_workflow) + draft_var_srv.prefill_conversation_variable_default_values(draft_workflow, user_id=account.id) node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - node_data = node_config.get("data", {}) - if node_type.is_start_node: + node_data = node_config["data"] + if is_start_node_type(node_type): with Session(bind=db.engine) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) conversation_id = draft_var_srv.get_or_create_conversation( @@ -703,8 +753,8 @@ class WorkflowService: app=app_model, workflow=draft_workflow, ) - if node_type is NodeType.START: - start_data = StartNodeData.model_validate(node_data) + if node_type == BuiltinNodeTypes.START: + start_data = StartNodeData.model_validate(node_data, from_attributes=True) user_inputs = _rebuild_file_for_user_inputs_in_start_node( tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs ) @@ -733,6 +783,7 @@ class WorkflowService: engine=db.engine, app_id=app_model.id, tenant_id=app_model.tenant_id, + user_id=account.id, ) enclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) @@ -782,7 +833,7 @@ class WorkflowService: session=session, app_id=app_model.id, node_id=workflow_node_execution.node_id, - node_type=NodeType(workflow_node_execution.node_type), + node_type=workflow_node_execution.node_type, enclosing_node_id=enclosing_node_id, node_execution_id=node_execution.id, user=account, @@ -815,7 +866,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") # inputs: values used to fill missing upstream variables referenced in form_content. @@ -824,6 +875,7 @@ class WorkflowService: workflow=draft_workflow, node_config=node_config, manual_inputs=inputs or {}, + user_id=account.id, ) node = self._build_human_input_node( workflow=draft_workflow, @@ -874,7 +926,7 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") # inputs: values used to fill missing upstream variables referenced in form_content. @@ -884,6 +936,7 @@ class WorkflowService: workflow=draft_workflow, node_config=node_config, manual_inputs=inputs or {}, + user_id=account.id, ) node = self._build_human_input_node( workflow=draft_workflow, @@ -914,7 +967,7 @@ class WorkflowService: session=session, app_id=app_model.id, node_id=node_id, - node_type=NodeType.HUMAN_INPUT, + node_type=BuiltinNodeTypes.HUMAN_INPUT, node_execution_id=str(uuid.uuid4()), user=account, enclosing_node_id=enclosing_node_id, @@ -939,10 +992,10 @@ class WorkflowService: node_config = draft_workflow.get_node_config_by_id(node_id) node_type = Workflow.get_node_type_from_node_config(node_config) - if node_type is not NodeType.HUMAN_INPUT: + if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config.get("data", {})) + node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, @@ -952,7 +1005,7 @@ class WorkflowService: delivery_method = apply_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id or "", + user_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -960,6 +1013,7 @@ class WorkflowService: workflow=draft_workflow, node_config=node_config, manual_inputs=inputs or {}, + user_id=account.id, ) node = self._build_human_input_node( workflow=draft_workflow, @@ -1016,7 +1070,7 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(session_factory=db.engine, tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) params = FormCreateParams( app_id=app_model.id, workflow_execution_id=None, @@ -1060,17 +1114,19 @@ class WorkflowService: *, workflow: Workflow, account: Account, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, variable_pool: VariablePool, ) -> HumanInputNode: graph_init_params = GraphInitParams( - tenant_id=workflow.tenant_id, - app_id=workflow.app_id, workflow_id=workflow.id, graph_config=workflow.graph_dict, - user_id=account.id, - user_from=UserFrom.ACCOUNT.value, - invoke_from=InvokeFrom.DEBUGGER.value, + run_context=build_dify_run_context( + tenant_id=workflow.tenant_id, + app_id=workflow.app_id, + user_id=account.id, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ), call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -1078,10 +1134,11 @@ class WorkflowService: start_at=time.perf_counter(), ) node = HumanInputNode( - id=node_config.get("id", str(uuid.uuid4())), + id=node_config["id"], config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), ) return node @@ -1090,12 +1147,13 @@ class WorkflowService: *, app_model: App, workflow: Workflow, - node_config: Mapping[str, Any], + node_config: NodeConfigDict, manual_inputs: Mapping[str, Any], + user_id: str, ) -> VariablePool: with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.prefill_conversation_variable_default_values(workflow) + draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) variable_pool = VariablePool( system_variables=SystemVariable.default(), @@ -1108,6 +1166,7 @@ class WorkflowService: engine=db.engine, app_id=app_model.id, tenant_id=app_model.tenant_id, + user_id=user_id, ) variable_mapping = HumanInputNode.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, @@ -1324,18 +1383,18 @@ class WorkflowService: for node in node_configs: node_type = node.get("data", {}).get("type") if node_type: - node_types.add(NodeType(node_type)) + node_types.add(node_type) # start node and trigger node cannot coexist - if NodeType.START in node_types: - if any(nt.is_trigger_node for nt in node_types): + if BuiltinNodeTypes.START in node_types: + if any(is_trigger_node_type(nt) for nt in node_types): raise ValueError("Start node and trigger nodes cannot coexist in the same workflow") for node in node_configs: node_data = node.get("data", {}) node_type = node_data.get("type") - if node_type == NodeType.HUMAN_INPUT: + if node_type == BuiltinNodeTypes.HUMAN_INPUT: self._validate_human_input_node_data(node_data) def validate_features_structure(self, app_model: App, features: dict): @@ -1360,7 +1419,7 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from core.workflow.nodes.human_input.entities import HumanInputNodeData + from dify_graph.nodes.human_input.entities import HumanInputNodeData try: HumanInputNodeData.model_validate(node_data) @@ -1457,7 +1516,7 @@ def _setup_variable_pool( conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. - if node_type == NodeType.START or node_type.is_trigger_node: + if is_start_node_type(node_type): system_variable = SystemVariable( user_id=user_id, app_id=workflow.app_id, diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index 2d3d00cd50..ae55c9ee03 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -13,6 +13,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DatasetAutoDisableLog, DocumentSegment from models.dataset import Document as DatasetDocument +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red")) return - if dataset_document.indexing_status != "completed": + if dataset_document.indexing_status != IndexingStatus.COMPLETED: return indexing_cache_key = f"document_{dataset_document.id}_indexing" @@ -48,7 +49,7 @@ def add_document_to_index_task(dataset_document_id: str): session.query(DocumentSegment) .where( DocumentSegment.document_id == dataset_document.id, - DocumentSegment.status == "completed", + DocumentSegment.status == SegmentStatus.COMPLETED, ) .order_by(DocumentSegment.position.asc()) .all() @@ -139,7 +140,7 @@ def add_document_to_index_task(dataset_document_id: str): logger.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = naive_utc_now() - dataset_document.indexing_status = "error" + dataset_document.indexing_status = IndexingStatus.ERROR dataset_document.error = str(e) session.commit() finally: diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index 4f8e2fec7a..1fe43c3d62 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -11,6 +11,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset +from models.enums import CollectionBindingType from models.model import App, AppAnnotationSetting, MessageAnnotation from services.dataset_service import DatasetCollectionBindingService @@ -47,7 +48,7 @@ def enable_annotation_reply_task( try: documents = [] dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_provider_name, embedding_model_name, "annotation" + embedding_provider_name, embedding_model_name, CollectionBindingType.ANNOTATION ) annotation_setting = ( session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first() @@ -56,7 +57,7 @@ def enable_annotation_reply_task( if dataset_collection_binding.id != annotation_setting.collection_binding_id: old_dataset_collection_binding = ( DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( - annotation_setting.collection_binding_id, "annotation" + annotation_setting.collection_binding_id, CollectionBindingType.ANNOTATION ) ) if old_dataset_collection_binding and annotations: diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index e58d334f41..174aa50343 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -21,7 +21,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from libs.flask_utils import set_login_user from models.account import Account @@ -321,7 +321,13 @@ def _resume_app_execution(payload: dict[str, Any]) -> None: return message = session.scalar( - select(Message).where(Message.workflow_run_id == workflow_run_id).order_by(Message.created_at.desc()) + select(Message) + .where( + Message.conversation_id == conversation.id, + Message.workflow_run_id == workflow_run_id, + ) + .order_by(Message.created_at.desc()) + .limit(1) ) if message is None: logger.warning("Message not found for workflow run %s", workflow_run_id) diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index cc96542d4b..d247cf5cf7 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -21,7 +21,7 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from core.workflow.runtime import GraphRuntimeState +from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index f69f17b16d..49dee00919 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -11,7 +11,7 @@ from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index b5e472d71e..b3cbc73d6e 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -10,6 +10,7 @@ from core.rag.models.document import Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "waiting": + if segment.status != SegmentStatus.WAITING: return indexing_cache_key = f"segment_{segment.id}_indexing" @@ -40,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment status to indexing session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "indexing", + DocumentSegment.status: SegmentStatus.INDEXING, DocumentSegment.indexing_at: naive_utc_now(), } ) @@ -70,7 +71,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -82,7 +83,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N # update segment to completed session.query(DocumentSegment).filter_by(id=segment.id).update( { - DocumentSegment.status: "completed", + DocumentSegment.status: SegmentStatus.COMPLETED, DocumentSegment.completed_at: naive_utc_now(), } ) @@ -94,7 +95,7 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N logger.exception("create segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index fddd9199d1..f99e90062f 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -12,6 +12,7 @@ from core.rag.extractor.notion_extractor import NotionExtractor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - if document.indexing_status == "parsing": + if document.indexing_status == IndexingStatus.PARSING: logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) return @@ -88,7 +89,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() return @@ -128,7 +129,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): data_source_info["last_edited_time"] = last_edited_time document.data_source_info = json.dumps(data_source_info) - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) @@ -151,6 +152,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): with session_factory.create_session() as session, session.begin(): document = session.query(Document).filter_by(id=document_id).first() if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 11edcf151f..e05d63426c 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,9 +1,10 @@ import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Sequence +from typing import Any, Protocol import click -from celery import shared_task +from celery import current_app, shared_task from configs import dify_config from core.db.session_factory import session_factory @@ -13,12 +14,19 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.feature_service import FeatureService from tasks.generate_summary_index_task import generate_summary_index_task logger = logging.getLogger(__name__) +class CeleryTaskLike(Protocol): + def delay(self, *args: Any, **kwargs: Any) -> Any: ... + + def apply_async(self, *args: Any, **kwargs: Any) -> Any: ... + + @shared_task(queue="dataset") def document_indexing_task(dataset_id: str, document_ids: list): """ @@ -74,7 +82,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -89,7 +97,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): for document in documents: if document: - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) # Transaction committed and closed @@ -141,7 +149,7 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): document.need_summary, ) if ( - document.indexing_status == "completed" + document.indexing_status == IndexingStatus.COMPLETED and document.doc_form != "qa_model" and document.need_summary is True ): @@ -179,8 +187,8 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): def _document_indexing_with_tenant_queue( - tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: Callable[[str, str, Sequence[str]], None] -): + tenant_id: str, dataset_id: str, document_ids: Sequence[str], task_func: CeleryTaskLike +) -> None: try: _document_indexing(dataset_id, document_ids) except Exception: @@ -201,16 +209,20 @@ def _document_indexing_with_tenant_queue( logger.info("document indexing tenant isolation queue %s next tasks: %s", tenant_id, next_tasks) if next_tasks: - for next_task in next_tasks: - document_task = DocumentTask(**next_task) - # Process the next waiting task - # Keep the flag set to indicate a task is running - tenant_isolated_task_queue.set_task_waiting_time() - task_func.delay( # type: ignore - tenant_id=document_task.tenant_id, - dataset_id=document_task.dataset_id, - document_ids=document_task.document_ids, - ) + with current_app.producer_or_acquire() as producer: # type: ignore + for next_task in next_tasks: + document_task = DocumentTask(**next_task) + # Keep the flag set to indicate a task is running + tenant_isolated_task_queue.set_task_waiting_time() + task_func.apply_async( + kwargs={ + "tenant_id": document_task.tenant_id, + "dataset_id": document_task.dataset_id, + "document_ids": document_task.document_ids, + }, + producer=producer, + ) + else: # No more waiting tasks, clear the flag tenant_isolated_task_queue.delete_task_key() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index c7508c6d05..62bce24de4 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -10,6 +10,7 @@ from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Document not found: {document_id}", fg="red")) return - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 00a963255b..13c651753f 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -15,6 +15,7 @@ from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -112,7 +113,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st ) for document in documents: if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -146,7 +147,7 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 41ebb0b076..5ad17d75d4 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -12,6 +12,7 @@ from core.rag.models.document import AttachmentDocument, ChildDocument, Document from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import DocumentSegment +from models.enums import IndexingStatus, SegmentStatus logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ def enable_segment_to_index_task(segment_id: str): logger.info(click.style(f"Segment not found: {segment_id}", fg="red")) return - if segment.status != "completed": + if segment.status != SegmentStatus.COMPLETED: logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red")) return @@ -65,7 +66,7 @@ def enable_segment_to_index_task(segment_id: str): if ( not dataset_document.enabled or dataset_document.archived - or dataset_document.indexing_status != "completed" + or dataset_document.indexing_status != IndexingStatus.COMPLETED ): logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan")) return @@ -123,7 +124,7 @@ def enable_segment_to_index_task(segment_id: str): logger.exception("enable segment to index failed") segment.enabled = False segment.disabled_at = naive_utc_now() - segment.status = "error" + segment.status = SegmentStatus.ERROR segment.error = str(e) session.commit() finally: diff --git a/api/tasks/generate_summary_index_task.py b/api/tasks/generate_summary_index_task.py index e4273e16b5..6493833edc 100644 --- a/api/tasks/generate_summary_index_task.py +++ b/api/tasks/generate_summary_index_task.py @@ -14,7 +14,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None): """ Async generate summary index for document segments. diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index 5413a33d6a..dd3b6a4530 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -7,8 +7,8 @@ from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import ensure_naive_utc, naive_utc_now @@ -58,7 +58,7 @@ def check_and_handle_human_input_timeouts(limit: int = 100) -> None: """Scan for expired human input forms and resume or end workflows.""" session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - form_repo = HumanInputFormSubmissionRepository(session_factory) + form_repo = HumanInputFormSubmissionRepository() service = HumanInputService(session_factory, form_repository=form_repo) now = naive_utc_now() global_timeout_seconds = dify_config.HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d1cd0fbadc..d241783359 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,8 +11,8 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod +from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from models.human_input import ( @@ -111,7 +111,7 @@ def _render_body( url=form_link, variable_pool=variable_pool, ) - return body + return EmailDeliveryConfig.render_markdown_body(body) def _load_variable_pool(workflow_run_id: str | None) -> VariablePool | None: @@ -173,10 +173,11 @@ def dispatch_human_input_email_task(form_id: str, node_title: str | None = None, for recipient in job.recipients: form_link = _build_form_link(recipient.token) body = _render_body(job.body, form_link, variable_pool=variable_pool) + subject = EmailDeliveryConfig.sanitize_subject(job.subject) mail.send( to=recipient.email, - subject=job.subject, + subject=subject, html=body, ) diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py index 6ad04aab0d..5d201bd801 100644 --- a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py +++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py @@ -6,7 +6,6 @@ import typing import click from celery import shared_task -from core.helper.marketplace import record_install_plugin_event from core.plugin.entities.marketplace import MarketplacePluginSnapshot from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.impl.plugin import PluginInstaller @@ -166,7 +165,6 @@ def process_tenant_plugin_autoupgrade_check_task( # execute upgrade new_unique_identifier = manifest.latest_package_identifier - record_install_plugin_event(new_unique_identifier) click.echo( click.style( f"Upgrade plugin: {original_unique_identifier} -> {new_unique_identifier}", diff --git a/api/tasks/rag_pipeline/rag_pipeline_run_task.py b/api/tasks/rag_pipeline/rag_pipeline_run_task.py index 093342d1a3..52f66dddb8 100644 --- a/api/tasks/rag_pipeline/rag_pipeline_run_task.py +++ b/api/tasks/rag_pipeline/rag_pipeline_run_task.py @@ -3,12 +3,13 @@ import json import logging import time import uuid -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from concurrent.futures import ThreadPoolExecutor +from itertools import islice from typing import Any import click -from celery import shared_task # type: ignore +from celery import group, shared_task from flask import current_app, g from sqlalchemy.orm import Session, sessionmaker @@ -27,6 +28,11 @@ from services.file_service import FileService logger = logging.getLogger(__name__) +def chunked(iterable: Sequence, size: int): + it = iter(iterable) + return iter(lambda: list(islice(it, size)), []) + + @shared_task(queue="pipeline") def rag_pipeline_run_task( rag_pipeline_invoke_entities_file_id: str, @@ -83,16 +89,24 @@ def rag_pipeline_run_task( logger.info("rag pipeline tenant isolation queue %s next files: %s", tenant_id, next_file_ids) if next_file_ids: - for next_file_id in next_file_ids: - # Process the next waiting task - # Keep the flag set to indicate a task is running - tenant_isolated_task_queue.set_task_waiting_time() - rag_pipeline_run_task.delay( # type: ignore - rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8") - if isinstance(next_file_id, bytes) - else next_file_id, - tenant_id=tenant_id, - ) + for batch in chunked(next_file_ids, 100): + jobs = [] + for next_file_id in batch: + tenant_isolated_task_queue.set_task_waiting_time() + + file_id = ( + next_file_id.decode("utf-8") if isinstance(next_file_id, (bytes, bytearray)) else next_file_id + ) + + jobs.append( + rag_pipeline_run_task.s( + rag_pipeline_invoke_entities_file_id=file_id, + tenant_id=tenant_id, + ) + ) + + if jobs: + group(jobs).apply_async() else: # No more waiting tasks, clear the flag tenant_isolated_task_queue.delete_task_key() diff --git a/api/tasks/regenerate_summary_index_task.py b/api/tasks/regenerate_summary_index_task.py index cf8988d13e..39c2f4103e 100644 --- a/api/tasks/regenerate_summary_index_task.py +++ b/api/tasks/regenerate_summary_index_task.py @@ -16,7 +16,7 @@ from services.summary_index_service import SummaryIndexService logger = logging.getLogger(__name__) -@shared_task(queue="dataset") +@shared_task(queue="dataset_summary") def regenerate_summary_index_task( dataset_id: str, regenerate_reason: str = "summary_model_changed", diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index f20b15ac83..4fcb0cf804 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -12,6 +12,7 @@ from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models import Account, Tenant from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService from services.rag_pipeline.rag_pipeline import RagPipelineService @@ -63,7 +64,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ .first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -95,7 +96,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -108,7 +109,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_ indexing_runner.run([document]) redis_client.delete(retry_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index f1c8c56995..aa6bce958b 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -11,6 +11,7 @@ from core.rag.index_processor.index_processor_factory import IndexProcessorFacto from extensions.ext_redis import redis_client from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.feature_service import FeatureService logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() ) if document: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(e) document.stopped_at = naive_utc_now() session.add(document) @@ -76,7 +77,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): session.execute(segment_delete_stmt) session.commit() - document.indexing_status = "parsing" + document.indexing_status = IndexingStatus.PARSING document.processing_started_at = naive_utc_now() session.add(document) session.commit() @@ -85,7 +86,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): indexing_runner.run([document]) redis_client.delete(sync_indexing_cache_key) except Exception as ex: - document.indexing_status = "error" + document.indexing_status = IndexingStatus.ERROR document.error = str(ex) document.stopped_at = naive_utc_now() session.add(document) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index d18ea2c23c..75ae1f6316 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -20,13 +20,14 @@ from core.db.session_factory import session_factory from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.entities.request import TriggerInvokeEventResponse from core.plugin.impl.exc import PluginInvokeError +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager -from core.workflow.enums import NodeType, WorkflowExecutionStatus from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData +from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited from models.enums import ( AppTriggerType, @@ -164,7 +165,7 @@ def _record_trigger_failure_log( elapsed_time=0.0, total_tokens=0, total_steps=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, created_at=now, finished_at=now, @@ -179,7 +180,7 @@ def _record_trigger_failure_log( workflow_id=workflow.id, workflow_run_id=workflow_run.id, created_from=WorkflowAppLogCreatedFrom.SERVICE_API.value, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, ) session.add(workflow_app_log) @@ -212,7 +213,7 @@ def _record_trigger_failure_log( error=error_message, queue_name=queue_name, retry_count=0, - created_by_role=created_by_role.value, + created_by_role=created_by_role, created_by=created_by, triggered_at=now, finished_at=now, @@ -278,7 +279,7 @@ def dispatch_triggered_workflow( # Find the trigger node in the workflow event_node = None - for node_id, node_config in workflow.walk_nodes(NodeType.TRIGGER_PLUGIN): + for node_id, node_config in workflow.walk_nodes(TRIGGER_PLUGIN_NODE_TYPE): if node_id == plugin_trigger.node_id: event_node = node_config break diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index 3b3c6e5313..f41118e592 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -12,8 +12,8 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from core.workflow.entities.workflow_execution import WorkflowExecution -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.entities.workflow_execution import WorkflowExecution +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom @@ -94,13 +94,15 @@ def _create_workflow_run_from_execution( workflow_run.tenant_id = tenant_id workflow_run.app_id = app_id workflow_run.workflow_id = execution.workflow_id - workflow_run.type = execution.workflow_type.value - workflow_run.triggered_from = triggered_from.value + from models.workflow import WorkflowType as ModelWorkflowType + + workflow_run.type = ModelWorkflowType(execution.workflow_type.value) + workflow_run.triggered_from = triggered_from workflow_run.version = execution.workflow_version json_converter = WorkflowRuntimeTypeConverter() workflow_run.graph = json.dumps(json_converter.to_json_encodable(execution.graph)) workflow_run.inputs = json.dumps(json_converter.to_json_encodable(execution.inputs)) - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) @@ -108,7 +110,7 @@ def _create_workflow_run_from_execution( workflow_run.elapsed_time = execution.elapsed_time workflow_run.total_tokens = execution.total_tokens workflow_run.total_steps = execution.total_steps - workflow_run.created_by_role = creator_user_role.value + workflow_run.created_by_role = creator_user_role workflow_run.created_by = creator_user_id workflow_run.created_at = execution.started_at workflow_run.finished_at = execution.finished_at @@ -121,7 +123,7 @@ def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: Wo Update a WorkflowRun database model from a WorkflowExecution domain entity. """ json_converter = WorkflowRuntimeTypeConverter() - workflow_run.status = execution.status.value + workflow_run.status = execution.status workflow_run.outputs = ( json.dumps(json_converter.to_json_encodable(execution.outputs)) if execution.outputs else "{}" ) diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index b30a4ff15b..466ef6c858 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -12,10 +12,10 @@ from celery import shared_task from sqlalchemy import select from core.db.session_factory import session_factory -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, ) -from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter +from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -98,12 +98,12 @@ def _create_node_execution_from_domain( node_execution.tenant_id = tenant_id node_execution.app_id = app_id node_execution.workflow_id = execution.workflow_id - node_execution.triggered_from = triggered_from.value + node_execution.triggered_from = triggered_from node_execution.workflow_run_id = execution.workflow_execution_id node_execution.index = execution.index node_execution.predecessor_node_id = execution.predecessor_node_id node_execution.node_id = execution.node_id - node_execution.node_type = execution.node_type.value + node_execution.node_type = execution.node_type node_execution.title = execution.title node_execution.node_execution_id = execution.node_execution_id @@ -128,7 +128,7 @@ def _create_node_execution_from_domain( node_execution.status = execution.status.value node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time - node_execution.created_by_role = creator_user_role.value + node_execution.created_by_role = creator_user_role node_execution.created_by = creator_user_id node_execution.created_at = execution.created_at node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/.env.example b/api/tests/integration_tests/.env.example index 39effbab58..f84d39aeb5 100644 --- a/api/tests/integration_tests/.env.example +++ b/api/tests/integration_tests/.env.example @@ -60,7 +60,6 @@ VECTOR_STORE=weaviate # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih -WEAVIATE_GRPC_ENABLED=false WEAVIATE_BATCH_SIZE=100 WEAVIATE_TOKENIZATION=word @@ -78,6 +77,19 @@ IRIS_MAX_CONNECTION=3 IRIS_TEXT_INDEX=true IRIS_TEXT_INDEX_LANGUAGE=en +# Hologres configuration +HOLOGRES_HOST=localhost +HOLOGRES_PORT=80 +HOLOGRES_DATABASE=test_db +HOLOGRES_ACCESS_KEY_ID=test_access_key_id +HOLOGRES_ACCESS_KEY_SECRET=test_access_key_secret +HOLOGRES_SCHEMA=public +HOLOGRES_TOKENIZER=jieba +HOLOGRES_DISTANCE_METHOD=Cosine +HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq +HOLOGRES_MAX_DEGREE=64 +HOLOGRES_EF_CONSTRUCTION=400 + # Upload configuration UPLOAD_FILE_SIZE_LIMIT=15 diff --git a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py index 498ac56d5d..d10e5ed13c 100644 --- a/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py +++ b/api/tests/integration_tests/controllers/console/app/test_chat_message_permissions.py @@ -13,6 +13,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -154,7 +155,7 @@ class TestChatMessageApiPermissions: re_sign_file_url_answer="", answer_tokens=0, provider_response_latency=0.0, - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=mock_account.id, feedbacks=[], @@ -165,7 +166,7 @@ class TestChatMessageApiPermissions: agent_thoughts=[], message_files=[], message_metadata_dict={}, - status="success", + status="normal", error="", parent_message_id=None, ) diff --git a/api/tests/integration_tests/controllers/console/app/test_description_validation.py b/api/tests/integration_tests/controllers/console/app/test_description_validation.py index 8160807e48..f36c596eb8 100644 --- a/api/tests/integration_tests/controllers/console/app/test_description_validation.py +++ b/api/tests/integration_tests/controllers/console/app/test_description_validation.py @@ -5,14 +5,10 @@ This test module validates the 400-character limit enforcement for App descriptions across all creation and editing endpoints. """ -import os import sys import pytest -# Add the API root to Python path for imports -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) - class TestAppDescriptionValidationUnit: """Unit tests for description validation function""" diff --git a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py index 0f8b42e98b..309a0b015a 100644 --- a/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py +++ b/api/tests/integration_tests/controllers/console/app/test_feedback_export_api.py @@ -14,6 +14,7 @@ from controllers.console.app import wraps from libs.datetime_utils import naive_utc_now from models import App, Tenant from models.account import Account, TenantAccountJoin, TenantAccountRole +from models.enums import FeedbackFromSource, FeedbackRating from models.model import AppMode, MessageFeedback from services.feedback_service import FeedbackService @@ -77,8 +78,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content=None, from_end_user_id=str(uuid.uuid4()), from_account_id=None, @@ -90,8 +91,8 @@ class TestFeedbackExportApi: app_id=app_id, conversation_id=conversation_id, message_id=message_id, - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="The response was not helpful", from_end_user_id=None, from_account_id=str(uuid.uuid4()), @@ -277,8 +278,8 @@ class TestFeedbackExportApi: # Verify service was called with correct parameters mock_export_feedbacks.assert_called_once_with( app_id=mock_app_model.id, - from_source="user", - rating="dislike", + from_source=FeedbackFromSource.USER, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 003bb356e5..4fdbb7d9f3 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -2,7 +2,7 @@ from collections.abc import Generator from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from core.workflow.node_events import StreamCompletedEvent +from dify_graph.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 909d6377ce..3e79792b5b 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,6 +1,7 @@ -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent class _Seg: @@ -28,13 +29,17 @@ class _GS: class _GP: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 @@ -61,6 +66,8 @@ def test_node_integration_minimal_stream(mocker): def get_upload_file_by_id(cls, **_): raise AssertionError + mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) + node = DatasourceNode( id="n", config={ @@ -77,7 +84,6 @@ def test_node_integration_minimal_stream(mocker): }, graph_init_params=_GP(), graph_runtime_state=_GS(vp), - datasource_manager=_Mgr, ) out = list(node._run()) diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index 16a66bc3f1..db4bbc1ca1 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -6,8 +6,9 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db +from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=storage_key, name="test_file.txt", size=1024, @@ -288,7 +289,7 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( tenant_id=other_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key="other_tenant_key", name="other_file.txt", size=1024, diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 5012defdad..4e184c93fd 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,20 +4,27 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient + # import monkeypatch -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from core.model_runtime.entities.model_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.llm_entities import ( + LLMMode, + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 5faa002fff..9d3a869691 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,13 +6,14 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes import NodeType -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import StringVariable +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType +from dify_graph.variables.variables import StringVariable from extensions.ext_database import db from extensions.ext_storage import storage +from extensions.storage.storage_type import StorageType from factories.variable_factory import build_segment from libs import datetime_utils from models.enums import CreatorUserRole @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import ( class TestWorkflowDraftVariableService(unittest.TestCase): _test_app_id: str _session: Session + _test_user_id: str _node1_id = "test_node_1" _node2_id = "test_node_2" _node_exec_id = str(uuid.uuid4()) @@ -99,13 +101,13 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_list_variables(self): srv = self._get_test_srv() - var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2) + var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2, user_id=self._test_user_id) assert var_list.total == 5 assert len(var_list.variables) == 2 page1_var_ids = {v.id for v in var_list.variables} assert page1_var_ids.issubset(self._variable_ids) - var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2) + var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2, user_id=self._test_user_id) assert var_list_2.total is None assert len(var_list_2.variables) == 2 page2_var_ids = {v.id for v in var_list_2.variables} @@ -114,7 +116,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_get_node_variable(self): srv = self._get_test_srv() - node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var") + node_var = srv.get_node_variable(self._test_app_id, self._node1_id, "str_var", user_id=self._test_user_id) assert node_var is not None assert node_var.id == self._node1_str_var_id assert node_var.name == "str_var" @@ -122,7 +124,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_get_system_variable(self): srv = self._get_test_srv() - sys_var = srv.get_system_variable(self._test_app_id, "sys_var") + sys_var = srv.get_system_variable(self._test_app_id, "sys_var", user_id=self._test_user_id) assert sys_var is not None assert sys_var.id == self._sys_var_id assert sys_var.name == "sys_var" @@ -130,7 +132,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_get_conversation_variable(self): srv = self._get_test_srv() - conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var") + conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var", user_id=self._test_user_id) assert conv_var is not None assert conv_var.id == self._conv_var_id assert conv_var.name == "conv_var" @@ -138,7 +140,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test_delete_node_variables(self): srv = self._get_test_srv() - srv.delete_node_variables(self._test_app_id, self._node2_id) + srv.delete_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) node2_var_count = ( self._session.query(WorkflowDraftVariable) .where( @@ -162,7 +164,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): def test__list_node_variables(self): srv = self._get_test_srv() - node_vars = srv._list_node_variables(self._test_app_id, self._node2_id) + node_vars = srv._list_node_variables(self._test_app_id, self._node2_id, user_id=self._test_user_id) assert len(node_vars.variables) == 2 assert {v.id for v in node_vars.variables} == set(self._node2_var_ids) @@ -173,7 +175,7 @@ class TestWorkflowDraftVariableService(unittest.TestCase): [self._node2_id, "str_var"], [self._node2_id, "int_var"], ] - variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors) + variables = srv.get_draft_variables_by_selectors(self._test_app_id, selectors, user_id=self._test_user_id) assert len(variables) == 3 assert {v.id for v in variables} == {self._node1_str_var_id} | set(self._node2_var_ids) @@ -206,19 +208,23 @@ class TestDraftVariableLoader(unittest.TestCase): def setUp(self): self._test_app_id = str(uuid.uuid4()) self._test_tenant_id = str(uuid.uuid4()) + self._test_user_id = str(uuid.uuid4()) sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, ) conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var", value=build_segment("conv_value"), ) node_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node1_id, name="str_var", value=build_segment("str_value"), @@ -248,12 +254,22 @@ class TestDraftVariableLoader(unittest.TestCase): session.commit() def test_variable_loader_with_empty_selector(self): - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=self._test_app_id, + tenant_id=self._test_tenant_id, + user_id=self._test_user_id, + ) variables = var_loader.load_variables([]) assert len(variables) == 0 def test_variable_loader_with_non_empty_selector(self): - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=self._test_app_id, + tenant_id=self._test_tenant_id, + user_id=self._test_user_id, + ) variables = var_loader.load_variables( [ [SYSTEM_VARIABLE_NODE_ID, "sys_var"], @@ -286,7 +302,7 @@ class TestDraftVariableLoader(unittest.TestCase): session=session, app_id=self._test_app_id, node_id="test_offload_node", - node_type=NodeType.LLM, # Use a real node type + node_type=BuiltinNodeTypes.LLM, # Use a real node type node_execution_id=node_execution_id, user=setup_account, ) @@ -296,7 +312,12 @@ class TestDraftVariableLoader(unittest.TestCase): session.commit() # Now test loading using DraftVarLoader - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=self._test_app_id, + tenant_id=self._test_tenant_id, + user_id=setup_account.id, + ) # Load the variable using the standard workflow variables = var_loader.load_variables([["test_offload_node", "offloaded_string_var"]]) @@ -313,7 +334,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Clean up - delete all draft variables for this app with Session(bind=db.engine) as session: service = WorkflowDraftVariableService(session) - service.delete_workflow_variables(self._test_app_id) + service.delete_app_workflow_variables(self._test_app_id) session.commit() def test_load_offloaded_variable_object_type_integration(self): @@ -327,7 +348,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Create an upload file record upload_file = UploadFile( tenant_id=self._test_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_offload_{uuid.uuid4()}.json", name="test_offload.json", size=len(content_bytes), @@ -364,6 +385,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Now create the offloaded draft variable with the correct file_id offloaded_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id="test_offload_node", name="offloaded_object_var", value=build_segment({"truncated": True}), @@ -379,7 +401,9 @@ class TestDraftVariableLoader(unittest.TestCase): # Use the service method that properly preloads relationships service = WorkflowDraftVariableService(session) draft_vars = service.get_draft_variables_by_selectors( - self._test_app_id, [["test_offload_node", "offloaded_object_var"]] + self._test_app_id, + [["test_offload_node", "offloaded_object_var"]], + user_id=self._test_user_id, ) assert len(draft_vars) == 1 @@ -387,7 +411,12 @@ class TestDraftVariableLoader(unittest.TestCase): assert loaded_var.is_truncated() # Create DraftVarLoader and test loading - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=self._test_app_id, + tenant_id=self._test_tenant_id, + user_id=self._test_user_id, + ) # Test the _load_offloaded_variable method selector_tuple, variable = var_loader._load_offloaded_variable(loaded_var) @@ -422,7 +451,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Create upload file record upload_file = UploadFile( tenant_id=self._test_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_integration_{uuid.uuid4()}.txt", name="test_integration.txt", size=len(content_bytes), @@ -459,6 +488,7 @@ class TestDraftVariableLoader(unittest.TestCase): # Now create the offloaded draft variable with the correct file_id offloaded_var = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id="test_integration_node", name="offloaded_integration_var", value=build_segment("truncated"), @@ -473,7 +503,12 @@ class TestDraftVariableLoader(unittest.TestCase): # Test load_variables with both regular and offloaded variables # This method should handle the relationship preloading internally - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=self._test_app_id, + tenant_id=self._test_tenant_id, + user_id=self._test_user_id, + ) variables = var_loader.load_variables( [ @@ -542,7 +577,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): index=1, node_execution_id=str(uuid.uuid4()), node_id=self._node_id, - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", inputs='{"input": "test input"}', process_data='{"test_var": "process_value", "other_var": "other_process"}', @@ -572,6 +607,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): # Create test variables self._node_var_with_exec = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node_id, name="test_var", value=build_segment("old_value"), @@ -581,6 +617,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): self._node_var_without_exec = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node_id, name="no_exec_var", value=build_segment("some_value"), @@ -591,6 +628,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): self._node_var_missing_exec = WorkflowDraftVariable.new_node_variable( app_id=self._test_app_id, + user_id=self._test_user_id, node_id=self._node_id, name="missing_exec_var", value=build_segment("some_value"), @@ -599,6 +637,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): self._conv_var = WorkflowDraftVariable.new_conversation_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="conv_var_1", value=build_segment("old_conv_value"), ) @@ -764,6 +803,7 @@ class TestWorkflowDraftVariableServiceResetVariable(unittest.TestCase): # Create a system variable sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, + user_id=self._test_user_id, name="sys_var", value=build_segment("sys_value"), node_execution_id=self._node_exec_id, diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index a259ccb2b9..bc83c6cc12 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -5,7 +5,8 @@ import pytest from sqlalchemy import delete from core.db.session_factory import session_factory -from core.workflow.variables.segments import StringSegment +from dify_graph.variables.segments import StringSegment +from extensions.storage.storage_type import StorageType from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile @@ -191,13 +192,13 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from core.workflow.variables.types import SegmentType + from dify_graph.variables.types import SegmentType from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: upload_file1 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file1.json", name="file1.json", size=1024, @@ -210,7 +211,7 @@ class TestDeleteDraftVariablesWithOffloadIntegration: ) upload_file2 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file2.json", name="file2.json", size=2048, @@ -422,7 +423,7 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from core.workflow.variables.types import SegmentType + from dify_graph.variables.types import SegmentType from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant @@ -430,7 +431,7 @@ class TestDeleteDraftVariablesSessionCommit: with session_factory.create_session() as session: upload_file1 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file1.json", name="file1.json", size=1024, @@ -443,7 +444,7 @@ class TestDeleteDraftVariablesSessionCommit: ) upload_file2 = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key="test/file2.json", name="file2.json", size=2048, diff --git a/api/tests/integration_tests/vdb/__mock/hologres.py b/api/tests/integration_tests/vdb/__mock/hologres.py new file mode 100644 index 0000000000..b60cf358c0 --- /dev/null +++ b/api/tests/integration_tests/vdb/__mock/hologres.py @@ -0,0 +1,209 @@ +import json +import os +from typing import Any + +import holo_search_sdk as holo +import pytest +from _pytest.monkeypatch import MonkeyPatch +from psycopg import sql as psql + +# Shared in-memory storage: {table_name: {doc_id: {"id", "text", "meta", "embedding"}}} +_mock_tables: dict[str, dict[str, dict[str, Any]]] = {} + + +class MockSearchQuery: + """Mock query builder for search_vector and search_text results.""" + + def __init__(self, table_name: str, search_type: str): + self._table_name = table_name + self._search_type = search_type + self._limit_val = 10 + self._filter_sql = None + + def select(self, columns): + return self + + def limit(self, n): + self._limit_val = n + return self + + def where(self, filter_sql): + self._filter_sql = filter_sql + return self + + def _apply_filter(self, row: dict[str, Any]) -> bool: + """Apply the filter SQL to check if a row matches.""" + if self._filter_sql is None: + return True + + # Extract literals (the document IDs) from the filter SQL + # Filter format: meta->>'document_id' IN ('doc1', 'doc2') + literals = [v for t, v in _extract_identifiers_and_literals(self._filter_sql) if t == "literal"] + if not literals: + return True + + # Get the document_id from the row's meta field + meta = row.get("meta", "{}") + if isinstance(meta, str): + meta = json.loads(meta) + doc_id = meta.get("document_id") + + return doc_id in literals + + def fetchall(self): + data = _mock_tables.get(self._table_name, {}) + results = [] + for row in list(data.values())[: self._limit_val]: + # Apply filter if present + if not self._apply_filter(row): + continue + + if self._search_type == "vector": + # row format expected by _process_vector_results: (distance, id, text, meta) + results.append((0.1, row["id"], row["text"], row["meta"])) + else: + # row format expected by _process_full_text_results: (id, text, meta, embedding, score) + results.append((row["id"], row["text"], row["meta"], row.get("embedding", []), 0.9)) + return results + + +class MockTable: + """Mock table object returned by client.open_table().""" + + def __init__(self, table_name: str): + self._table_name = table_name + + def upsert_multi(self, index_column, values, column_names, update=True, update_columns=None): + if self._table_name not in _mock_tables: + _mock_tables[self._table_name] = {} + id_idx = column_names.index("id") + for row in values: + doc_id = row[id_idx] + _mock_tables[self._table_name][doc_id] = dict(zip(column_names, row)) + + def search_vector(self, vector, column, distance_method, output_name): + return MockSearchQuery(self._table_name, "vector") + + def search_text(self, column, expression, return_score=False, return_score_name="score", return_all_columns=False): + return MockSearchQuery(self._table_name, "text") + + def set_vector_index( + self, column, distance_method, base_quantization_type, max_degree, ef_construction, use_reorder + ): + pass + + def create_text_index(self, index_name, column, tokenizer): + pass + + +def _extract_sql_template(query) -> str: + """Extract the SQL template string from a psycopg Composed object.""" + if isinstance(query, psql.Composed): + for part in query: + if isinstance(part, psql.SQL): + return part._obj + if isinstance(query, psql.SQL): + return query._obj + return "" + + +def _extract_identifiers_and_literals(query) -> list[Any]: + """Extract Identifier and Literal values from a psycopg Composed object.""" + values: list[Any] = [] + if isinstance(query, psql.Composed): + for part in query: + if isinstance(part, psql.Identifier): + values.append(("ident", part._obj[0] if part._obj else "")) + elif isinstance(part, psql.Literal): + values.append(("literal", part._obj)) + elif isinstance(part, psql.Composed): + # Handles SQL(...).join(...) for IN clauses + for sub in part: + if isinstance(sub, psql.Literal): + values.append(("literal", sub._obj)) + return values + + +class MockHologresClient: + """Mock holo_search_sdk client that stores data in memory.""" + + def connect(self): + pass + + def check_table_exist(self, table_name): + return table_name in _mock_tables + + def open_table(self, table_name): + return MockTable(table_name) + + def execute(self, query, fetch_result=False): + template = _extract_sql_template(query) + params = _extract_identifiers_and_literals(query) + + if "CREATE TABLE" in template.upper(): + # Extract table name from first identifier + table_name = next((v for t, v in params if t == "ident"), "unknown") + if table_name not in _mock_tables: + _mock_tables[table_name] = {} + return None + + if "SELECT 1" in template: + # text_exists: SELECT 1 FROM {table} WHERE id = {id} LIMIT 1 + table_name = next((v for t, v in params if t == "ident"), "") + doc_id = next((v for t, v in params if t == "literal"), "") + data = _mock_tables.get(table_name, {}) + return [(1,)] if doc_id in data else [] + + if "SELECT id" in template: + # get_ids_by_metadata_field: SELECT id FROM {table} WHERE meta->>{key} = {value} + table_name = next((v for t, v in params if t == "ident"), "") + literals = [v for t, v in params if t == "literal"] + key = literals[0] if len(literals) > 0 else "" + value = literals[1] if len(literals) > 1 else "" + data = _mock_tables.get(table_name, {}) + return [(doc_id,) for doc_id, row in data.items() if json.loads(row.get("meta", "{}")).get(key) == value] + + if "DELETE" in template.upper(): + table_name = next((v for t, v in params if t == "ident"), "") + if "id IN" in template: + # delete_by_ids + ids_to_delete = [v for t, v in params if t == "literal"] + for did in ids_to_delete: + _mock_tables.get(table_name, {}).pop(did, None) + elif "meta->>" in template: + # delete_by_metadata_field + literals = [v for t, v in params if t == "literal"] + key = literals[0] if len(literals) > 0 else "" + value = literals[1] if len(literals) > 1 else "" + data = _mock_tables.get(table_name, {}) + to_remove = [ + doc_id for doc_id, row in data.items() if json.loads(row.get("meta", "{}")).get(key) == value + ] + for did in to_remove: + data.pop(did, None) + return None + + return [] if fetch_result else None + + def drop_table(self, table_name): + _mock_tables.pop(table_name, None) + + +def mock_connect(**kwargs): + """Replacement for holo_search_sdk.connect() that returns a mock client.""" + return MockHologresClient() + + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +@pytest.fixture +def setup_hologres_mock(monkeypatch: MonkeyPatch): + if MOCK: + monkeypatch.setattr(holo, "connect", mock_connect) + + yield + + if MOCK: + _mock_tables.clear() + monkeypatch.undo() diff --git a/api/tests/integration_tests/vdb/hologres/__init__.py b/api/tests/integration_tests/vdb/hologres/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/vdb/hologres/test_hologres.py b/api/tests/integration_tests/vdb/hologres/test_hologres.py new file mode 100644 index 0000000000..ff2be88ef1 --- /dev/null +++ b/api/tests/integration_tests/vdb/hologres/test_hologres.py @@ -0,0 +1,149 @@ +import os +import uuid +from typing import cast + +from holo_search_sdk.types import BaseQuantizationType, DistanceType, TokenizerType + +from core.rag.datasource.vdb.hologres.hologres_vector import HologresVector, HologresVectorConfig +from core.rag.models.document import Document +from tests.integration_tests.vdb.__mock.hologres import setup_hologres_mock +from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis + +MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" + + +class HologresVectorTest(AbstractVectorTest): + def __init__(self): + super().__init__() + # Hologres requires collection names to be lowercase + self.collection_name = self.collection_name.lower() + self.vector = HologresVector( + collection_name=self.collection_name, + config=HologresVectorConfig( + host=os.environ.get("HOLOGRES_HOST", "localhost"), + port=int(os.environ.get("HOLOGRES_PORT", "80")), + database=os.environ.get("HOLOGRES_DATABASE", "test_db"), + access_key_id=os.environ.get("HOLOGRES_ACCESS_KEY_ID", "test_key"), + access_key_secret=os.environ.get("HOLOGRES_ACCESS_KEY_SECRET", "test_secret"), + schema_name=os.environ.get("HOLOGRES_SCHEMA", "public"), + tokenizer=cast(TokenizerType, os.environ.get("HOLOGRES_TOKENIZER", "jieba")), + distance_method=cast(DistanceType, os.environ.get("HOLOGRES_DISTANCE_METHOD", "Cosine")), + base_quantization_type=cast( + BaseQuantizationType, os.environ.get("HOLOGRES_BASE_QUANTIZATION_TYPE", "rabitq") + ), + max_degree=int(os.environ.get("HOLOGRES_MAX_DEGREE", "64")), + ef_construction=int(os.environ.get("HOLOGRES_EF_CONSTRUCTION", "400")), + ), + ) + + def search_by_full_text(self): + """Override: full-text index may not be immediately ready in real mode.""" + hits_by_full_text = self.vector.search_by_full_text(query=get_example_text()) + if MOCK: + # In mock mode, full-text search should return the document we inserted + assert len(hits_by_full_text) == 1 + assert hits_by_full_text[0].metadata["doc_id"] == self.example_doc_id + else: + # In real mode, full-text index may need time to become active + assert len(hits_by_full_text) >= 0 + + def search_by_vector_with_filter(self): + """Test vector search with document_ids_filter.""" + # Create another document with different document_id + other_doc_id = str(uuid.uuid4()) + other_doc = Document( + page_content="other_text", + metadata={ + "doc_id": other_doc_id, + "doc_hash": other_doc_id, + "document_id": other_doc_id, + "dataset_id": self.dataset_id, + }, + ) + self.vector.add_texts(documents=[other_doc], embeddings=[self.example_embedding]) + + # Search with filter - should only return the original document + hits = self.vector.search_by_vector( + query_vector=self.example_embedding, + document_ids_filter=[self.example_doc_id], + ) + assert len(hits) == 1 + assert hits[0].metadata["doc_id"] == self.example_doc_id + + # Search without filter - should return both + all_hits = self.vector.search_by_vector(query_vector=self.example_embedding, top_k=10) + assert len(all_hits) >= 2 + + def search_by_full_text_with_filter(self): + """Test full-text search with document_ids_filter.""" + # Create another document with different document_id + other_doc_id = str(uuid.uuid4()) + other_doc = Document( + page_content="unique_other_text", + metadata={ + "doc_id": other_doc_id, + "doc_hash": other_doc_id, + "document_id": other_doc_id, + "dataset_id": self.dataset_id, + }, + ) + self.vector.add_texts(documents=[other_doc], embeddings=[self.example_embedding]) + + # Search with filter - should only return the original document + hits = self.vector.search_by_full_text( + query=get_example_text(), + document_ids_filter=[self.example_doc_id], + ) + if MOCK: + assert len(hits) == 1 + assert hits[0].metadata["doc_id"] == self.example_doc_id + + def get_ids_by_metadata_field(self): + """Override: Hologres implements this method via JSONB query.""" + ids = self.vector.get_ids_by_metadata_field(key="document_id", value=self.example_doc_id) + assert ids is not None + assert len(ids) == 1 + + def run_all_tests(self): + # Clean up before running tests + self.vector.delete() + # Run base tests (create, search, text_exists, get_ids, add_texts, delete_by_ids, delete) + super().run_all_tests() + + # Additional filter tests require fresh data (table was deleted by base tests) + if MOCK: + # Recreate collection for filter tests + self.vector.create( + texts=[ + Document( + page_content=get_example_text(), + metadata={ + "doc_id": self.example_doc_id, + "doc_hash": self.example_doc_id, + "document_id": self.example_doc_id, + "dataset_id": self.dataset_id, + }, + ) + ], + embeddings=[self.example_embedding], + ) + self.search_by_vector_with_filter() + self.search_by_full_text_with_filter() + # Clean up + self.vector.delete() + + +def test_hologres_vector(setup_mock_redis, setup_hologres_mock): + """ + Test Hologres vector database implementation. + + This test covers: + - Creating collection with vector index + - Adding texts with embeddings + - Vector similarity search + - Full-text search + - Text existence check + - Batch deletion by IDs + - Collection deletion + """ + HologresVectorTest().run_all_tests() diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index cdecdf41d2..5b0f86fed1 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType diff --git a/api/tests/integration_tests/workflow/nodes/knowledge_index/__init__.py b/api/tests/integration_tests/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/workflow/nodes/knowledge_index/test_knowledge_index_node_integration.py b/api/tests/integration_tests/workflow/nodes/knowledge_index/test_knowledge_index_node_integration.py new file mode 100644 index 0000000000..4edbf2b1e9 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/knowledge_index/test_knowledge_index_node_integration.py @@ -0,0 +1,69 @@ +""" +Integration tests for KnowledgeIndexNode. + +This module provides integration tests for KnowledgeIndexNode with real database interactions. + +Note: These tests require database setup and are more complex than unit tests. +For now, we focus on unit tests which provide better coverage for the node logic. +""" + +import pytest + + +class TestKnowledgeIndexNodeIntegration: + """ + Integration test suite for KnowledgeIndexNode. + + Note: Full integration tests require: + - Database setup with datasets and documents + - Vector store for embeddings + - Model providers for indexing and summarization + - IndexProcessor and SummaryIndexService implementations + + For now, unit tests provide comprehensive coverage of the node logic. + """ + + @pytest.mark.skip(reason="Integration tests require full database and vector store setup") + def test_end_to_end_knowledge_index_preview(self): + """Test end-to-end knowledge index workflow in preview mode.""" + # TODO: Implement with real database + # 1. Create a dataset + # 2. Create a document + # 3. Prepare chunks + # 4. Run KnowledgeIndexNode in preview mode + # 5. Verify preview output + pass + + @pytest.mark.skip(reason="Integration tests require full database and vector store setup") + def test_end_to_end_knowledge_index_production(self): + """Test end-to-end knowledge index workflow in production mode.""" + # TODO: Implement with real database + # 1. Create a dataset + # 2. Create a document + # 3. Prepare chunks + # 4. Run KnowledgeIndexNode in production mode + # 5. Verify indexing and summary generation + pass + + @pytest.mark.skip(reason="Integration tests require full database and vector store setup") + def test_knowledge_index_with_summary_enabled(self): + """Test knowledge index with summary index setting enabled.""" + # TODO: Implement with real database + # 1. Create a dataset + # 2. Create a document + # 3. Prepare chunks + # 4. Configure summary index setting + # 5. Run KnowledgeIndexNode + # 6. Verify summaries are generated and indexed + pass + + @pytest.mark.skip(reason="Integration tests require full database and vector store setup") + def test_knowledge_index_parent_child_structure(self): + """Test knowledge index with parent-child chunk structure.""" + # TODO: Implement with real database + # 1. Create a dataset + # 2. Create a document + # 3. Prepare parent-child chunks + # 4. Run KnowledgeIndexNode + # 5. Verify parent-child indexing + pass diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e0ea14b789..e3a2b6b866 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -4,18 +4,17 @@ import uuid import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from tests.workflow_test_utils import build_test_graph_init_params CODE_MAX_STRING_LENGTH = dify_config.CODE_MAX_STRING_LENGTH @@ -32,11 +31,11 @@ def init_code_node(code_config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, code_config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -61,7 +60,7 @@ def init_code_node(code_config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = CodeNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index e0f2363799..f885f69e55 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -5,19 +5,18 @@ from urllib.parse import urlencode import pytest from configs import dify_config -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file.file_manager import file_manager -from core.workflow.graph import Graph -from core.workflow.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file.file_manager import file_manager +from dify_graph.graph import Graph +from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -42,11 +41,11 @@ def init_http_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -71,7 +70,7 @@ def init_http_node(config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( id=str(uuid.uuid4()), @@ -190,15 +189,16 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from core.workflow.nodes.http_request.entities import ( + from dify_graph.enums import BuiltinNodeTypes + from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from core.workflow.nodes.http_request.exc import AuthorizationConfigError - from core.workflow.nodes.http_request.executor import Executor - from core.workflow.runtime import VariablePool - from core.workflow.system_variable import SystemVariable + from dify_graph.nodes.http_request.exc import AuthorizationConfigError + from dify_graph.nodes.http_request.executor import Executor + from dify_graph.runtime import VariablePool + from dify_graph.system_variable import SystemVariable # Create variable pool variable_pool = VariablePool( @@ -210,6 +210,7 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): # Create node data with custom auth and empty api_key node_data = HttpRequestNodeData( + type=BuiltinNodeTypes.HTTP_REQUEST, title="http", desc="", url="http://example.com", @@ -686,11 +687,11 @@ def test_nested_object_variable_selector(setup_http_mock): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -716,7 +717,7 @@ def test_nested_object_variable_selector(setup_http_mock): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node = HttpRequestNode( id=str(uuid.uuid4()), diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index b5b0fb5334..d628348f1e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,18 +4,18 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" @@ -38,11 +38,11 @@ def init_llm_node(config: dict) -> LLMNode: workflow_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056d" user_id = "9d2074fc-6f86-45a9-b09d-6ecc63b9056e" - init_params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + init_params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -75,6 +75,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), + template_renderer=MagicMock(spec=TemplateRenderer), + http_client=MagicMock(spec=HttpClientProtocol), ) return node @@ -113,8 +115,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from core.model_runtime.entities.message_entities import AssistantPromptMessage + from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -157,8 +159,8 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(**_kwargs): - from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + def mock_fetch_prompt_messages_1(*_args, **_kwargs): + from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -229,8 +231,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from core.model_runtime.entities.message_entities import AssistantPromptMessage + from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -274,7 +276,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 773074e92d..62d9af0196 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,18 +3,17 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from core.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance +from tests.workflow_test_utils import build_test_graph_init_params """FOR MOCK FIXTURES, DO NOT REMOVE""" from tests.integration_tests.model_runtime.__mock.plugin_daemon import setup_model_mock @@ -44,11 +43,11 @@ def init_parameter_extractor_node(config: dict, memory=None): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index bc03ce1b96..7bb4f905c3 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,22 +1,30 @@ import time import uuid -import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom -from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params -@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True) -def test_execute_code(setup_code_executor_mock): +class _SimpleJinja2Renderer: + """Minimal Jinja2-based renderer for integration tests (no code executor).""" + + def render_template(self, template: str, variables: dict[str, object]) -> str: + from jinja2 import Template + + try: + return Template(template).render(**variables) + except Exception as exc: + raise TemplateRenderError(str(exc)) from exc + + +def test_execute_template_transform(): code = """{{args2}}""" config = { "id": "1", @@ -45,11 +53,11 @@ def test_execute_code(setup_code_executor_mock): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -68,19 +76,21 @@ def test_execute_code(setup_code_executor_mock): graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - # Create node factory + # Create node factory (graph init path still works regardless of renderer choice below) node_factory = DifyNodeFactory( graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + assert graph is not None node = TemplateTransformNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index cfbef52c93..a6717ada31 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -1,18 +1,18 @@ import time import uuid -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.node_events import StreamCompletedEvent -from core.workflow.nodes.tool.tool_node import ToolNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.protocols import ToolFileManagerProtocol +from dify_graph.nodes.tool.tool_node import ToolNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def init_tool_node(config: dict): @@ -27,11 +27,11 @@ def init_tool_node(config: dict): "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}, config], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -54,13 +54,16 @@ def init_tool_node(config: dict): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( id=str(uuid.uuid4()), config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node @@ -84,17 +87,20 @@ def test_tool_variable_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - node.graph_runtime_state.variable_pool.add(["1", "args1"], "1+1") - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None def test_tool_mixed_invoke(): @@ -118,12 +124,15 @@ def test_tool_mixed_invoke(): } ) - ToolParameterConfigurationManager.decrypt_tool_parameters = MagicMock(return_value={"format": "%Y-%m-%d %H:%M:%S"}) - - # execute node - result = node._run() - for item in result: - if isinstance(item, StreamCompletedEvent): - assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert item.node_run_result.outputs is not None - assert item.node_run_result.outputs.get("text") is not None + with patch.object( + ToolParameterConfigurationManager, + "decrypt_tool_parameters", + return_value={"format": "%Y-%m-%d %H:%M:%S"}, + ): + # execute node + result = node._run() + for item in result: + if isinstance(item, StreamCompletedEvent): + assert item.node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert item.node_run_result.outputs is not None + assert item.node_run_result.outputs.get("text") is not None diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index d6d2d30305..ef0ca4232d 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -10,8 +10,11 @@ more reliable and realistic test scenarios. import logging import os from collections.abc import Generator +from contextlib import contextmanager from pathlib import Path +from typing import Protocol, TypeVar +import psycopg2 import pytest from flask import Flask from flask.testing import FlaskClient @@ -31,6 +34,25 @@ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(level logger = logging.getLogger(__name__) +class _CloserProtocol(Protocol): + """_Closer is any type which implement the close() method.""" + + def close(self): + """close the current object, release any external resouece (file, transaction, connection etc.) + associated with it. + """ + pass + + +_Closer = TypeVar("_Closer", bound=_CloserProtocol) + + +@contextmanager +def _auto_close(closer: _Closer) -> Generator[_Closer, None, None]: + yield closer + closer.close() + + class DifyTestContainers: """ Manages all test containers required for Dify integration tests. @@ -97,45 +119,28 @@ class DifyTestContainers: wait_for_logs(self.postgres, "is ready to accept connections", timeout=30) logger.info("PostgreSQL container is ready and accepting connections") - # Install uuid-ossp extension for UUID generation - logger.info("Installing uuid-ossp extension...") - try: - import psycopg2 - - conn = psycopg2.connect( - host=db_host, - port=db_port, - user=self.postgres.username, - password=self.postgres.password, - database=self.postgres.dbname, - ) - conn.autocommit = True - cursor = conn.cursor() - cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') - cursor.close() - conn.close() + conn = psycopg2.connect( + host=db_host, + port=db_port, + user=self.postgres.username, + password=self.postgres.password, + database=self.postgres.dbname, + ) + conn.autocommit = True + with _auto_close(conn): + with conn.cursor() as cursor: + # Install uuid-ossp extension for UUID generation + logger.info("Installing uuid-ossp extension...") + cursor.execute('CREATE EXTENSION IF NOT EXISTS "uuid-ossp";') logger.info("uuid-ossp extension installed successfully") - except Exception as e: - logger.warning("Failed to install uuid-ossp extension: %s", e) - # Create plugin database for dify-plugin-daemon - logger.info("Creating plugin database...") - try: - conn = psycopg2.connect( - host=db_host, - port=db_port, - user=self.postgres.username, - password=self.postgres.password, - database=self.postgres.dbname, - ) - conn.autocommit = True - cursor = conn.cursor() - cursor.execute("CREATE DATABASE dify_plugin;") - cursor.close() - conn.close() + # NOTE: We cannot use `with conn.cursor() as cursor:` as it will wrap the statement + # inside a transaction. However, the `CREATE DATABASE` statement cannot run inside a transaction block. + with _auto_close(conn.cursor()) as cursor: + # Create plugin database for dify-plugin-daemon + logger.info("Creating plugin database...") + cursor.execute("CREATE DATABASE dify_plugin;") logger.info("Plugin database created successfully") - except Exception as e: - logger.warning("Failed to create plugin database: %s", e) # Set up storage environment variables os.environ.setdefault("STORAGE_TYPE", "opendal") @@ -160,8 +165,9 @@ class DifyTestContainers: # Start Dify Sandbox container for code execution environment # Dify Sandbox provides a secure environment for executing user code + # Use pinned version 0.2.12 to match production docker-compose configuration logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:latest").with_network(self.network) + self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -181,7 +187,7 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.3.0-local").with_network( + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network( self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) @@ -258,23 +264,16 @@ class DifyTestContainers: containers = [self.redis, self.postgres, self.dify_sandbox, self.dify_plugin_daemon] for container in containers: if container: - try: - container_name = container.image - logger.info("Stopping container: %s", container_name) - container.stop() - logger.info("Successfully stopped container: %s", container_name) - except Exception as e: - # Log error but don't fail the test cleanup - logger.warning("Failed to stop container %s: %s", container, e) + container_name = container.image + logger.info("Stopping container: %s", container_name) + container.stop() + logger.info("Successfully stopped container: %s", container_name) # Stop and remove the network if self.network: - try: - logger.info("Removing Docker network...") - self.network.remove() - logger.info("Successfully removed Docker network") - except Exception as e: - logger.warning("Failed to remove Docker network: %s", e) + logger.info("Removing Docker network...") + self.network.remove() + logger.info("Successfully removed Docker network") self._containers_started = False logger.info("All test containers stopped and cleaned up successfully") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 7fad603a6d..4f606dccb8 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -8,12 +8,12 @@ from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin from models.account import AccountStatus, TenantAccountRole -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import App, AppMode, Conversation, Message from models.workflow import WorkflowRun from services.account_service import AccountService @@ -75,7 +75,7 @@ def _create_conversation(db_session: Session, app_id: str, account_id: str) -> C inputs={}, status="normal", mode=AppMode.CHAT, - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, ) db_session.add(conversation) @@ -124,7 +124,7 @@ def _create_message( answer_price_unit=0.001, currency="USD", status="normal", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account_id, workflow_run_id=workflow_run_id, inputs={"query": "Hello"}, diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index dcf31aeca7..96fb7ea293 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,16 +31,16 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.graph_engine.entities.commands import GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from core.workflow.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from core.workflow.runtime.variable_pool import SystemVariable, VariablePool +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.graph_engine.entities.commands import GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from dify_graph.graph_events.graph import GraphRunPausedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime.variable_pool import SystemVariable, VariablePool from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -544,7 +544,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from core.workflow.graph_events.graph import ( + from dify_graph.graph_events.graph import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py index cdf390b327..a60159c66a 100644 --- a/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/pipeline/test_queue_integration.py @@ -18,7 +18,7 @@ from faker import Faker from core.rag.pipeline.queue import TaskWrapper, TenantIsolatedTaskQueue from extensions.ext_redis import redis_client -from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole, TenantStatus @dataclass @@ -47,7 +47,7 @@ class TestTenantIsolatedTaskQueueIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -55,7 +55,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() @@ -101,7 +101,7 @@ class TestTenantIsolatedTaskQueueIntegration: # Create second tenant tenant2 = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant2) db_session_with_containers.commit() @@ -410,7 +410,7 @@ class TestTenantIsolatedTaskQueueCompatibility: email=fake.email(), name=fake.name(), interface_language="en-US", - status="active", + status=AccountStatus.ACTIVE, ) db_session_with_containers.add(account) db_session_with_containers.commit() @@ -418,7 +418,7 @@ class TestTenantIsolatedTaskQueueCompatibility: # Create tenant tenant = Tenant( name=fake.company(), - status="normal", + status=TenantStatus.NORMAL, ) db_session_with_containers.add(tenant) db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py index 4e6cc620ac..781e297fa4 100644 --- a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -5,9 +5,11 @@ import pytest from faker import Faker from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.account_service import AccountService, TenantService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestGetAvailableDatasetsIntegration: @@ -22,7 +24,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -34,7 +36,7 @@ class TestGetAvailableDatasetsIntegration: name=fake.company(), description=fake.text(max_nb_chars=100), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -48,14 +50,14 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field name=f"Document {i}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -83,7 +85,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -93,7 +95,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -105,13 +107,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Archived Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Archived ) @@ -136,7 +138,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -146,7 +148,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -158,13 +160,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Disabled Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Disabled archived=False, ) @@ -189,7 +191,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -199,21 +201,21 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) # Create documents with non-completed status - for i, status in enumerate(["indexing", "parsing", "splitting"]): + for i, status in enumerate([IndexingStatus.INDEXING, IndexingStatus.PARSING, IndexingStatus.SPLITTING]): document = Document( id=str(uuid.uuid4()), tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {status}", created_by=account.id, doc_form="text_model", @@ -252,7 +254,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -262,7 +264,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=fake.company(), provider="external", # External provider - data_source_type="external", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -286,7 +288,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company()) tenant1 = account1.current_tenant @@ -295,7 +297,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company()) tenant2 = account2.current_tenant @@ -306,7 +308,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant1.id, name="Tenant 1 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account1.id, ) db_session_with_containers.add(dataset1) @@ -317,7 +319,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant2.id, name="Tenant 2 Dataset", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account2.id, ) db_session_with_containers.add(dataset2) @@ -329,13 +331,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document for {dataset.name}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -362,7 +364,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -384,7 +386,7 @@ class TestGetAvailableDatasetsIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -397,7 +399,7 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, name=f"Dataset {i}", provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -409,13 +411,13 @@ class TestGetAvailableDatasetsIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=f"Document {i}", created_by=account.id, doc_form="text_model", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -445,7 +447,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -455,7 +457,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, indexing_technique="high_quality", ) @@ -466,12 +468,12 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=str(uuid.uuid4()), # Required field - created_from="web", + created_from=DocumentCreatedFrom.WEB, name=fake.sentence(), created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", @@ -513,7 +515,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -524,7 +526,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -561,7 +563,7 @@ class TestKnowledgeRetrievalIntegration: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -571,7 +573,7 @@ class TestKnowledgeRetrievalIntegration: tenant_id=tenant.id, name=fake.company(), provider="dify", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 079e4934bb..9d0fad4b12 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import Session from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -20,7 +20,7 @@ from core.workflow.nodes.human_input.entities import ( UserAction, WebAppDeliveryMethod, ) -from core.workflow.repositories.human_input_form_repository import FormCreateParams +from dify_graph.repositories.human_input_form_repository import FormCreateParams from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -100,7 +100,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,7 +129,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = _build_form_params( delivery_methods=[ _build_email_delivery( @@ -173,7 +173,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( app_id=str(uuid4()), @@ -210,7 +210,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=str(uuid4()), workflow_execution_id=str(uuid4()), diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 06d55177eb..9733735df3 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -12,27 +12,27 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowType -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.enums import WorkflowType +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.model import App, AppMode, IconType from models.workflow import Workflow, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun +from tests.workflow_test_utils import build_test_graph_init_params def _mock_form_repository_without_submission() -> HumanInputFormRepository: @@ -87,11 +87,11 @@ def _build_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id=tenant_id, - app_id=app_id, + params = build_test_graph_init_params( workflow_id=workflow_id, graph_config=graph_config, + tenant_id=tenant_id, + app_id=app_id, user_id=user_id, user_from="account", invoke_from="debugger", diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 3568a8b070..8e70fc0bb0 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -6,8 +6,9 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db +from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader from models import ToolFile, UploadFile from models.enums import CreatorUserRole @@ -53,7 +54,7 @@ class TestStorageKeyLoader(unittest.TestCase): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=storage_key, name="test_file.txt", size=1024, @@ -289,7 +290,7 @@ class TestStorageKeyLoader(unittest.TestCase): # Create upload file for other tenant (but don't add to cleanup list) upload_file_other = UploadFile( tenant_id=other_tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key="other_tenant_key", name="other_file.txt", size=1024, diff --git a/api/tests/test_containers_integration_tests/helpers/__init__.py b/api/tests/test_containers_integration_tests/helpers/__init__.py index 40d03889a9..0b753abd1f 100644 --- a/api/tests/test_containers_integration_tests/helpers/__init__.py +++ b/api/tests/test_containers_integration_tests/helpers/__init__.py @@ -1 +1,24 @@ """Helper utilities for integration tests.""" + +import re + + +def generate_valid_password(fake, length: int = 12) -> str: + """Generate a password that always satisfies the project's password validation rules. + + The password validation rule in ``api/libs/password.py`` requires passwords to + contain **both letters and digits** with a minimum length of 8: + + ``^(?=.*[a-zA-Z])(?=.*\\d).{8,}$`` + + ``Faker.password()`` does **not** guarantee that the generated password will + contain both character types, which can cause intermittent test failures. + + This helper re-generates until the result is valid (typically first attempt). + """ + for _ in range(100): + pwd = fake.password(length=length) + if re.search(r"[a-zA-Z]", pwd) and re.search(r"\d", pwd): + return pwd + # Fallback: should never be reached in practice + return fake.password(length=max(length - 2, 6)) + "a1" diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index 19d7772c39..fb8d1808f9 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -5,8 +5,9 @@ from datetime import datetime, timedelta from decimal import Decimal from uuid import uuid4 -from core.workflow.nodes.human_input.entities import FormDefinition, UserAction +from dify_graph.nodes.human_input.entities import FormDefinition, UserAction from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent from models.human_input import HumanInputForm, HumanInputFormStatus from models.model import App, Conversation, Message @@ -78,8 +79,8 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: introduction="", system_instruction="", status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, from_end_user_id=None, ) @@ -101,7 +102,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: answer_unit_price=Decimal("0.001"), provider_response_latency=0.5, currency="USD", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, workflow_run_id=workflow_run_id, ) diff --git a/api/tests/test_containers_integration_tests/models/test_app_model_config.py b/api/tests/test_containers_integration_tests/models/test_app_model_config.py new file mode 100644 index 0000000000..e8b36097e1 --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_app_model_config.py @@ -0,0 +1,32 @@ +""" +Integration tests for AppModelConfig using testcontainers. + +These tests validate database-backed model behavior without mocking SQLAlchemy queries. +""" + +from uuid import uuid4 + +from sqlalchemy.orm import Session + +from models.model import AppModelConfig + + +class TestAppModelConfig: + """Integration tests for AppModelConfig.""" + + def test_annotation_reply_dict_disabled_without_setting(self, db_session_with_containers: Session) -> None: + """Return disabled annotation reply dict when no AppAnnotationSetting exists.""" + # Arrange + config = AppModelConfig(app_id=str(uuid4())) + db_session_with_containers.add(config) + db_session_with_containers.commit() + + # Act + result = config.annotation_reply_dict + + # Assert + assert result == {"enabled": False} + + # Cleanup + db_session_with_containers.delete(config) + db_session_with_containers.commit() diff --git a/api/tests/test_containers_integration_tests/models/test_dataset_models.py b/api/tests/test_containers_integration_tests/models/test_dataset_models.py index 6c541a8ad2..a3bbf19657 100644 --- a/api/tests/test_containers_integration_tests/models/test_dataset_models.py +++ b/api/tests/test_containers_integration_tests/models/test_dataset_models.py @@ -12,6 +12,7 @@ import pytest from sqlalchemy.orm import Session from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus class TestDatasetDocumentProperties: @@ -29,7 +30,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -39,10 +40,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -56,7 +57,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -65,12 +66,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="available.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -78,12 +79,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="pending.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, archived=False, ) @@ -91,12 +92,12 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="disabled.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -111,7 +112,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -121,10 +122,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=i + 1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name=f"doc_{i}.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=wc, ) @@ -139,7 +140,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -148,10 +149,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -166,7 +167,7 @@ class TestDatasetDocumentProperties: content=f"segment {i}", word_count=100, tokens=50, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, created_by=created_by, ) @@ -180,7 +181,7 @@ class TestDatasetDocumentProperties: content="waiting segment", word_count=100, tokens=50, - status="waiting", + status=SegmentStatus.WAITING, enabled=True, created_by=created_by, ) @@ -195,7 +196,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -204,10 +205,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -235,7 +236,7 @@ class TestDatasetDocumentProperties: created_by = str(uuid4()) dataset = Dataset( - tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + tenant_id=tenant_id, name="Test Dataset", data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by ) db_session_with_containers.add(dataset) db_session_with_containers.flush() @@ -244,10 +245,10 @@ class TestDatasetDocumentProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="doc.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(doc) @@ -288,7 +289,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -298,10 +299,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -335,7 +336,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -345,10 +346,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -382,7 +383,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -392,10 +393,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) @@ -439,7 +440,7 @@ class TestDocumentSegmentNavigationProperties: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) db_session_with_containers.add(dataset) @@ -449,10 +450,10 @@ class TestDocumentSegmentNavigationProperties: tenant_id=tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) db_session_with_containers.add(document) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 556c029b24..458862b0ec 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -8,7 +8,7 @@ from uuid import uuid4 from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 05a868c0c2..c3ed79656f 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,6 +2,7 @@ from __future__ import annotations +import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -11,16 +12,27 @@ import pytest from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.entities.pause_reason import PauseReasonType -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities import WorkflowExecution +from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.human_input import ( + BackstageRecipientPayload, + HumanInputDelivery, + HumanInputForm, + HumanInputFormRecipient, + RecipientType, +) from models.workflow import WorkflowAppLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity from repositories.sqlalchemy_api_workflow_run_repository import ( DifyAPISQLAlchemyWorkflowRunRepository, + _build_human_input_required_reason, + _PrivateWorkflowPauseEntity, _WorkflowRunError, ) @@ -90,6 +102,19 @@ def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: WorkflowRun.app_id == scope.app_id, ) ) + + form_ids_subquery = select(HumanInputForm.id).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + session.execute(delete(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids_subquery))) + session.execute(delete(HumanInputDelivery).where(HumanInputDelivery.form_id.in_(form_ids_subquery))) + session.execute( + delete(HumanInputForm).where( + HumanInputForm.tenant_id == scope.tenant_id, + HumanInputForm.app_id == scope.app_id, + ) + ) session.commit() for state_key in scope.state_keys: @@ -504,3 +529,200 @@ class TestDeleteWorkflowPause: with pytest.raises(_WorkflowRunError, match="WorkflowPause not found"): repository.delete_workflow_pause(pause_entity=pause_entity) + + +class TestPrivateWorkflowPauseEntity: + """Integration tests for _PrivateWorkflowPauseEntity using real DB models.""" + + def test_properties( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Entity properties delegate to the persisted WorkflowPause model.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(pause.state_object_key) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + + assert entity.id == pause.id + assert entity.workflow_execution_id == workflow_run.id + assert entity.resumed_at is None + + def test_get_state( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state loads state data from storage using the persisted state_object_key.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result = entity.get_state() + + assert result == expected_state + + def test_get_state_caching( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """get_state caches the result so storage is only accessed once.""" + + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + state_key = f"workflow-state-{uuid4()}.json" + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=state_key, + ) + db_session_with_containers.add(pause) + db_session_with_containers.commit() + db_session_with_containers.refresh(pause) + test_scope.state_keys.add(state_key) + + expected_state = b'{"test": "state"}' + storage.save(state_key, expected_state) + + entity = _PrivateWorkflowPauseEntity(pause_model=pause, reason_models=[], human_input_form=[]) + result1 = entity.get_state() + # Delete from storage to prove second call uses cache + storage.delete(state_key) + test_scope.state_keys.discard(state_key) + result2 = entity.get_state() + + assert result1 == expected_state + assert result2 == expected_state + + +class TestBuildHumanInputRequiredReason: + """Integration tests for _build_human_input_required_reason using real DB models.""" + + def test_prefers_backstage_token_when_available( + self, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Use backstage token when multiple recipient types may exist.""" + + expiration_time = naive_utc_now() + form_definition = FormDefinition( + form_content="content", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], + user_actions=[UserAction(id="approve", title="Approve")], + rendered_content="rendered", + expiration_time=expiration_time, + default_values={"name": "Alice"}, + node_title="Ask Name", + display_in_ui=True, + ) + + form_model = HumanInputForm( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + workflow_run_id=str(uuid4()), + node_id="node-1", + form_definition=form_definition.model_dump_json(), + rendered_content="rendered", + status=HumanInputFormStatus.WAITING, + expiration_time=expiration_time, + ) + db_session_with_containers.add(form_model) + db_session_with_containers.flush() + + delivery = HumanInputDelivery( + form_id=form_model.id, + delivery_method_type=DeliveryMethodType.WEBAPP, + channel_payload="{}", + ) + db_session_with_containers.add(delivery) + db_session_with_containers.flush() + + access_token = secrets.token_urlsafe(8) + recipient = HumanInputFormRecipient( + form_id=form_model.id, + delivery_id=delivery.id, + recipient_type=RecipientType.BACKSTAGE, + recipient_payload=BackstageRecipientPayload().model_dump_json(), + access_token=access_token, + ) + db_session_with_containers.add(recipient) + db_session_with_containers.flush() + + # Create a pause so the reason has a valid pause_id + workflow_run = _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.RUNNING, + ) + pause = WorkflowPause( + id=str(uuid4()), + workflow_id=test_scope.workflow_id, + workflow_run_id=workflow_run.id, + state_object_key=f"workflow-state-{uuid4()}.json", + ) + db_session_with_containers.add(pause) + db_session_with_containers.flush() + test_scope.state_keys.add(pause.state_object_key) + + reason_model = WorkflowPauseReason( + pause_id=pause.id, + type_=PauseReasonType.HUMAN_INPUT_REQUIRED, + form_id=form_model.id, + node_id="node-1", + message="", + ) + db_session_with_containers.add(reason_model) + db_session_with_containers.commit() + + # Refresh to ensure we have DB-round-tripped objects + db_session_with_containers.refresh(form_model) + db_session_with_containers.refresh(reason_model) + db_session_with_containers.refresh(recipient) + + reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + + assert isinstance(reason, HumanInputRequired) + assert reason.form_token == access_token + assert reason.node_title == "Ask Name" + assert reason.form_content == "content" + assert reason.inputs[0].output_variable_name == "name" + assert reason.actions[0].id == "approve" diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py new file mode 100644 index 0000000000..1568d5d65c --- /dev/null +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -0,0 +1,391 @@ +"""Integration tests for get_paginated_workflow_runs and get_workflow_runs_count using testcontainers.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import timedelta +from uuid import uuid4 + +import pytest +from sqlalchemy import Engine, delete +from sqlalchemy import exc as sa_exc +from sqlalchemy.orm import Session, sessionmaker + +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus +from libs.datetime_utils import naive_utc_now +from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom +from models.workflow import WorkflowRun, WorkflowType +from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository + + +class _TestWorkflowRunRepository(DifyAPISQLAlchemyWorkflowRunRepository): + """Concrete repository for tests where save() is not under test.""" + + def save(self, execution: WorkflowExecution) -> None: + return None + + +@dataclass +class _TestScope: + """Per-test data scope used to isolate DB rows.""" + + tenant_id: str = field(default_factory=lambda: str(uuid4())) + app_id: str = field(default_factory=lambda: str(uuid4())) + workflow_id: str = field(default_factory=lambda: str(uuid4())) + user_id: str = field(default_factory=lambda: str(uuid4())) + + +def _create_workflow_run( + session: Session, + scope: _TestScope, + *, + status: WorkflowExecutionStatus, + triggered_from: WorkflowRunTriggeredFrom = WorkflowRunTriggeredFrom.DEBUGGING, + created_at_offset: timedelta | None = None, +) -> WorkflowRun: + """Create and persist a workflow run bound to the current test scope.""" + now = naive_utc_now() + workflow_run = WorkflowRun( + id=str(uuid4()), + tenant_id=scope.tenant_id, + app_id=scope.app_id, + workflow_id=scope.workflow_id, + type=WorkflowType.WORKFLOW, + triggered_from=triggered_from, + version="draft", + graph="{}", + inputs="{}", + status=status, + created_by_role=CreatorUserRole.ACCOUNT, + created_by=scope.user_id, + created_at=now + created_at_offset if created_at_offset is not None else now, + ) + session.add(workflow_run) + session.commit() + return workflow_run + + +def _cleanup_scope_data(session: Session, scope: _TestScope) -> None: + """Remove test-created DB rows for a test scope.""" + session.execute( + delete(WorkflowRun).where( + WorkflowRun.tenant_id == scope.tenant_id, + WorkflowRun.app_id == scope.app_id, + ) + ) + session.commit() + + +@pytest.fixture +def repository(db_session_with_containers: Session) -> DifyAPISQLAlchemyWorkflowRunRepository: + """Build a repository backed by the testcontainers database engine.""" + engine = db_session_with_containers.get_bind() + assert isinstance(engine, Engine) + return _TestWorkflowRunRepository(session_maker=sessionmaker(bind=engine, expire_on_commit=False)) + + +@pytest.fixture +def test_scope(db_session_with_containers: Session) -> _TestScope: + """Provide an isolated scope and clean related data after each test.""" + scope = _TestScope() + yield scope + _cleanup_scope_data(db_session_with_containers, scope) + + +class TestGetPaginatedWorkflowRuns: + """Integration tests for get_paginated_workflow_runs.""" + + def test_returns_runs_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return all runs for the given tenant/app when no status filter is applied.""" + for status in ( + WorkflowExecutionStatus.SUCCEEDED, + WorkflowExecutionStatus.FAILED, + WorkflowExecutionStatus.RUNNING, + ): + _create_workflow_run(db_session_with_containers, test_scope, status=status) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.limit == 20 + assert result.has_more is False + + def test_filters_by_status( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return only runs matching the requested status.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status="succeeded", + ) + + assert len(result.data) == 2 + assert all(run.status == WorkflowExecutionStatus.SUCCEEDED for run in result.data) + + def test_pagination_has_more( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Return has_more=True when more records exist beyond the limit.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + + assert len(result.data) == 3 + assert result.has_more is True + + def test_cursor_based_pagination( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Cursor-based pagination returns the next page of results.""" + for i in range(5): + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(seconds=i), + ) + + # First page + page1 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=None, + status=None, + ) + assert len(page1.data) == 3 + assert page1.has_more is True + + # Second page using cursor + page2 = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=3, + last_id=page1.data[-1].id, + status=None, + ) + assert len(page2.data) == 2 + assert page2.has_more is False + + # No overlap between pages + page1_ids = {r.id for r in page1.data} + page2_ids = {r.id for r in page2.data} + assert page1_ids.isdisjoint(page2_ids) + + def test_invalid_last_id_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + test_scope: _TestScope, + ) -> None: + """Raise ValueError when last_id refers to a non-existent run.""" + with pytest.raises(ValueError, match="Last workflow run not exists"): + repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=str(uuid4()), + status=None, + ) + + def test_tenant_isolation( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Runs from other tenants are not returned.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + other_scope = _TestScope(app_id=test_scope.app_id) + try: + _create_workflow_run(db_session_with_containers, other_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + result = repository.get_paginated_workflow_runs( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + limit=20, + last_id=None, + status=None, + ) + + assert len(result.data) == 1 + assert result.data[0].tenant_id == test_scope.tenant_id + finally: + _cleanup_scope_data(db_session_with_containers, other_scope) + + +class TestGetWorkflowRunsCount: + """Integration tests for get_workflow_runs_count.""" + + def test_count_without_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count all runs grouped by status when no status filter is applied.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + for _ in range(2): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.RUNNING) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + ) + + assert result["total"] == 6 + assert result["succeeded"] == 3 + assert result["failed"] == 2 + assert result["running"] == 1 + assert result["stopped"] == 0 + assert result["partial-succeeded"] == 0 + + def test_count_with_status_filter( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Count only runs matching the requested status.""" + for _ in range(3): + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + ) + + assert result["total"] == 3 + assert result["succeeded"] == 3 + assert result["failed"] == 0 + + def test_count_with_invalid_status_raises( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Invalid status raises StatementError because the column uses an enum type.""" + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + + with pytest.raises(sa_exc.StatementError) as exc_info: + repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="invalid_status", + ) + assert isinstance(exc_info.value.orig, ValueError) + + def test_count_with_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Time range filter excludes runs created outside the window.""" + # Recent run (within 1 day) + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Old run (8 days ago) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status=None, + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + + def test_count_with_status_and_time_range( + self, + repository: DifyAPISQLAlchemyWorkflowRunRepository, + db_session_with_containers: Session, + test_scope: _TestScope, + ) -> None: + """Both status and time_range filters apply together.""" + # Recent succeeded + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.SUCCEEDED) + # Recent failed + _create_workflow_run(db_session_with_containers, test_scope, status=WorkflowExecutionStatus.FAILED) + # Old succeeded (outside time range) + _create_workflow_run( + db_session_with_containers, + test_scope, + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_offset=timedelta(days=-8), + ) + + result = repository.get_workflow_runs_count( + tenant_id=test_scope.tenant_id, + app_id=test_scope.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + status="succeeded", + time_range="7d", + ) + + assert result["total"] == 1 + assert result["succeeded"] == 1 + assert result["failed"] == 0 diff --git a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py index 73df2d9ed9..638a61c815 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py +++ b/api/tests/test_containers_integration_tests/services/dataset_collection_binding.py @@ -9,9 +9,10 @@ from itertools import starmap from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from extensions.ext_database import db from models.dataset import DatasetCollectionBinding +from models.enums import CollectionBindingType from services.dataset_service import DatasetCollectionBindingService @@ -28,10 +29,11 @@ class DatasetCollectionBindingTestDataFactory: @staticmethod def create_collection_binding( + db_session_with_containers: Session, provider_name: str = "openai", model_name: str = "text-embedding-ada-002", collection_name: str = "collection-abc", - collection_type: str = "dataset", + collection_type: str = CollectionBindingType.DATASET, ) -> DatasetCollectionBinding: """ Create a DatasetCollectionBinding with specified attributes. @@ -40,7 +42,7 @@ class DatasetCollectionBindingTestDataFactory: provider_name: Name of the embedding model provider (e.g., "openai", "cohere") model_name: Name of the embedding model (e.g., "text-embedding-ada-002") collection_name: Name of the vector database collection - collection_type: Type of collection (default: "dataset") + collection_type: Type of collection (default: CollectionBindingType.DATASET) Returns: DatasetCollectionBinding instance @@ -51,8 +53,8 @@ class DatasetCollectionBindingTestDataFactory: collection_name=collection_name, type=collection_type, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -64,7 +66,7 @@ class TestDatasetCollectionBindingServiceGetBinding: including various provider/model combinations, collection types, and edge cases. """ - def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_existing_binding_success(self, db_session_with_containers: Session): """ Test successful retrieval of an existing collection binding. @@ -75,8 +77,9 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = "openai" model_name = "text-embedding-ada-002" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET existing_binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name=provider_name, model_name=model_name, collection_name="existing-collection", @@ -92,7 +95,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.id == existing_binding.id assert result.collection_name == "existing-collection" - def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_create_new_binding_success(self, db_session_with_containers: Session): """ Test successful creation of a new collection binding when none exists. @@ -102,7 +105,7 @@ class TestDatasetCollectionBindingServiceGetBinding: # Arrange provider_name = f"provider-{uuid4()}" model_name = f"model-{uuid4()}" - collection_type = "dataset" + collection_type = CollectionBindingType.DATASET # Act result = DatasetCollectionBindingService.get_dataset_collection_binding( @@ -116,7 +119,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.type == collection_type assert result.collection_name is not None - def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with different collection type.""" # Arrange provider_name = "openai" @@ -133,7 +136,7 @@ class TestDatasetCollectionBindingServiceGetBinding: assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_default_collection_type(self, db_session_with_containers: Session): """Test get_dataset_collection_binding with default collection type parameter.""" # Arrange provider_name = "openai" @@ -143,11 +146,13 @@ class TestDatasetCollectionBindingServiceGetBinding: result = DatasetCollectionBindingService.get_dataset_collection_binding(provider_name, model_name) # Assert - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET assert result.provider_name == provider_name assert result.model_name == model_name - def test_get_dataset_collection_binding_different_provider_model_combination(self, db_session_with_containers): + def test_get_dataset_collection_binding_different_provider_model_combination( + self, db_session_with_containers: Session + ): """Test get_dataset_collection_binding with various provider/model combinations.""" # Arrange combinations = [ @@ -174,39 +179,47 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: including successful retrieval and error handling for missing bindings. """ - def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_success(self, db_session_with_containers: Session): """Test successful retrieval of collection binding by ID and type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act - result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(binding.id, "dataset") + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + binding.id, CollectionBindingType.DATASET + ) # Assert assert result.id == binding.id assert result.provider_name == "openai" assert result.model_name == "text-embedding-ada-002" assert result.collection_name == "test-collection" - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET - def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_not_found_error(self, db_session_with_containers: Session): """Test error handling when collection binding is not found by ID and type.""" # Arrange non_existent_id = str(uuid4()) # Act & Assert with pytest.raises(ValueError, match="Dataset collection binding not found"): - DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(non_existent_id, "dataset") + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type( + non_existent_id, CollectionBindingType.DATASET + ) - def test_get_dataset_collection_binding_by_id_and_type_different_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_different_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID and type with different collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", @@ -222,14 +235,17 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: assert result.id == binding.id assert result.type == "custom_type" - def test_get_dataset_collection_binding_by_id_and_type_default_collection_type(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_default_collection_type( + self, db_session_with_containers: Session + ): """Test retrieval by ID with default collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act @@ -237,16 +253,17 @@ class TestDatasetCollectionBindingServiceGetBindingByIdAndType: # Assert assert result.id == binding.id - assert result.type == "dataset" + assert result.type == CollectionBindingType.DATASET - def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers): + def test_get_dataset_collection_binding_by_id_and_type_wrong_type_error(self, db_session_with_containers: Session): """Test error when binding exists but with wrong collection type.""" # Arrange binding = DatasetCollectionBindingTestDataFactory.create_collection_binding( + db_session_with_containers, provider_name="openai", model_name="text-embedding-ada-002", collection_name="test-collection", - collection_type="dataset", + collection_type=CollectionBindingType.DATASET, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py index 9871ef37e6..6b35f867d7 100644 --- a/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py +++ b/api/tests/test_containers_integration_tests/services/dataset_service_update_delete.py @@ -10,11 +10,12 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import AppDatasetJoin, Dataset, DatasetPermissionEnum +from models.enums import DataSourceType from models.model import App from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -27,6 +28,7 @@ class DatasetUpdateDeleteTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -37,13 +39,13 @@ class DatasetUpdateDeleteTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -51,14 +53,15 @@ class DatasetUpdateDeleteTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -70,7 +73,7 @@ class DatasetUpdateDeleteTestDataFactory: tenant_id=tenant_id, name=name, description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, @@ -78,12 +81,12 @@ class DatasetUpdateDeleteTestDataFactory: retrieval_model={"top_k": 2}, enable_api=enable_api, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_app(tenant_id: str, created_by: str, name: str = "Test App") -> App: + def create_app(db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test App") -> App: """Create a real app for AppDatasetJoin.""" app = App( tenant_id=tenant_id, @@ -96,16 +99,16 @@ class DatasetUpdateDeleteTestDataFactory: enable_api=True, created_by=created_by, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app @staticmethod - def create_app_dataset_join(app_id: str, dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, app_id: str, dataset_id: str) -> AppDatasetJoin: """Create a real AppDatasetJoin record.""" join = AppDatasetJoin(app_id=app_id, dataset_id=dataset_id) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @@ -114,7 +117,7 @@ class TestDatasetServiceDeleteDataset: Comprehensive integration tests for DatasetService.delete_dataset method. """ - def test_delete_dataset_success(self, db_session_with_containers): + def test_delete_dataset_success(self, db_session_with_containers: Session): """ Test successful deletion of a dataset. @@ -130,8 +133,10 @@ class TestDatasetServiceDeleteDataset: - Method returns True """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act with patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted: @@ -139,10 +144,10 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None mock_dataset_was_deleted.send.assert_called_once_with(dataset) - def test_delete_dataset_not_found(self, db_session_with_containers): + def test_delete_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -156,7 +161,9 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) dataset_id = str(uuid4()) # Act @@ -165,7 +172,7 @@ class TestDatasetServiceDeleteDataset: # Assert assert result is False - def test_delete_dataset_permission_denied_error(self, db_session_with_containers): + def test_delete_dataset_permission_denied_error(self, db_session_with_containers: Session): """ Test error handling when user lacks permission. @@ -178,19 +185,22 @@ class TestDatasetServiceDeleteDataset: - No database operations are performed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) normal_user, _ = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL, tenant=tenant, ) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act & Assert with pytest.raises(NoPermissionError): DatasetService.delete_dataset(dataset.id, normal_user) # Verify no deletion was attempted - assert db.session.get(Dataset, dataset.id) is not None + assert db_session_with_containers.get(Dataset, dataset.id) is not None class TestDatasetServiceDatasetUseCheck: @@ -198,7 +208,7 @@ class TestDatasetServiceDatasetUseCheck: Comprehensive integration tests for DatasetService.dataset_use_check method. """ - def test_dataset_use_check_in_use(self, db_session_with_containers): + def test_dataset_use_check_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is in use. @@ -211,10 +221,12 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) - app = DatasetUpdateDeleteTestDataFactory.create_app(tenant.id, owner.id) - DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(app.id, dataset.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + app = DatasetUpdateDeleteTestDataFactory.create_app(db_session_with_containers, tenant.id, owner.id) + DatasetUpdateDeleteTestDataFactory.create_app_dataset_join(db_session_with_containers, app.id, dataset.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -222,7 +234,7 @@ class TestDatasetServiceDatasetUseCheck: # Assert assert result is True - def test_dataset_use_check_not_in_use(self, db_session_with_containers): + def test_dataset_use_check_not_in_use(self, db_session_with_containers: Session): """ Test detection when dataset is not in use. @@ -235,8 +247,10 @@ class TestDatasetServiceDatasetUseCheck: - Database query is executed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) # Act result = DatasetService.dataset_use_check(dataset.id) @@ -250,7 +264,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: Comprehensive integration tests for DatasetService.update_dataset_api_status method. """ - def test_update_dataset_api_status_enable_success(self, db_session_with_containers): + def test_update_dataset_api_status_enable_success(self, db_session_with_containers: Session): """ Test successful enabling of dataset API access. @@ -264,8 +278,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -276,12 +294,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is True assert dataset.updated_by == owner.id assert dataset.updated_at == current_time - def test_update_dataset_api_status_disable_success(self, db_session_with_containers): + def test_update_dataset_api_status_disable_success(self, db_session_with_containers: Session): """ Test successful disabling of dataset API access. @@ -295,8 +313,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - Transaction is committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=True) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=True + ) current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) # Act @@ -307,11 +329,11 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, False) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False assert dataset.updated_by == owner.id - def test_update_dataset_api_status_not_found_error(self, db_session_with_containers): + def test_update_dataset_api_status_not_found_error(self, db_session_with_containers: Session): """ Test error handling when dataset is not found. @@ -330,7 +352,7 @@ class TestDatasetServiceUpdateDatasetApiStatus: with pytest.raises(NotFound, match="Dataset not found"): DatasetService.update_dataset_api_status(dataset_id, True) - def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers): + def test_update_dataset_api_status_missing_current_user_error(self, db_session_with_containers: Session): """ Test error handling when current_user is missing. @@ -343,8 +365,12 @@ class TestDatasetServiceUpdateDatasetApiStatus: - No updates are committed """ # Arrange - owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - dataset = DatasetUpdateDeleteTestDataFactory.create_dataset(tenant.id, owner.id, enable_api=False) + owner, tenant = DatasetUpdateDeleteTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + dataset = DatasetUpdateDeleteTestDataFactory.create_dataset( + db_session_with_containers, tenant.id, owner.id, enable_api=False + ) # Act & Assert with ( @@ -354,6 +380,6 @@ class TestDatasetServiceUpdateDatasetApiStatus: DatasetService.update_dataset_api_status(dataset.id, True) # Verify no commit was attempted - db.session.rollback() - db.session.refresh(dataset) + db_session_with_containers.rollback() + db_session_with_containers.refresh(dataset) assert dataset.enable_api is False diff --git a/api/tests/test_containers_integration_tests/services/document_service_status.py b/api/tests/test_containers_integration_tests/services/document_service_status.py index c08ea2a93b..f995ac7bef 100644 --- a/api/tests/test_containers_integration_tests/services/document_service_status.py +++ b/api/tests/test_containers_integration_tests/services/document_service_status.py @@ -13,9 +13,10 @@ from uuid import uuid4 import pytest +from extensions.storage.storage_type import StorageType from models import Account from models.dataset import Dataset, Document -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus from models.model import UploadFile from services.dataset_service import DocumentService from services.errors.document import DocumentIndexingError @@ -88,7 +89,7 @@ class DocumentStatusTestDataFactory: data_source_info=json.dumps(data_source_info or {}), batch=f"batch-{uuid4()}", name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, doc_form="text_model", ) @@ -100,7 +101,7 @@ class DocumentStatusTestDataFactory: document.paused_by = paused_by document.paused_at = paused_at document.doc_metadata = doc_metadata or {} - if indexing_status == "completed" and "completed_at" not in kwargs: + if indexing_status == IndexingStatus.COMPLETED and "completed_at" not in kwargs: document.completed_at = FIXED_TIME for key, value in kwargs.items(): @@ -139,7 +140,7 @@ class DocumentStatusTestDataFactory: dataset = Dataset( tenant_id=tenant_id, name=name, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) dataset.id = dataset_id @@ -198,7 +199,7 @@ class DocumentStatusTestDataFactory: """ upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"uploads/{uuid4()}", name=name, size=128, @@ -291,7 +292,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, is_paused=False, ) @@ -326,7 +327,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -354,7 +355,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=False, ) @@ -383,7 +384,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, is_paused=False, ) @@ -412,7 +413,7 @@ class TestDocumentServicePauseDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="error", + indexing_status=IndexingStatus.ERROR, is_paused=False, ) @@ -487,7 +488,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=True, paused_by=str(uuid4()), paused_at=paused_time, @@ -526,7 +527,7 @@ class TestDocumentServiceRecoverDocument: db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="indexing", + indexing_status=IndexingStatus.INDEXING, is_paused=False, ) @@ -609,7 +610,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -619,7 +620,7 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document) - assert document.indexing_status == "waiting" + assert document.indexing_status == IndexingStatus.WAITING expected_cache_key = f"document_{document.id}_is_retried" mock_document_service_dependencies["redis_client"].setex.assert_called_once_with(expected_cache_key, 600, 1) @@ -646,14 +647,14 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, position=2, ) @@ -665,8 +666,8 @@ class TestDocumentServiceRetryDocument: # Assert db_session_with_containers.refresh(document1) db_session_with_containers.refresh(document2) - assert document1.indexing_status == "waiting" - assert document2.indexing_status == "waiting" + assert document1.indexing_status == IndexingStatus.WAITING + assert document2.indexing_status == IndexingStatus.WAITING mock_document_service_dependencies["retry_task"].delay.assert_called_once_with( dataset.id, [document1.id, document2.id], mock_document_service_dependencies["user_id"] @@ -693,7 +694,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = "1" @@ -703,7 +704,7 @@ class TestDocumentServiceRetryDocument: DocumentService.retry_document(dataset.id, [document]) db_session_with_containers.refresh(document) - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR def test_retry_document_missing_current_user_error( self, db_session_with_containers, mock_document_service_dependencies @@ -726,7 +727,7 @@ class TestDocumentServiceRetryDocument: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) mock_document_service_dependencies["redis_client"].get.return_value = None @@ -816,7 +817,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document2 = DocumentStatusTestDataFactory.create_document( db_session_with_containers, @@ -824,7 +825,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) document_ids = [document1.id, document2.id] @@ -866,7 +867,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: tenant_id=dataset.tenant_id, document_id=str(uuid4()), enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, completed_at=FIXED_TIME, ) document_ids = [document.id] @@ -909,7 +910,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=False, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -951,7 +952,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: document_id=str(uuid4()), archived=True, enabled=True, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1015,7 +1016,7 @@ class TestDocumentServiceBatchUpdateDocumentStatus: dataset_id=dataset.id, tenant_id=dataset.tenant_id, document_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) document_ids = [document.id] @@ -1098,7 +1099,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1139,7 +1140,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, doc_metadata={"existing_key": "existing_value"}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1187,7 +1188,7 @@ class TestDocumentServiceRenameDocument: dataset_id=dataset.id, tenant_id=tenant_id, data_source_info={"upload_file_id": upload_file.id}, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act @@ -1277,7 +1278,7 @@ class TestDocumentServiceRenameDocument: document_id=document_id, dataset_id=dataset.id, tenant_id=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Act & Assert diff --git a/api/tests/test_containers_integration_tests/services/test_account_service.py b/api/tests/test_containers_integration_tests/services/test_account_service.py index 606e7e0b57..cc9596d15f 100644 --- a/api/tests/test_containers_integration_tests/services/test_account_service.py +++ b/api/tests/test_containers_integration_tests/services/test_account_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import Unauthorized from configs import dify_config @@ -19,6 +20,7 @@ from services.errors.account import ( TenantNotFoundError, ) from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAccountService: @@ -45,14 +47,14 @@ class TestAccountService: "passport_service": mock_passport_service, } - def test_create_account_and_login(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_login(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation and login with correct password. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -70,7 +72,9 @@ class TestAccountService: logged_in = AccountService.authenticate(email, password) assert logged_in.id == account.id - def test_create_account_without_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_without_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation without password (for OAuth users). """ @@ -92,7 +96,7 @@ class TestAccountService: assert account.password_salt is None def test_create_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account create with invalid new password format. @@ -113,7 +117,9 @@ class TestAccountService: password="invalid_new_password", ) - def test_create_account_registration_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_registration_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when registration is disabled. """ @@ -128,17 +134,19 @@ class TestAccountService: email=email, name=name, interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) - def test_create_account_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account creation when email is in freeze period. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True @@ -154,24 +162,26 @@ class TestAccountService: dify_config.BILLING_ENABLED = False # Reset config for other tests - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent account. """ fake = Faker() email = fake.email() - password = fake.password(length=12) + password = generate_valid_password(fake) with pytest.raises(AccountPasswordError): AccountService.authenticate(email, password) - def test_authenticate_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -186,22 +196,21 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(AccountLoginError): AccountService.authenticate(email, password) - def test_authenticate_wrong_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_wrong_password(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with wrong password. """ fake = Faker() email = fake.email() name = fake.name() - correct_password = fake.password(length=12) - wrong_password = fake.password(length=12) + correct_password = generate_valid_password(fake) + wrong_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -217,14 +226,16 @@ class TestAccountService: with pytest.raises(AccountPasswordError): AccountService.authenticate(email, wrong_password) - def test_authenticate_with_invite_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_with_invite_token( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invite token to set password for account without password. """ fake = Faker() email = fake.email() name = fake.name() - new_password = fake.password(length=12) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -249,7 +260,7 @@ class TestAccountService: assert authenticated_account.password_salt is not None def test_authenticate_pending_account_activation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication activates pending account. @@ -257,7 +268,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -270,24 +281,25 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Authenticate should activate the account authenticated_account = AccountService.authenticate(email, password) assert authenticated_account.status == AccountStatus.ACTIVE assert authenticated_account.initialized_at is not None - def test_update_account_password_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_password_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful password update. """ fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) - new_password = fake.password(length=12) + old_password = generate_valid_password(fake) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -308,7 +320,7 @@ class TestAccountService: assert authenticated_account.id == account.id def test_update_account_password_wrong_current_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with wrong current password. @@ -316,9 +328,9 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) - wrong_password = fake.password(length=12) - new_password = fake.password(length=12) + old_password = generate_valid_password(fake) + wrong_password = generate_valid_password(fake) + new_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -335,7 +347,7 @@ class TestAccountService: AccountService.update_account_password(account, wrong_password, new_password) def test_update_account_password_invalid_new_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test password update with invalid new password format. @@ -343,7 +355,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - old_password = fake.password(length=12) + old_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -360,14 +372,14 @@ class TestAccountService: with pytest.raises(ValueError): # Password validation error AccountService.update_account_password(account, old_password, "123") - def test_create_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account creation with automatic tenant creation. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -387,14 +399,13 @@ class TestAccountService: assert account.email == email # Verify tenant was created and linked - from extensions.ext_database import db - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" def test_create_account_and_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace creation is disabled. @@ -402,7 +413,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -419,7 +430,7 @@ class TestAccountService: ) def test_create_account_and_tenant_workspace_limit_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account creation when workspace limit is exceeded. @@ -427,7 +438,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies[ @@ -446,7 +457,9 @@ class TestAccountService: password=password, ) - def test_link_account_integrate_new_provider(self, db_session_with_containers, mock_external_service_dependencies): + def test_link_account_integrate_new_provider( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test linking account with new OAuth provider. """ @@ -469,15 +482,18 @@ class TestAccountService: AccountService.link_account_integrate("new-google", "google_open_id_123", account) # Verify integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="new-google").first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="new-google") + .first() + ) assert integration is not None assert integration.open_id == "google_open_id_123" def test_link_account_integrate_existing_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test linking account with existing provider (should update). @@ -504,22 +520,23 @@ class TestAccountService: AccountService.link_account_integrate("exists-google", "google_open_id_456", account) # Verify integration was updated - from extensions.ext_database import db from models import AccountIntegrate integration = ( - db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider="exists-google").first() + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider="exists-google") + .first() ) assert integration.open_id == "google_open_id_456" - def test_close_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_close_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test closing an account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -536,19 +553,18 @@ class TestAccountService: AccountService.close_account(account) # Verify account status changed - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.CLOSED - def test_update_account_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating account fields. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) updated_name = fake.name() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -568,14 +584,16 @@ class TestAccountService: assert updated_account.name == updated_name assert updated_account.interface_theme == "dark" - def test_update_account_invalid_field(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_account_invalid_field( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating account with invalid field. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -591,14 +609,14 @@ class TestAccountService: with pytest.raises(AttributeError): AccountService.update_account(account, invalid_field="value") - def test_update_login_info(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_login_info(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating login information. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -616,20 +634,19 @@ class TestAccountService: AccountService.update_login_info(account, ip_address=ip_address) # Verify login info was updated - from extensions.ext_database import db - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.last_login_ip == ip_address assert account.last_login_at is not None - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login with token generation. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -659,14 +676,16 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_login_pending_account_activation(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_pending_account_activation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test login activates pending account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -680,24 +699,23 @@ class TestAccountService: password=password, ) account.status = AccountStatus.PENDING - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Login should activate the account token_pair = AccountService.login(account) - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.status == AccountStatus.ACTIVE - def test_logout(self, db_session_with_containers, mock_external_service_dependencies): + def test_logout(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test logout functionality. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -723,14 +741,14 @@ class TestAccountService: refresh_token_key = f"account_refresh_token:{account.id}" assert redis_client.get(refresh_token_key) is None - def test_refresh_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful token refresh. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -757,7 +775,7 @@ class TestAccountService: assert new_token_pair.access_token == "new_mock_access_token" assert new_token_pair.refresh_token != initial_token_pair.refresh_token - def test_refresh_token_invalid_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test refresh token with invalid token. """ @@ -766,14 +784,16 @@ class TestAccountService: with pytest.raises(ValueError, match="Invalid refresh token"): AccountService.refresh_token(invalid_token) - def test_refresh_token_invalid_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_refresh_token_invalid_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test refresh token with valid token but invalid account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -791,23 +811,22 @@ class TestAccountService: token_pair = AccountService.login(account) # Delete account - from extensions.ext_database import db - db.session.delete(account) - db.session.commit() + db_session_with_containers.delete(account) + db_session_with_containers.commit() # Try to refresh token with deleted account with pytest.raises(ValueError, match="Invalid account"): AccountService.refresh_token(token_pair.refresh_token) - def test_load_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading user by ID successfully. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -830,7 +849,7 @@ class TestAccountService: assert loaded_user.id == account.id assert loaded_user.email == account.email - def test_load_user_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading non-existent user. """ @@ -839,14 +858,14 @@ class TestAccountService: loaded_user = AccountService.load_user(non_existent_user_id) assert loaded_user is None - def test_load_user_banned_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_user_banned_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading banned user raises Unauthorized. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -861,21 +880,20 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.load_user(account.id) - def test_get_account_jwt_token(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_account_jwt_token(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test JWT token generation for account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -902,14 +920,14 @@ class TestAccountService: assert call_args["iss"] is not None assert call_args["sub"] == "Console API Passport" - def test_load_logged_in_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_load_logged_in_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test loading logged in account by ID. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -931,14 +949,16 @@ class TestAccountService: assert loaded_account is not None assert loaded_account.id == account.id - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email successfully. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -957,7 +977,9 @@ class TestAccountService: assert found_user is not None assert found_user.id == account.id - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through non-existent email. """ @@ -968,7 +990,7 @@ class TestAccountService: assert found_user is None def test_get_user_through_email_banned_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting banned user through email raises Unauthorized. @@ -976,7 +998,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -991,14 +1013,15 @@ class TestAccountService: # Ban the account account.status = AccountStatus.BANNED - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(Unauthorized): # Unauthorized exception AccountService.get_user_through_email(email) - def test_get_user_through_email_in_freeze(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_in_freeze( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting user through email that is in freeze period. """ @@ -1014,14 +1037,14 @@ class TestAccountService: # Reset config dify_config.BILLING_ENABLED = False - def test_delete_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account deletion (should add task to queue and sync to enterprise). """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1050,7 +1073,7 @@ class TestAccountService: mock_delete_task.delay.assert_called_once_with(account.id) def test_generate_account_deletion_verification_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generating account deletion verification code. @@ -1058,7 +1081,7 @@ class TestAccountService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1079,14 +1102,16 @@ class TestAccountService: assert len(code) == 6 assert code.isdigit() - def test_verify_account_deletion_code_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_valid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying valid account deletion code. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -1106,14 +1131,16 @@ class TestAccountService: is_valid = AccountService.verify_account_deletion_code(token, code) assert is_valid is True - def test_verify_account_deletion_code_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_verify_account_deletion_code_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test verifying invalid account deletion code. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) wrong_code = fake.numerify(text="######") # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -1135,7 +1162,7 @@ class TestAccountService: assert is_valid is False def test_verify_account_deletion_code_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test verifying account deletion code with invalid token. @@ -1167,7 +1194,7 @@ class TestTenantService: "billing_service": mock_billing_service, } - def test_create_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant creation with default settings. """ @@ -1187,7 +1214,7 @@ class TestTenantService: assert tenant.encrypt_public_key is not None def test_create_tenant_workspace_creation_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant creation when workspace creation is disabled. @@ -1202,7 +1229,9 @@ class TestTenantService: with pytest.raises(NotAllowedCreateWorkspace): # NotAllowedCreateWorkspace exception TenantService.create_tenant(name=tenant_name) - def test_create_tenant_with_custom_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_with_custom_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant creation with custom name and setup flag. """ @@ -1221,7 +1250,9 @@ class TestTenantService: assert tenant.status == "normal" assert tenant.encrypt_public_key is not None - def test_create_tenant_member_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful tenant member creation. """ @@ -1229,7 +1260,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1251,7 +1282,9 @@ class TestTenantService: assert tenant_member.account_id == account.id assert tenant_member.role == "admin" - def test_create_tenant_member_duplicate_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_duplicate_owner( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test creating duplicate owner for a tenant (should fail). """ @@ -1259,10 +1292,10 @@ class TestTenantService: tenant_name = fake.company() email1 = fake.email() name1 = fake.name() - password1 = fake.password(length=12) + password1 = generate_valid_password(fake) email2 = fake.email() name2 = fake.name() - password2 = fake.password(length=12) + password2 = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1290,7 +1323,9 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant already has an owner"): TenantService.create_tenant_member(tenant, account2, role="owner") - def test_create_tenant_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_tenant_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating role for existing tenant member. """ @@ -1298,7 +1333,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1323,14 +1358,14 @@ class TestTenantService: assert tenant_member2.account_id == tenant_member1.account_id assert tenant_member2.role == "editor" - def test_get_join_tenants_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_join_tenants_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting join tenants for an account. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant1_name = fake.company() tenant2_name = fake.company() # Setup mocks @@ -1361,7 +1396,7 @@ class TestTenantService: assert tenant2_name in tenant_names def test_get_current_tenant_by_account_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant by account successfully. @@ -1369,7 +1404,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -1388,9 +1423,8 @@ class TestTenantService: # Add account to tenant and set as current TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get current tenant current_tenant = TenantService.get_current_tenant_by_account(account) @@ -1400,7 +1434,7 @@ class TestTenantService: assert current_tenant.role == "owner" def test_get_current_tenant_by_account_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting current tenant when account has no current tenant. @@ -1408,7 +1442,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1426,14 +1460,14 @@ class TestTenantService: with pytest.raises((AttributeError, TenantNotFoundError)): TenantService.get_current_tenant_by_account(account) - def test_switch_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tenant switching. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant1_name = fake.company() tenant2_name = fake.company() # Setup mocks @@ -1457,25 +1491,24 @@ class TestTenantService: # Set initial current tenant account.current_tenant = tenant1 - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Switch to second tenant TenantService.switch_tenant(account, tenant2.id) # Verify tenant was switched - db.session.refresh(account) + db_session_with_containers.refresh(account) assert account.current_tenant_id == tenant2.id - def test_switch_tenant_no_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_no_tenant_id(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tenant switching without providing tenant ID. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1493,14 +1526,16 @@ class TestTenantService: with pytest.raises(ValueError, match="Tenant ID must be provided"): TenantService.switch_tenant(account, None) - def test_switch_tenant_account_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_tenant_account_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test switching to a tenant where account is not a member. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) tenant_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -1520,7 +1555,7 @@ class TestTenantService: with pytest.raises(Exception, match="Tenant not found or account is not a member of the tenant"): TenantService.switch_tenant(account, tenant.id) - def test_has_roles_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking if tenant has specific roles. """ @@ -1528,10 +1563,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1570,7 +1605,7 @@ class TestTenantService: has_normal = TenantService.has_roles(tenant, [TenantAccountRole.NORMAL]) assert has_normal is False - def test_has_roles_invalid_role_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_has_roles_invalid_role_type(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test checking roles with invalid role type. """ @@ -1589,7 +1624,7 @@ class TestTenantService: with pytest.raises(ValueError, match="all roles must be TenantAccountRole"): TenantService.has_roles(tenant, [invalid_role]) - def test_get_user_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting user role in a tenant. """ @@ -1597,7 +1632,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1620,7 +1655,9 @@ class TestTenantService: assert user_role == "editor" - def test_check_member_permission_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission successfully. """ @@ -1628,10 +1665,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1660,7 +1697,7 @@ class TestTenantService: TenantService.check_member_permission(tenant, owner_account, member_account, "add") def test_check_member_permission_invalid_action( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test checking member permission with invalid action. @@ -1669,7 +1706,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) invalid_action = "invalid_action_that_doesnt_exist" # Setup mocks mock_external_service_dependencies[ @@ -1692,7 +1729,9 @@ class TestTenantService: with pytest.raises(Exception, match="Invalid action"): TenantService.check_member_permission(tenant, account, None, invalid_action) - def test_check_member_permission_operate_self(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_member_permission_operate_self( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test checking member permission when trying to operate self. """ @@ -1700,7 +1739,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1722,7 +1761,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.check_member_permission(tenant, account, account, "remove") - def test_remove_member_from_tenant_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful member removal from tenant (should sync to enterprise). """ @@ -1730,10 +1771,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1770,16 +1811,17 @@ class TestTenantService: ) # Verify member was removed - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join is None def test_remove_member_from_tenant_operate_self( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test removing member when trying to operate self. @@ -1788,7 +1830,7 @@ class TestTenantService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1810,7 +1852,9 @@ class TestTenantService: with pytest.raises(Exception, match="Cannot operate self"): TenantService.remove_member_from_tenant(tenant, account, account) - def test_remove_member_from_tenant_not_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_member_from_tenant_not_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test removing member who is not in the tenant. """ @@ -1818,10 +1862,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) non_member_email = fake.email() non_member_name = fake.name() - non_member_password = fake.password(length=12) + non_member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1849,7 +1893,7 @@ class TestTenantService: with pytest.raises(Exception, match="Member not in tenant"): TenantService.remove_member_from_tenant(tenant, non_member_account, owner_account) - def test_update_member_role_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful member role update. """ @@ -1857,10 +1901,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1889,15 +1933,16 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "admin", owner_account) # Verify role was updated - from extensions.ext_database import db from models.account import TenantAccountJoin member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert member_join.role == "admin" - def test_update_member_role_to_owner(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_to_owner(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating member role to owner (should change current owner to admin). """ @@ -1905,10 +1950,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1937,19 +1982,24 @@ class TestTenantService: TenantService.update_member_role(tenant, member_account, "owner", owner_account) # Verify roles were updated correctly - from extensions.ext_database import db from models.account import TenantAccountJoin owner_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=owner_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=owner_account.id) + .first() ) member_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=member_account.id) + .first() ) assert owner_join.role == "admin" assert member_join.role == "owner" - def test_update_member_role_already_assigned(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_member_role_already_assigned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating member role to already assigned role. """ @@ -1957,10 +2007,10 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) member_email = fake.email() member_name = fake.name() - member_password = fake.password(length=12) + member_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -1989,7 +2039,7 @@ class TestTenantService: with pytest.raises(Exception, match="The provided role is already assigned to the member"): TenantService.update_member_role(tenant, member_account, "admin", owner_account) - def test_get_tenant_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_count_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant count successfully. """ @@ -2014,7 +2064,7 @@ class TestTenantService: assert tenant_count >= 3 def test_create_owner_tenant_if_not_exist_new_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant for new user without existing tenants. @@ -2022,7 +2072,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) workspace_name = fake.company() # Setup mocks mock_external_service_dependencies[ @@ -2044,17 +2094,16 @@ class TestTenantService: TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == workspace_name def test_create_owner_tenant_if_not_exist_existing_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when user already has a tenant. @@ -2062,7 +2111,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) existing_tenant_name = fake.company() new_workspace_name = fake.company() # Setup mocks @@ -2083,20 +2132,19 @@ class TestTenantService: existing_tenant = TenantService.create_tenant(name=existing_tenant_name) TenantService.create_tenant_member(existing_tenant, account, role="owner") account.current_tenant = existing_tenant - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Try to create owner tenant again (should not create new one) TenantService.create_owner_tenant_if_not_exist(account, name=new_workspace_name) # Verify no new tenant was created - tenant_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).all() + tenant_joins = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).all() assert len(tenant_joins) == 1 assert account.current_tenant.id == existing_tenant.id def test_create_owner_tenant_if_not_exist_workspace_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating owner tenant when workspace creation is disabled. @@ -2104,7 +2152,7 @@ class TestTenantService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) workspace_name = fake.company() # Setup mocks to disable workspace creation mock_external_service_dependencies[ @@ -2123,7 +2171,7 @@ class TestTenantService: with pytest.raises(WorkSpaceNotAllowedCreateError): # WorkSpaceNotAllowedCreateError exception TenantService.create_owner_tenant_if_not_exist(account, name=workspace_name) - def test_get_tenant_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_members_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting tenant members successfully. """ @@ -2131,13 +2179,13 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) normal_email = fake.email() normal_name = fake.name() - normal_password = fake.password(length=12) + normal_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2187,7 +2235,9 @@ class TestTenantService: elif member.email == normal_email: assert member.role == "normal" - def test_get_dataset_operator_members_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_operator_members_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting dataset operator members successfully. """ @@ -2195,13 +2245,13 @@ class TestTenantService: tenant_name = fake.company() owner_email = fake.email() owner_name = fake.name() - owner_password = fake.password(length=12) + owner_password = generate_valid_password(fake) operator_email = fake.email() operator_name = fake.name() - operator_password = fake.password(length=12) + operator_password = generate_valid_password(fake) normal_email = fake.email() normal_name = fake.name() - normal_password = fake.password(length=12) + normal_password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies[ "feature_service" @@ -2240,7 +2290,7 @@ class TestTenantService: assert dataset_operators[0].email == operator_email assert dataset_operators[0].role == "dataset_operator" - def test_get_custom_config_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_custom_config_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting custom config successfully. """ @@ -2259,9 +2309,8 @@ class TestTenantService: # Set custom config custom_config = {"theme": theme, "language": language, "feature_flags": {"beta": True}} tenant.custom_config_dict = custom_config - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Get custom config retrieved_config = TenantService.get_custom_config(tenant.id) @@ -2296,24 +2345,23 @@ class TestRegisterService: "passport_service": mock_passport_service, } - def test_setup_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful system setup with account creation and tenant setup. """ fake = Faker() admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False - from extensions.ext_database import db from models.model import DifySetup - db.session.query(DifySetup).delete() - db.session.commit() + db_session_with_containers.query(DifySetup).delete() + db_session_with_containers.commit() # Execute setup RegisterService.setup( @@ -2327,7 +2375,7 @@ class TestRegisterService: # Verify account was created from models import Account - account = db.session.query(Account).filter_by(email=admin_email).first() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() assert account is not None assert account.name == admin_name assert account.last_login_ip == ip_address @@ -2335,24 +2383,24 @@ class TestRegisterService: assert account.status == "active" # Verify DifySetup was created - dify_setup = db.session.query(DifySetup).first() + dify_setup = db_session_with_containers.query(DifySetup).first() assert dify_setup is not None # Verify tenant was created and linked from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_setup_failure_rollback(self, db_session_with_containers, mock_external_service_dependencies): + def test_setup_failure_rollback(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test setup failure with proper rollback of all created entities. """ fake = Faker() admin_email = fake.email() admin_name = fake.name() - admin_password = fake.password(length=12) + admin_password = generate_valid_password(fake) ip_address = fake.ipv4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2373,28 +2421,27 @@ class TestRegisterService: ) # Verify no entities were created (rollback worked) - from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin from models.model import DifySetup - account = db.session.query(Account).filter_by(email=admin_email).first() - tenant_count = db.session.query(Tenant).count() - tenant_join_count = db.session.query(TenantAccountJoin).count() - dify_setup_count = db.session.query(DifySetup).count() + account = db_session_with_containers.query(Account).filter_by(email=admin_email).first() + tenant_count = db_session_with_containers.query(Tenant).count() + tenant_join_count = db_session_with_containers.query(TenantAccountJoin).count() + dify_setup_count = db_session_with_containers.query(DifySetup).count() assert account is None assert tenant_count == 0 assert tenant_join_count == 0 assert dify_setup_count == 0 - def test_register_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful account registration with workspace creation. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2421,16 +2468,15 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" assert account.current_tenant is not None assert account.current_tenant.name == f"{name}'s Workspace" - def test_register_with_oauth(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_oauth(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration with OAuth integration. """ @@ -2467,21 +2513,26 @@ class TestRegisterService: assert account.initialized_at is not None # Verify OAuth integration was created - from extensions.ext_database import db from models import AccountIntegrate - integration = db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + integration = ( + db_session_with_containers.query(AccountIntegrate) + .filter_by(account_id=account.id, provider=provider) + .first() + ) assert integration is not None assert integration.open_id == open_id - def test_register_with_pending_status(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_with_pending_status( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration with pending status. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2511,21 +2562,22 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is not None assert tenant_join.role == "owner" - def test_register_workspace_creation_disabled(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_creation_disabled( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace creation is disabled. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2549,20 +2601,21 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_workspace_limit_exceeded(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_workspace_limit_exceeded( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test account registration when workspace limit is exceeded. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2589,20 +2642,19 @@ class TestRegisterService: assert account.initialized_at is not None # Verify tenant was created and linked - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_register_without_workspace(self, db_session_with_containers, mock_external_service_dependencies): + def test_register_without_workspace(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test account registration without workspace creation. """ fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2624,13 +2676,14 @@ class TestRegisterService: assert account.initialized_at is not None # Verify no tenant was created - from extensions.ext_database import db from models.account import TenantAccountJoin - tenant_join = db.session.query(TenantAccountJoin).filter_by(account_id=account.id).first() + tenant_join = db_session_with_containers.query(TenantAccountJoin).filter_by(account_id=account.id).first() assert tenant_join is None - def test_invite_new_member_new_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_new_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a new member who doesn't have an account yet. """ @@ -2638,7 +2691,7 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) new_member_email = fake.email() language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks @@ -2682,22 +2735,25 @@ class TestRegisterService: mock_send_mail.delay.assert_called_once() # Verify new account was created with pending status - from extensions.ext_database import db from models import Account, TenantAccountJoin - new_account = db.session.query(Account).filter_by(email=new_member_email).first() + new_account = db_session_with_containers.query(Account).filter_by(email=new_member_email).first() assert new_account is not None assert new_account.name == new_member_email.split("@")[0] # Default name from email assert new_account.status == "pending" # Verify tenant member was created tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=new_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=new_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "normal" - def test_invite_new_member_existing_account(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_account( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting an existing member who is not in the tenant yet. """ @@ -2705,10 +2761,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) existing_member_email = fake.email() existing_member_name = fake.name() - existing_member_password = fake.password(length=12) + existing_member_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2749,16 +2805,19 @@ class TestRegisterService: mock_send_mail.delay.assert_not_called() # Verify tenant member was created for existing account - from extensions.ext_database import db from models.account import TenantAccountJoin tenant_join = ( - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=existing_account.id).first() + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=existing_account.id) + .first() ) assert tenant_join is not None assert tenant_join.role == "admin" - def test_invite_new_member_existing_member(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_existing_member( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member who is already in the tenant with pending status. """ @@ -2766,10 +2825,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) existing_pending_member_email = fake.email() existing_pending_member_name = fake.name() - existing_pending_member_password = fake.password(length=12) + existing_pending_member_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2793,9 +2852,8 @@ class TestRegisterService: password=existing_pending_member_password, ) existing_account.status = "pending" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2820,7 +2878,9 @@ class TestRegisterService: # Verify email task was called mock_send_mail.delay.assert_called_once() - def test_invite_new_member_no_inviter(self, db_session_with_containers, mock_external_service_dependencies): + def test_invite_new_member_no_inviter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test inviting a member without providing an inviter. """ @@ -2846,7 +2906,7 @@ class TestRegisterService: ) def test_invite_new_member_account_already_in_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test inviting a member who is already in the tenant with active status. @@ -2855,10 +2915,10 @@ class TestRegisterService: tenant_name = fake.company() inviter_email = fake.email() inviter_name = fake.name() - inviter_password = fake.password(length=12) + inviter_password = generate_valid_password(fake) already_in_tenant_email = fake.email() already_in_tenant_name = fake.name() - already_in_tenant_password = fake.password(length=12) + already_in_tenant_password = generate_valid_password(fake) language = fake.random_element(elements=("en-US", "zh-CN")) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -2882,9 +2942,8 @@ class TestRegisterService: password=already_in_tenant_password, ) existing_account.status = "active" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Add existing account to tenant TenantService.create_tenant_member(tenant, existing_account, role="normal") @@ -2899,7 +2958,9 @@ class TestRegisterService: inviter=inviter, ) - def test_generate_invite_token_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_invite_token_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation of invite token. """ @@ -2907,7 +2968,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -2943,7 +3004,7 @@ class TestRegisterService: assert invitation_data["email"] == account.email assert invitation_data["workspace_id"] == tenant.id - def test_is_valid_invite_token_valid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_valid(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation of valid invite token. """ @@ -2951,7 +3012,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -2974,7 +3035,9 @@ class TestRegisterService: # Verify token is valid assert is_valid is True - def test_is_valid_invite_token_invalid(self, db_session_with_containers, mock_external_service_dependencies): + def test_is_valid_invite_token_invalid( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation of invalid invite token. """ @@ -2987,7 +3050,7 @@ class TestRegisterService: assert is_valid is False def test_revoke_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token with workspace ID and email. @@ -2996,7 +3059,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3030,7 +3093,7 @@ class TestRegisterService: assert redis_client.get(token_key) is not None def test_revoke_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test revoking token without workspace ID and email. @@ -3039,7 +3102,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3073,7 +3136,7 @@ class TestRegisterService: assert redis_client.get(token_key) is None def test_get_invitation_if_token_valid_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with valid token. @@ -3082,7 +3145,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False @@ -3122,7 +3185,7 @@ class TestRegisterService: assert result["data"]["workspace_id"] == tenant.id def test_get_invitation_if_token_valid_invalid_token( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid token. @@ -3142,7 +3205,7 @@ class TestRegisterService: assert result is None def test_get_invitation_if_token_valid_invalid_tenant( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with invalid tenant. @@ -3150,7 +3213,7 @@ class TestRegisterService: fake = Faker() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) invalid_tenant_id = fake.uuid4() token = fake.uuid4() # Setup mocks @@ -3192,7 +3255,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_account_mismatch( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with account ID mismatch. @@ -3201,7 +3264,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) token = fake.uuid4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -3242,7 +3305,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_if_token_valid_tenant_not_normal( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation data with tenant not in normal status. @@ -3251,7 +3314,7 @@ class TestRegisterService: tenant_name = fake.company() email = fake.email() name = fake.name() - password = fake.password(length=12) + password = generate_valid_password(fake) token = fake.uuid4() # Setup mocks mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True @@ -3268,10 +3331,9 @@ class TestRegisterService: TenantService.create_tenant_member(tenant, account, role="normal") # Change tenant status to non-normal - tenant.status = "suspended" - from extensions.ext_database import db + tenant.status = "archive" - db.session.commit() + db_session_with_containers.commit() # Create a real token from extensions.ext_redis import redis_client @@ -3300,7 +3362,7 @@ class TestRegisterService: redis_client.delete(token_key) def test_get_invitation_by_token_with_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token with workspace ID and email. @@ -3339,7 +3401,7 @@ class TestRegisterService: redis_client.delete(cache_key) def test_get_invitation_by_token_without_workspace_and_email( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting invitation by token without workspace ID and email. @@ -3372,7 +3434,7 @@ class TestRegisterService: # Clean up redis_client.delete(token_key) - def test_get_invitation_token_key(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_invitation_token_key(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting invitation token key. """ diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index e7cc140582..b51fbc3a42 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -3,13 +3,16 @@ from unittest.mock import MagicMock, create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.plugin.impl.exc import PluginDaemonClientSideError from models import Account +from models.enums import ConversationFromSource, MessageFileBelongsTo from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought from services.account_service import AccountService, TenantService from services.agent_service import AgentService from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAgentService: @@ -87,7 +90,7 @@ class TestAgentService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -110,7 +113,7 @@ class TestAgentService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -133,13 +136,12 @@ class TestAgentService: # Update the app model config to set agent_mode for agent-chat mode if app.mode == "agent-chat" and app.app_model_config: app.app_model_config.agent_mode = json.dumps({"enabled": True, "strategy": "react", "tools": []}) - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() return app, account - def _create_test_conversation_and_message(self, db_session_with_containers, app, account): + def _create_test_conversation_and_message(self, db_session_with_containers: Session, app, account): """ Helper method to create a test conversation and message with agent thoughts. @@ -153,8 +155,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation conversation = Conversation( id=fake.uuid4(), @@ -165,10 +165,10 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -180,12 +180,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -204,14 +204,14 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return conversation, message - def _create_test_agent_thoughts(self, db_session_with_containers, message): + def _create_test_agent_thoughts(self, db_session_with_containers: Session, message): """ Helper method to create test agent thoughts for a message. @@ -224,8 +224,6 @@ class TestAgentService: """ fake = Faker() - from extensions.ext_database import db - agent_thoughts = [] # Create first agent thought @@ -251,7 +249,7 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought1) + db_session_with_containers.add(thought1) agent_thoughts.append(thought1) # Create second agent thought @@ -277,14 +275,14 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought2) + db_session_with_containers.add(thought2) agent_thoughts.append(thought2) - db.session.commit() + db_session_with_containers.commit() return agent_thoughts - def test_get_agent_logs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of agent logs with complete data. """ @@ -344,7 +342,7 @@ class TestAgentService: assert dataset_tool_call["tool_icon"] == "" # dataset-retrieval tools have empty icon def test_get_agent_logs_conversation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when conversation is not found. @@ -358,7 +356,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Conversation not found"): AgentService.get_agent_logs(app, fake.uuid4(), fake.uuid4()) - def test_get_agent_logs_message_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_message_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when message is not found. """ @@ -372,7 +372,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Message not found"): AgentService.get_agent_logs(app, str(conversation.id), fake.uuid4()) - def test_get_agent_logs_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when conversation is from end user. """ @@ -381,8 +383,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create end user end_user = EndUser( id=fake.uuid4(), @@ -393,8 +393,8 @@ class TestAgentService: session_id=fake.uuid4(), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create conversation with end user conversation = Conversation( @@ -406,10 +406,10 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -421,12 +421,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -445,10 +445,10 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -457,7 +457,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == end_user.name - def test_get_agent_logs_with_unknown_executor(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_unknown_executor( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval when executor is unknown. """ @@ -466,8 +468,6 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create conversation with non-existent account conversation = Conversation( id=fake.uuid4(), @@ -478,10 +478,10 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -493,12 +493,12 @@ class TestAgentService: agent_mode=json.dumps({"enabled": True, "strategy": "react", "tools": []}), ) app_model_config.id = fake.uuid4() - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Update conversation with app model config conversation.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() # Create message message = Message( @@ -517,10 +517,10 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -529,7 +529,9 @@ class TestAgentService: assert result is not None assert result["meta"]["executor"] == "Unknown" - def test_get_agent_logs_with_tool_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_tool_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with tool errors. """ @@ -539,8 +541,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with tool error thought_with_error = MessageAgentThought( message_id=message.id, @@ -564,8 +564,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_error) - db.session.commit() + db_session_with_containers.add(thought_with_error) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -580,7 +580,7 @@ class TestAgentService: assert tool_call["error"] == "Tool execution failed" def test_get_agent_logs_without_agent_thoughts( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval when message has no agent thoughts. @@ -600,7 +600,7 @@ class TestAgentService: assert len(result["iterations"]) == 0 def test_get_agent_logs_app_model_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is not found. @@ -610,11 +610,9 @@ class TestAgentService: # Create test data app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Remove app model config to test error handling app.app_model_config_id = None - db.session.commit() + db_session_with_containers.commit() # Create conversation without app model config conversation = Conversation( @@ -626,11 +624,11 @@ class TestAgentService: inputs={}, status="normal", mode="chat", - from_source="api", + from_source=ConversationFromSource.API, app_model_config_id=None, # Explicitly set to None ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( @@ -649,17 +647,17 @@ class TestAgentService: answer_unit_price=0.001, provider_response_latency=1.5, currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() # Execute the method under test with pytest.raises(ValueError, match="App model config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) def test_get_agent_logs_agent_config_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when agent config is not found. @@ -677,7 +675,9 @@ class TestAgentService: with pytest.raises(ValueError, match="Agent config not found"): AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) - def test_list_agent_providers_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_agent_providers_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful listing of agent providers. """ @@ -698,7 +698,7 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_providers.assert_called_once_with(str(app.tenant_id)) - def test_get_agent_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of specific agent provider. """ @@ -720,7 +720,9 @@ class TestAgentService: mock_plugin_client = mock_external_service_dependencies["plugin_agent_client"].return_value mock_plugin_client.fetch_agent_strategy_provider.assert_called_once_with(str(app.tenant_id), provider_name) - def test_get_agent_provider_plugin_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_provider_plugin_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when plugin daemon client raises an error. """ @@ -741,7 +743,7 @@ class TestAgentService: AgentService.get_agent_provider(str(account.id), str(app.tenant_id), provider_name) def test_get_agent_logs_with_complex_tool_data( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with complex tool data and multiple tools. @@ -752,8 +754,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with multiple tools complex_thought = MessageAgentThought( message_id=message.id, @@ -799,8 +799,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(complex_thought) - db.session.commit() + db_session_with_containers.add(complex_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -831,7 +831,7 @@ class TestAgentService: assert tool_calls[2]["status"] == "success" assert tool_calls[2]["tool_icon"] == "" # dataset-retrieval tools have empty icon - def test_get_agent_logs_with_files(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_files(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test agent logs retrieval with message files and agent thought files. """ @@ -841,8 +841,7 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from core.workflow.file import FileTransferMethod, FileType - from extensions.ext_database import db + from dify_graph.file import FileTransferMethod, FileType from models.enums import CreatorUserRole # Add files to message @@ -854,7 +853,7 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file1.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) @@ -863,13 +862,13 @@ class TestAgentService: type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, url="http://example.com/file2.png", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role=CreatorUserRole.ACCOUNT, created_by=message.from_account_id, ) - db.session.add(message_file1) - db.session.add(message_file2) - db.session.commit() + db_session_with_containers.add(message_file1) + db_session_with_containers.add(message_file2) + db_session_with_containers.commit() # Create agent thought with files thought_with_files = MessageAgentThought( @@ -895,8 +894,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(thought_with_files) - db.session.commit() + db_session_with_containers.add(thought_with_files) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -912,7 +911,7 @@ class TestAgentService: assert "file2" in iterations[0]["files"] def test_get_agent_logs_with_different_timezone( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test agent logs retrieval with different timezone settings. @@ -938,7 +937,9 @@ class TestAgentService: assert "T" in start_time # ISO format assert "+08:00" in start_time or "Z" in start_time # Timezone offset - def test_get_agent_logs_with_empty_tool_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_empty_tool_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with empty tool data. """ @@ -948,8 +949,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with empty tool data empty_thought = MessageAgentThought( message_id=message.id, @@ -964,8 +963,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(empty_thought) - db.session.commit() + db_session_with_containers.add(empty_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) @@ -979,7 +978,9 @@ class TestAgentService: tool_calls = iterations[0]["tool_calls"] assert len(tool_calls) == 0 # No tools to process - def test_get_agent_logs_with_malformed_json(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_agent_logs_with_malformed_json( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test agent logs retrieval with malformed JSON data in tool fields. """ @@ -989,8 +990,6 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from extensions.ext_database import db - # Create agent thought with malformed JSON malformed_thought = MessageAgentThought( message_id=message.id, @@ -1005,8 +1004,8 @@ class TestAgentService: created_by_role="account", created_by=message.from_account_id, ) - db.session.add(malformed_thought) - db.session.commit() + db_session_with_containers.add(malformed_thought) + db_session_with_containers.commit() # Execute the method under test result = AgentService.get_agent_logs(app, str(conversation.id), str(message.id)) diff --git a/api/tests/test_containers_integration_tests/services/test_annotation_service.py b/api/tests/test_containers_integration_tests/services/test_annotation_service.py index 4f5190e533..95fc73f45a 100644 --- a/api/tests/test_containers_integration_tests/services/test_annotation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_annotation_service.py @@ -2,12 +2,15 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account +from models.enums import ConversationFromSource, InvokeFrom from models.model import MessageAnnotation from services.annotation_service import AppAnnotationService from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAnnotationService: @@ -52,7 +55,7 @@ class TestAnnotationService: "current_user": mock_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -77,7 +80,7 @@ class TestAnnotationService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -115,11 +118,10 @@ class TestAnnotationService: tenant_id, ) - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -135,23 +137,22 @@ class TestAnnotationService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -174,18 +175,18 @@ class TestAnnotationService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_insert_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation. @@ -211,9 +212,8 @@ class TestAnnotationService: assert annotation.id is not None # Verify annotation was saved to database - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.id is not None # Verify add_annotation_to_index_task was called (when annotation setting exists) @@ -221,7 +221,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_insert_app_annotation_directly_requires_question( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Question must be provided when inserting annotations directly. @@ -238,7 +238,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id) def test_insert_app_annotation_directly_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test direct insertion of app annotation when app is not found. @@ -260,7 +260,7 @@ class TestAnnotationService: AppAnnotationService.insert_app_annotation_directly(annotation_args, non_existent_app_id) def test_update_app_annotation_directly_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation. @@ -298,7 +298,7 @@ class TestAnnotationService: mock_external_service_dependencies["update_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_new( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating new annotation from message. @@ -307,8 +307,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { @@ -333,7 +333,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_update( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test updating existing annotation from message. @@ -342,8 +342,8 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial annotation initial_args = { @@ -373,7 +373,7 @@ class TestAnnotationService: mock_external_service_dependencies["add_task"].delay.assert_not_called() def test_up_insert_app_annotation_from_message_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message when app is not found. @@ -395,7 +395,7 @@ class TestAnnotationService: AppAnnotationService.up_insert_app_annotation_from_message(annotation_args, non_existent_app_id) def test_get_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of annotation list by app ID. @@ -428,7 +428,7 @@ class TestAnnotationService: assert annotation.account_id == account.id def test_get_annotation_list_by_app_id_with_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list with keyword search. @@ -462,7 +462,7 @@ class TestAnnotationService: assert unique_keyword in annotation_list[0].question or unique_keyword in annotation_list[0].content def test_get_annotation_list_by_app_id_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test retrieval of annotation list with special characters in keyword to verify SQL injection prevention. @@ -534,7 +534,7 @@ class TestAnnotationService: assert all("50%" in (item.question or "") or "50%" in (item.content or "") for item in annotation_list) def test_get_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of annotation list when app is not found. @@ -549,7 +549,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="App not found"): AppAnnotationService.get_annotation_list_by_app_id(non_existent_app_id, page=1, limit=10, keyword="") - def test_delete_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of app annotation. """ @@ -568,16 +570,19 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - from extensions.ext_database import db - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called (when annotation setting exists) # Note: In this test, no annotation setting exists, so task should not be called mock_external_service_dependencies["delete_task"].delay.assert_not_called() - def test_delete_app_annotation_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_annotation_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deletion of app annotation when app is not found. """ @@ -593,7 +598,7 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(non_existent_app_id, annotation_id) def test_delete_app_annotation_annotation_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test deletion of app annotation when annotation is not found. @@ -606,7 +611,9 @@ class TestAnnotationService: with pytest.raises(NotFound, match="Annotation not found"): AppAnnotationService.delete_app_annotation(app.id, non_existent_annotation_id) - def test_enable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of app annotation. """ @@ -632,7 +639,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["enable_task"].delay.assert_called_once() - def test_disable_app_annotation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_app_annotation_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of app annotation. """ @@ -651,7 +660,9 @@ class TestAnnotationService: # Verify task was called mock_external_service_dependencies["disable_task"].delay.assert_called_once() - def test_enable_app_annotation_cached_job(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_app_annotation_cached_job( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test enabling app annotation when job is already cached. """ @@ -685,7 +696,9 @@ class TestAnnotationService: # Clean up redis_client.delete(enable_app_annotation_key) - def test_get_annotation_hit_histories_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_hit_histories_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation hit histories. """ @@ -709,7 +722,7 @@ class TestAnnotationService: query=f"Query {i}: {fake.sentence()}", user_id=account.id, message_id=fake.uuid4(), - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=0.8 + (i * 0.1), ) @@ -728,7 +741,9 @@ class TestAnnotationService: assert history.app_id == app.id assert history.account_id == account.id - def test_add_annotation_history_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_annotation_history_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful addition of annotation history. """ @@ -758,21 +773,20 @@ class TestAnnotationService: query=query, user_id=account.id, message_id=message_id, - from_source="console", + from_source=ConversationFromSource.CONSOLE, score=score, ) # Verify hit count was incremented - from extensions.ext_database import db - db.session.refresh(annotation) + db_session_with_containers.refresh(annotation) assert annotation.hit_count == initial_hit_count + 1 # Verify history was created from models.model import AppAnnotationHitHistory history = ( - db.session.query(AppAnnotationHitHistory) + db_session_with_containers.query(AppAnnotationHitHistory) .where( AppAnnotationHitHistory.annotation_id == annotation.id, AppAnnotationHitHistory.message_id == message_id ) @@ -786,7 +800,9 @@ class TestAnnotationService: assert history.score == score assert history.source == "console" - def test_get_annotation_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_annotation_by_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of annotation by ID. """ @@ -811,7 +827,9 @@ class TestAnnotationService: assert retrieved_annotation.content == annotation_args["answer"] assert retrieved_annotation.account_id == account.id - def test_batch_import_app_annotations_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_batch_import_app_annotations_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful batch import of app annotations. """ @@ -854,7 +872,7 @@ class TestAnnotationService: mock_external_service_dependencies["batch_import_task"].delay.assert_called_once() def test_batch_import_app_annotations_empty_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import with empty CSV file. @@ -889,7 +907,7 @@ class TestAnnotationService: assert "empty" in result["error_msg"].lower() def test_batch_import_app_annotations_quota_exceeded( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test batch import when quota is exceeded. @@ -935,7 +953,7 @@ class TestAnnotationService: assert "limit" in result["error_msg"].lower() def test_get_app_annotation_setting_by_app_id_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting enabled app annotation setting by app ID. @@ -944,7 +962,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -956,8 +973,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -967,8 +984,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Get annotation setting result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) @@ -981,7 +998,7 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" def test_get_app_annotation_setting_by_app_id_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting disabled app annotation setting by app ID. @@ -996,7 +1013,7 @@ class TestAnnotationService: assert result["enabled"] is False def test_update_app_annotation_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of app annotation setting. @@ -1005,7 +1022,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1017,8 +1033,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1028,8 +1044,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Update annotation setting update_args = { @@ -1046,11 +1062,11 @@ class TestAnnotationService: assert result["embedding_model"]["embedding_model_name"] == "text-embedding-ada-002" # Verify database was updated - db.session.refresh(annotation_setting) + db_session_with_containers.refresh(annotation_setting) assert annotation_setting.score_threshold == 0.9 def test_export_annotation_list_by_app_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful export of annotation list by app ID. @@ -1083,7 +1099,7 @@ class TestAnnotationService: assert annotation.created_at <= exported_annotations[i - 1].created_at def test_export_annotation_list_by_app_id_app_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test export of annotation list when app is not found. @@ -1099,7 +1115,7 @@ class TestAnnotationService: AppAnnotationService.export_annotation_list_by_app_id(non_existent_app_id) def test_insert_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct insertion of app annotation with annotation setting enabled. @@ -1108,7 +1124,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1120,8 +1135,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1131,8 +1146,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Setup annotation data annotation_args = { @@ -1161,7 +1176,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_update_app_annotation_directly_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful direct update of app annotation with annotation setting enabled. @@ -1170,7 +1185,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1182,8 +1196,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1193,8 +1207,8 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # First, create an annotation original_args = { @@ -1234,7 +1248,7 @@ class TestAnnotationService: assert call_args[4] == collection_binding.id # collection_binding_id def test_delete_app_annotation_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful deletion of app annotation with annotation setting enabled. @@ -1243,7 +1257,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1255,8 +1268,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1267,8 +1280,8 @@ class TestAnnotationService: updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create an annotation first annotation_args = { @@ -1285,7 +1298,9 @@ class TestAnnotationService: AppAnnotationService.delete_app_annotation(app.id, annotation_id) # Verify annotation was deleted - deleted_annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + deleted_annotation = ( + db_session_with_containers.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first() + ) assert deleted_annotation is None # Verify delete_annotation_index_task was called @@ -1297,7 +1312,7 @@ class TestAnnotationService: assert call_args[3] == collection_binding.id # collection_binding_id def test_up_insert_app_annotation_from_message_with_setting_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating annotation from message with annotation setting enabled. @@ -1306,7 +1321,6 @@ class TestAnnotationService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create annotation setting first - from extensions.ext_database import db from models.dataset import DatasetCollectionBinding from models.model import AppAnnotationSetting @@ -1318,8 +1332,8 @@ class TestAnnotationService: collection_name=f"annotation_collection_{fake.uuid4()}", ) collection_binding.id = str(fake.uuid4()) - db.session.add(collection_binding) - db.session.flush() + db_session_with_containers.add(collection_binding) + db_session_with_containers.flush() # Create annotation setting annotation_setting = AppAnnotationSetting( @@ -1329,12 +1343,12 @@ class TestAnnotationService: created_user_id=account.id, updated_user_id=account.id, ) - db.session.add(annotation_setting) - db.session.commit() + db_session_with_containers.add(annotation_setting) + db_session_with_containers.commit() # Create a conversation and message first - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Setup annotation data with message_id annotation_args = { diff --git a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py index 8c8be2e670..7ce7357b41 100644 --- a/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_based_extension_service.py @@ -2,10 +2,12 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models.api_based_extension import APIBasedExtension from services.account_service import AccountService, TenantService from services.api_based_extension_service import APIBasedExtensionService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAPIBasedExtensionService: @@ -31,7 +33,7 @@ class TestAPIBasedExtensionService: "requestor_instance": mock_requestor_instance, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -54,14 +56,14 @@ class TestAPIBasedExtensionService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant return account, tenant - def test_save_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful saving of API-based extension. """ @@ -90,15 +92,16 @@ class TestAPIBasedExtensionService: assert saved_extension.created_at is not None # Verify extension was saved to database - from extensions.ext_database import db - db.session.refresh(saved_extension) + db_session_with_containers.refresh(saved_extension) assert saved_extension.id is not None # Verify ping connection was called mock_external_service_dependencies["requestor_instance"].request.assert_called_once() - def test_save_extension_validation_errors(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_validation_errors( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation errors when saving extension with invalid data. """ @@ -132,7 +135,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all extensions by tenant ID. """ @@ -169,7 +174,7 @@ class TestAPIBasedExtensionService: # Verify descending order (newer first) assert extension.created_at <= extension_list[i - 1].created_at - def test_get_with_tenant_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of extension by tenant ID and extension ID. """ @@ -200,7 +205,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_key == extension_data.api_key # Should be decrypted assert retrieved_extension.created_at is not None - def test_get_with_tenant_id_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when extension is not found. """ @@ -214,7 +221,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="API based extension is not found"): APIBasedExtensionService.get_with_tenant_id(tenant.id, non_existent_extension_id) - def test_delete_extension_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_extension_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of extension. """ @@ -238,12 +245,15 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.delete(created_extension) # Verify extension was deleted - from extensions.ext_database import db - deleted_extension = db.session.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + deleted_extension = ( + db_session_with_containers.query(APIBasedExtension).where(APIBasedExtension.id == extension_id).first() + ) assert deleted_extension is None - def test_save_extension_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when saving extension with duplicate name. """ @@ -272,7 +282,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="name must be unique, it is already existed"): APIBasedExtensionService.save(extension_data2) - def test_save_extension_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of existing extension. """ @@ -329,7 +341,9 @@ class TestAPIBasedExtensionService: assert retrieved_extension.api_endpoint == new_endpoint assert retrieved_extension.api_key == new_api_key # Should be decrypted when retrieved - def test_save_extension_connection_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_connection_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test connection error when saving extension with invalid endpoint. """ @@ -356,7 +370,7 @@ class TestAPIBasedExtensionService: APIBasedExtensionService.save(extension_data) def test_save_extension_invalid_api_key_length( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation error when saving extension with API key that is too short. @@ -378,7 +392,7 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must be at least 5 characters"): APIBasedExtensionService.save(extension_data) - def test_save_extension_empty_fields(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_empty_fields(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test validation errors when saving extension with empty required fields. """ @@ -412,7 +426,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="api_key must not be empty"): APIBasedExtensionService.save(extension_data) - def test_get_all_by_tenant_id_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_by_tenant_id_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extensions when no extensions exist for tenant. """ @@ -428,7 +444,9 @@ class TestAPIBasedExtensionService: assert len(extension_list) == 0 assert extension_list == [] - def test_save_extension_invalid_ping_response(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_invalid_ping_response( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is invalid. """ @@ -452,7 +470,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'result': 'invalid'}"): APIBasedExtensionService.save(extension_data) - def test_save_extension_missing_ping_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_extension_missing_ping_result( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test validation error when ping response is missing result field. """ @@ -476,7 +496,9 @@ class TestAPIBasedExtensionService: with pytest.raises(ValueError, match="{'status': 'ok'}"): APIBasedExtensionService.save(extension_data) - def test_get_with_tenant_id_wrong_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_with_tenant_id_wrong_tenant( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of extension when tenant ID doesn't match. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index e2a450b90c..8a362e1f5e 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -9,6 +9,7 @@ from models.model import App, AppModelConfig from services.account_service import AccountService, TenantService from services.app_dsl_service import AppDslService, ImportMode, ImportStatus from services.app_service import AppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAppDslService: @@ -89,7 +90,7 @@ class TestAppDslService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py index 8544d23cdf..5b1a4790f5 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_generate_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_generate_service.py @@ -2,13 +2,16 @@ import uuid from unittest.mock import ANY, MagicMock, patch import pytest +import sqlalchemy as sa from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models.model import EndUser from models.workflow import Workflow from services.app_generate_service import AppGenerateService from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestAppGenerateService: @@ -118,7 +121,9 @@ class TestAppGenerateService: "global_dify_config": mock_global_dify_config, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies, mode="chat"): + def _create_test_app_and_account( + self, db_session_with_containers: Session, mock_external_service_dependencies, mode="chat" + ): """ Helper method to create a test app and account for testing. @@ -144,7 +149,7 @@ class TestAppGenerateService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -169,7 +174,7 @@ class TestAppGenerateService: return app, account - def _create_test_workflow(self, db_session_with_containers, app): + def _create_test_workflow(self, db_session_with_containers: Session, app): """ Helper method to create a test workflow for testing. @@ -191,14 +196,14 @@ class TestAppGenerateService: status="published", ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_generate_completion_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_completion_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for completion mode app. """ @@ -226,7 +231,7 @@ class TestAppGenerateService: mock_external_service_dependencies["completion_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["completion_generator"].convert_to_event_stream.assert_called_once() - def test_generate_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_chat_mode_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful generation for chat mode app. """ @@ -250,7 +255,9 @@ class TestAppGenerateService: mock_external_service_dependencies["chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_agent_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_agent_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for agent chat mode app. """ @@ -274,7 +281,9 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].return_value.generate.assert_called_once() mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() - def test_generate_advanced_chat_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_advanced_chat_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for advanced chat mode app. """ @@ -300,7 +309,9 @@ class TestAppGenerateService: "advanced_chat_generator" ].return_value.convert_to_event_stream.assert_called_once() - def test_generate_workflow_mode_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_workflow_mode_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful generation for workflow mode app. """ @@ -324,7 +335,9 @@ class TestAppGenerateService: mock_external_service_dependencies["message_based_generator"].retrieve_events.assert_called_once() mock_external_service_dependencies["workflow_generator"].convert_to_event_stream.assert_called_once() - def test_generate_with_specific_workflow_id(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_specific_workflow_id( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with a specific workflow ID. """ @@ -355,7 +368,9 @@ class TestAppGenerateService: "workflow_service" ].return_value.get_published_workflow_by_id.assert_called_once() - def test_generate_with_debugger_invoke_from(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_debugger_invoke_from( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with debugger invoke from. """ @@ -378,7 +393,9 @@ class TestAppGenerateService: # Verify draft workflow was fetched for debugger mock_external_service_dependencies["workflow_service"].return_value.get_draft_workflow.assert_called_once() - def test_generate_with_non_streaming_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_non_streaming_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with non-streaming mode. """ @@ -401,7 +418,7 @@ class TestAppGenerateService: # Verify rate limit exit was called for non-streaming mode mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with EndUser instead of Account. """ @@ -421,10 +438,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} @@ -438,7 +453,7 @@ class TestAppGenerateService: assert result == ["test_response"] def test_generate_with_billing_enabled_sandbox_plan( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with billing enabled and sandbox plan. @@ -466,7 +481,9 @@ class TestAppGenerateService: # Verify billing service was called to consume quota mock_external_service_dependencies["billing_service"].update_tenant_feature_plan_usage.assert_called_once() - def test_generate_with_invalid_app_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_invalid_app_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with invalid app mode. """ @@ -476,22 +493,22 @@ class TestAppGenerateService: ) # Manually set invalid mode after creation + # With EnumText, invalid values are rejected at the DB level during flush, + # raising StatementError wrapping ValueError app.mode = "invalid_mode" # Setup test arguments args = {"inputs": {"query": fake.text(max_nb_chars=50)}, "response_mode": "streaming"} - # Execute the method under test and expect ValueError - with pytest.raises(ValueError) as exc_info: + # Execute the method under test and expect either ValueError (direct) or + # StatementError (from EnumText validation during autoflush) + with pytest.raises((ValueError, sa.exc.StatementError)): AppGenerateService.generate( app_model=app, user=account, args=args, invoke_from=InvokeFrom.SERVICE_API, streaming=True ) - # Verify error message - assert "Invalid app mode" in str(exc_info.value) - def test_generate_with_workflow_id_format_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with invalid workflow ID format. @@ -518,7 +535,7 @@ class TestAppGenerateService: assert "Invalid workflow_id format" in str(exc_info.value) def test_generate_with_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not found. @@ -552,7 +569,7 @@ class TestAppGenerateService: assert f"Workflow not found with id: {workflow_id}" in str(exc_info.value) def test_generate_with_workflow_not_initialized_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not initialized for debugger. @@ -578,7 +595,7 @@ class TestAppGenerateService: assert "Workflow not initialized" in str(exc_info.value) def test_generate_with_workflow_not_published_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation when workflow is not published for non-debugger. @@ -604,7 +621,7 @@ class TestAppGenerateService: assert "Workflow not published" in str(exc_info.value) def test_generate_single_iteration_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for advanced chat mode. @@ -631,7 +648,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single iteration generation for workflow mode. @@ -658,7 +675,7 @@ class TestAppGenerateService: ].return_value.single_iteration_generate.assert_called_once() def test_generate_single_iteration_invalid_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test single iteration generation with invalid app mode. @@ -681,7 +698,7 @@ class TestAppGenerateService: assert "Invalid app mode" in str(exc_info.value) def test_generate_single_loop_advanced_chat_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for advanced chat mode. @@ -708,7 +725,7 @@ class TestAppGenerateService: ].return_value.single_loop_generate.assert_called_once() def test_generate_single_loop_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful single loop generation for workflow mode. @@ -732,7 +749,9 @@ class TestAppGenerateService: # Verify workflow generator was called mock_external_service_dependencies["workflow_generator"].return_value.single_loop_generate.assert_called_once() - def test_generate_single_loop_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_single_loop_invalid_mode( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test single loop generation with invalid app mode. """ @@ -753,7 +772,9 @@ class TestAppGenerateService: # Verify error message assert "Invalid app mode" in str(exc_info.value) - def test_generate_more_like_this_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_more_like_this_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful more like this generation. """ @@ -778,7 +799,7 @@ class TestAppGenerateService: ].return_value.generate_more_like_this.assert_called_once() def test_generate_more_like_this_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test more like this generation with EndUser. @@ -799,10 +820,8 @@ class TestAppGenerateService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() message_id = fake.uuid4() @@ -815,7 +834,7 @@ class TestAppGenerateService: assert result == ["more_like_this_response"] def test_get_max_active_requests_with_app_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with app-specific limit. @@ -835,7 +854,7 @@ class TestAppGenerateService: assert result == 10 def test_get_max_active_requests_with_config_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with config limit being smaller. @@ -856,7 +875,7 @@ class TestAppGenerateService: assert result <= 100 def test_get_max_active_requests_with_zero_limits( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting max active requests with zero limits (infinite). @@ -875,7 +894,9 @@ class TestAppGenerateService: # Verify the result (should return config limit when app limit is 0) assert result == 100 # dify_config.APP_MAX_ACTIVE_REQUESTS - def test_generate_with_exception_cleanup(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_exception_cleanup( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that rate limit exit is called when an exception occurs. """ @@ -904,7 +925,9 @@ class TestAppGenerateService: # Verify rate limit exit was called for cleanup mock_external_service_dependencies["rate_limit"].return_value.exit.assert_called_once() - def test_generate_with_agent_mode_detection(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_agent_mode_detection( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test generation with agent mode detection based on app configuration. """ @@ -932,7 +955,7 @@ class TestAppGenerateService: mock_external_service_dependencies["agent_chat_generator"].convert_to_event_stream.assert_called_once() def test_generate_with_different_invoke_from_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test generation with different invoke from values. @@ -962,7 +985,7 @@ class TestAppGenerateService: # Verify the result assert result == ["test_response"] - def test_generate_with_complex_args(self, db_session_with_containers, mock_external_service_dependencies): + def test_generate_with_complex_args(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test generation with complex arguments including files and external trace ID. """ diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index 745d6c97b0..d79f80c009 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -2,11 +2,13 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants.model_template import default_app_templates from models import Account from models.model import App, Site from services.account_service import AccountService, TenantService +from tests.test_containers_integration_tests.helpers import generate_valid_password # Delay import of AppService to avoid circular dependency # from services.app_service import AppService @@ -44,7 +46,7 @@ class TestAppService: "account_feature_service": mock_account_feature_service, } - def test_create_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app creation with basic parameters. """ @@ -55,7 +57,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -98,7 +100,9 @@ class TestAppService: assert app.is_public is False assert app.is_universal is False - def test_create_app_with_different_modes(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_with_different_modes( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app creation with different app modes. """ @@ -109,7 +113,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -141,7 +145,7 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.created_by == account.id - def test_get_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app retrieval. """ @@ -152,7 +156,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -189,7 +193,7 @@ class TestAppService: assert retrieved_app.tenant_id == created_app.tenant_id assert retrieved_app.created_by == created_app.created_by - def test_get_paginate_apps_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful paginated app list retrieval. """ @@ -200,7 +204,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -243,7 +247,9 @@ class TestAppService: assert app.tenant_id == tenant.id assert app.mode == "chat" - def test_get_paginate_apps_with_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with various filters. """ @@ -254,7 +260,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -316,7 +322,9 @@ class TestAppService: my_apps = app_service.get_paginate_apps(account.id, tenant.id, created_by_me_args) assert len(my_apps.items) == 1 - def test_get_paginate_apps_with_tag_filters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_apps_with_tag_filters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test paginated app list with tag filters. """ @@ -327,7 +335,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -386,7 +394,7 @@ class TestAppService: # Should return None when no apps match tag filter assert paginated_apps is None - def test_update_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app update with all fields. """ @@ -397,7 +405,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -455,7 +463,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app name update. """ @@ -466,7 +474,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -508,7 +516,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_icon_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app icon update. """ @@ -519,7 +527,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -565,7 +573,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app site status update. """ @@ -576,7 +586,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -623,7 +633,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_api_status_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_api_status_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful app API status update. """ @@ -634,7 +646,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -681,7 +693,9 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_update_app_site_status_no_change(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_app_site_status_no_change( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app site status update when status doesn't change. """ @@ -692,7 +706,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -732,7 +746,7 @@ class TestAppService: assert updated_app.tenant_id == app.tenant_id assert updated_app.created_by == app.created_by - def test_delete_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app deletion. """ @@ -743,7 +757,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -778,12 +792,13 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_delete_app_with_related_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_app_with_related_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app deletion with related data cleanup. """ @@ -794,7 +809,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -839,12 +854,11 @@ class TestAppService: mock_delete_task.delay.assert_called_once_with(tenant_id=tenant.id, app_id=app_id) # Verify app was deleted from database - from extensions.ext_database import db - deleted_app = db.session.query(App).filter_by(id=app_id).first() + deleted_app = db_session_with_containers.query(App).filter_by(id=app_id).first() assert deleted_app is None - def test_get_app_meta_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_meta_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app metadata retrieval. """ @@ -855,7 +869,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -883,7 +897,7 @@ class TestAppService: assert "tool_icons" in app_meta # Note: get_app_meta currently only returns tool_icons - def test_get_app_code_by_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_code_by_id_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app code retrieval by app ID. """ @@ -894,7 +908,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -923,7 +937,7 @@ class TestAppService: assert app_code is not None assert len(app_code) > 0 - def test_get_app_id_by_code_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_id_by_code_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful app ID retrieval by app code. """ @@ -934,7 +948,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -963,10 +977,9 @@ class TestAppService: site.status = "normal" site.default_language = "en-US" site.customize_token_strategy = "uuid" - from extensions.ext_database import db - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Get app ID by code app_id = AppService.get_app_id_by_code(site.code) @@ -974,7 +987,7 @@ class TestAppService: # Verify app ID was retrieved correctly assert app_id == app.id - def test_create_app_invalid_mode(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_app_invalid_mode(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test app creation with invalid mode. """ @@ -985,7 +998,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -1010,7 +1023,7 @@ class TestAppService: app_service.create_app(tenant.id, app_args, account) def test_get_apps_with_special_characters_in_name( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test app retrieval with special characters in name search to verify SQL injection prevention. @@ -1027,7 +1040,7 @@ class TestAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant diff --git a/api/tests/test_containers_integration_tests/services/test_attachment_service.py b/api/tests/test_containers_integration_tests/services/test_attachment_service.py new file mode 100644 index 0000000000..768a8baee2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_attachment_service.py @@ -0,0 +1,80 @@ +"""Testcontainers integration tests for AttachmentService.""" + +import base64 +from datetime import UTC, datetime +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from werkzeug.exceptions import NotFound + +import services.attachment_service as attachment_service_module +from extensions.ext_database import db +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import UploadFile +from services.attachment_service import AttachmentService + + +class TestAttachmentService: + def _create_upload_file(self, db_session_with_containers, *, tenant_id: str | None = None) -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id or str(uuid4()), + storage_type=StorageType.OPENDAL, + key=f"upload/{uuid4()}.txt", + name="test-file.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=datetime.now(UTC), + used=False, + ) + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + def test_should_initialize_with_sessionmaker(self): + session_factory = sessionmaker() + + service = AttachmentService(session_factory=session_factory) + + assert service._session_maker is session_factory + + def test_should_initialize_with_engine(self): + engine = create_engine("sqlite:///:memory:") + + service = AttachmentService(session_factory=engine) + session = service._session_maker() + try: + assert session.bind == engine + finally: + session.close() + engine.dispose() + + @pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1]) + def test_should_raise_assertion_error_for_invalid_session_factory(self, invalid_session_factory): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + AttachmentService(session_factory=invalid_session_factory) + + def test_should_return_base64_when_file_exists(self, db_session_with_containers): + upload_file = self._create_upload_file(db_session_with_containers) + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load: + result = service.get_file_base64(upload_file.id) + + assert result == base64.b64encode(b"binary-content").decode() + mock_load.assert_called_once_with(upload_file.key) + + def test_should_raise_not_found_when_file_missing(self, db_session_with_containers): + service = AttachmentService(session_factory=sessionmaker(bind=db.engine)) + + with patch.object(attachment_service_module.storage, "load_once") as mock_load: + with pytest.raises(NotFound, match="File not found"): + service.get_file_base64(str(uuid4())) + + mock_load.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_conversation_service.py index 5f64e6f674..6180d98b1e 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_service.py @@ -10,6 +10,7 @@ from sqlalchemy import select from core.app.entities.app_invoke_entities import InvokeFrom from models.account import Account, Tenant, TenantAccountJoin +from models.enums import ConversationFromSource from models.model import App, Conversation, EndUser, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.conversation_service import ConversationService @@ -107,7 +108,7 @@ class ConversationServiceIntegrationTestDataFactory: system_instruction_tokens=0, status="normal", invoke_from=invoke_from.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, @@ -154,7 +155,7 @@ class ConversationServiceIntegrationTestDataFactory: currency="USD", status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api" if isinstance(user, EndUser) else "console", + from_source=ConversationFromSource.API if isinstance(user, EndUser) else ConversationFromSource.CONSOLE, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, ) diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py new file mode 100644 index 0000000000..42a2215896 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -0,0 +1,58 @@ +"""Testcontainers integration tests for ConversationVariableUpdater.""" + +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import sessionmaker + +from dify_graph.variables import StringVariable +from extensions.ext_database import db +from models.workflow import ConversationVariable +from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater + + +class TestConversationVariableUpdater: + def _create_conversation_variable( + self, db_session_with_containers, *, conversation_id: str, variable: StringVariable, app_id: str | None = None + ) -> ConversationVariable: + row = ConversationVariable( + id=variable.id, + conversation_id=conversation_id, + app_id=app_id or str(uuid4()), + data=variable.model_dump_json(), + ) + db_session_with_containers.add(row) + db_session_with_containers.commit() + return row + + def test_should_update_conversation_variable_data_and_commit(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="old value") + self._create_conversation_variable( + db_session_with_containers, conversation_id=conversation_id, variable=variable + ) + + updated_variable = StringVariable(id=variable.id, name="topic", value="new value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + updater.update(conversation_id=conversation_id, variable=updated_variable) + + db_session_with_containers.expire_all() + row = db_session_with_containers.get(ConversationVariable, (variable.id, conversation_id)) + assert row is not None + assert row.data == updated_variable.model_dump_json() + + def test_should_raise_not_found_when_variable_missing(self, db_session_with_containers): + conversation_id = str(uuid4()) + variable = StringVariable(id=str(uuid4()), name="topic", value="value") + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"): + updater.update(conversation_id=conversation_id, variable=variable) + + def test_should_do_nothing_when_flush_is_called(self, db_session_with_containers): + updater = ConversationVariableUpdater(sessionmaker(bind=db.engine)) + + result = updater.flush() + + assert result is None diff --git a/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py new file mode 100644 index 0000000000..25de0588fa --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_credit_pool_service.py @@ -0,0 +1,103 @@ +"""Testcontainers integration tests for CreditPoolService.""" + +from uuid import uuid4 + +import pytest + +from core.errors.error import QuotaExceededError +from models import TenantCreditPool +from services.credit_pool_service import CreditPoolService + + +class TestCreditPoolService: + def _create_tenant_id(self) -> str: + return str(uuid4()) + + def test_create_default_pool(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + + pool = CreditPoolService.create_default_pool(tenant_id) + + assert isinstance(pool, TenantCreditPool) + assert pool.tenant_id == tenant_id + assert pool.pool_type == "trial" + assert pool.quota_used == 0 + assert pool.quota_limit > 0 + + def test_get_pool_returns_pool_when_exists(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type="trial") + + assert result is not None + assert result.tenant_id == tenant_id + assert result.pool_type == "trial" + + def test_get_pool_returns_none_when_not_exists(self, db_session_with_containers): + result = CreditPoolService.get_pool(tenant_id=self._create_tenant_id(), pool_type="trial") + + assert result is None + + def test_check_credits_available_returns_false_when_no_pool(self, db_session_with_containers): + result = CreditPoolService.check_credits_available(tenant_id=self._create_tenant_id(), credits_required=10) + + assert result is False + + def test_check_credits_available_returns_true_when_sufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=10) + + assert result is True + + def test_check_credits_available_returns_false_when_insufficient(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + # Exhaust credits + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + result = CreditPoolService.check_credits_available(tenant_id=tenant_id, credits_required=1) + + assert result is False + + def test_check_and_deduct_credits_raises_when_no_pool(self, db_session_with_containers): + with pytest.raises(QuotaExceededError, match="Credit pool not found"): + CreditPoolService.check_and_deduct_credits(tenant_id=self._create_tenant_id(), credits_required=10) + + def test_check_and_deduct_credits_raises_when_no_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + pool.quota_used = pool.quota_limit + db_session_with_containers.commit() + + with pytest.raises(QuotaExceededError, match="No credits remaining"): + CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=10) + + def test_check_and_deduct_credits_deducts_required_amount(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + CreditPoolService.create_default_pool(tenant_id) + credits_required = 10 + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=credits_required) + + assert result == credits_required + db_session_with_containers.expire_all() + pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert pool.quota_used == credits_required + + def test_check_and_deduct_credits_caps_at_remaining(self, db_session_with_containers): + tenant_id = self._create_tenant_id() + pool = CreditPoolService.create_default_pool(tenant_id) + remaining = 5 + pool.quota_used = pool.quota_limit - remaining + db_session_with_containers.commit() + + result = CreditPoolService.check_and_deduct_credits(tenant_id=tenant_id, credits_required=200) + + assert result == remaining + db_session_with_containers.expire_all() + updated_pool = CreditPoolService.get_pool(tenant_id=tenant_id) + assert updated_pool.quota_used == pool.quota_limit diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py new file mode 100644 index 0000000000..55bfb64e18 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_permission_service.py @@ -0,0 +1,570 @@ +""" +Container-backed integration tests for dataset permission services on the real SQL path. + +This module exercises persisted DatasetPermission rows and dataset permission +checks with testcontainers-backed infrastructure instead of database-chain mocks. +""" + +from uuid import uuid4 + +import pytest + +from extensions.ext_database import db +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import ( + Dataset, + DatasetPermission, + DatasetPermissionEnum, +) +from models.enums import DataSourceType +from services.dataset_service import DatasetPermissionService, DatasetService +from services.errors.account import NoPermissionError + + +class DatasetPermissionTestDataFactory: + """Create persisted entities and request payloads for dataset permission integration tests.""" + + @staticmethod + def create_account_with_tenant( + role: TenantAccountRole = TenantAccountRole.NORMAL, + tenant: Tenant | None = None, + ) -> tuple[Account, Tenant]: + """Create a real account and tenant with specified role.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + if tenant is None: + tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") + db.session.add_all([account, tenant]) + else: + db.session.add(account) + + db.session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, + name: str = "Test Dataset", + ) -> Dataset: + """Create a real dataset with specified attributes.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description="desc", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique="high_quality", + created_by=created_by, + permission=permission, + provider="vendor", + retrieval_model={"top_k": 2}, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_dataset_permission( + dataset_id: str, + account_id: str, + tenant_id: str, + has_permission: bool = True, + ) -> DatasetPermission: + """Create a real DatasetPermission instance.""" + permission = DatasetPermission( + dataset_id=dataset_id, + account_id=account_id, + tenant_id=tenant_id, + has_permission=has_permission, + ) + db.session.add(permission) + db.session.commit() + return permission + + @staticmethod + def build_user_list_payload(user_ids: list[str]) -> list[dict[str, str]]: + """Build the request payload shape used by partial-member list updates.""" + return [{"user_id": user_id} for user_id in user_ids] + + +class TestDatasetPermissionServiceGetPartialMemberList: + """Verify partial-member list reads against persisted DatasetPermission rows.""" + + def test_get_dataset_partial_member_list_with_members(self, db_session_with_containers): + """ + Test retrieving partial member list with multiple members. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + user_3, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user_1.id, user_2.id, user_3.id] + for account_id in expected_account_ids: + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, account_id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 3 + + def test_get_dataset_partial_member_list_with_single_member(self, db_session_with_containers): + """ + Test retrieving partial member list with single member. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + expected_account_ids = [user.id] + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert set(result) == set(expected_account_ids) + assert len(result) == 1 + + def test_get_dataset_partial_member_list_empty(self, db_session_with_containers): + """ + Test retrieving partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + + # Assert + assert result == [] + assert len(result) == 0 + + +class TestDatasetPermissionServiceUpdatePartialMemberList: + """Verify partial-member list updates against persisted DatasetPermission rows.""" + + def test_update_partial_member_list_add_new_members(self, db_session_with_containers): + """ + Test adding new partial members to a dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {member_1.id, member_2.id} + + def test_update_partial_member_list_replace_existing(self, db_session_with_containers): + """ + Test replacing existing partial members with new ones. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + old_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + old_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + new_member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + old_users = DatasetPermissionTestDataFactory.build_user_list_payload([old_member_1.id, old_member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, old_users) + + new_users = DatasetPermissionTestDataFactory.build_user_list_payload([new_member_1.id, new_member_2.id]) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, new_users) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert set(result) == {new_member_1.id, new_member_2.id} + + def test_update_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test updating with empty member list (clearing all members). + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, []) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_update_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + existing_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + replacement_member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + DatasetPermissionService.update_partial_member_list( + tenant.id, + dataset.id, + DatasetPermissionTestDataFactory.build_user_list_payload([existing_member.id]), + ) + user_list = DatasetPermissionTestDataFactory.build_user_list_payload([replacement_member.id]) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, user_list) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert result == [existing_member.id] + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 1 + + +class TestDatasetPermissionServiceClearPartialMemberList: + """Verify partial-member clearing against persisted DatasetPermission rows.""" + + def test_clear_partial_member_list_success(self, db_session_with_containers): + """ + Test successful clearing of partial member list. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_empty_list(self, db_session_with_containers): + """ + Test clearing partial member list when no members exist. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + + # Act + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert result == [] + + def test_clear_partial_member_list_database_error_rollback(self, db_session_with_containers): + """ + Test error handling and rollback on database error. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + member_1, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + member_2, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + dataset = DatasetPermissionTestDataFactory.create_dataset(tenant.id, owner.id) + users = DatasetPermissionTestDataFactory.build_user_list_payload([member_1.id, member_2.id]) + DatasetPermissionService.update_partial_member_list(tenant.id, dataset.id, users) + rollback_called = {"count": 0} + original_rollback = db.session.rollback + + # Act / Assert + with pytest.MonkeyPatch.context() as mp: + + def _raise_commit(): + raise Exception("Database connection error") + + def _rollback_and_mark(): + rollback_called["count"] += 1 + original_rollback() + + mp.setattr("services.dataset_service.db.session.commit", _raise_commit) + mp.setattr("services.dataset_service.db.session.rollback", _rollback_and_mark) + with pytest.raises(Exception, match="Database connection error"): + DatasetPermissionService.clear_partial_member_list(dataset.id) + + # Assert + result = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert rollback_called["count"] == 1 + assert set(result) == {member_1.id, member_2.id} + assert db_session_with_containers.query(DatasetPermission).filter_by(dataset_id=dataset.id).count() == 2 + + +class TestDatasetServiceCheckDatasetPermission: + """Verify dataset access checks against persisted partial-member permissions.""" + + def test_check_dataset_permission_different_tenant_should_fail(self, db_session_with_containers): + """Test that users from different tenants cannot access dataset.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + other_user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, owner.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other_user) + + def test_check_dataset_permission_owner_can_access_any_dataset(self, db_session_with_containers): + """Test that tenant owners can access any dataset regardless of permission level.""" + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + creator, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, owner) + + def test_check_dataset_permission_only_me_creator_can_access(self, db_session_with_containers): + """Test ONLY_ME permission allows only the dataset creator to access.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + DatasetService.check_dataset_permission(dataset, creator) + + def test_check_dataset_permission_only_me_others_cannot_access(self, db_session_with_containers): + """Test ONLY_ME permission denies access to non-creators.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + other, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ONLY_ME + ) + + with pytest.raises(NoPermissionError): + DatasetService.check_dataset_permission(dataset, other) + + def test_check_dataset_permission_all_team_allows_access(self, db_session_with_containers): + """Test ALL_TEAM permission allows any team member to access the dataset.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + member, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, tenant=tenant + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.ALL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, member) + + def test_check_dataset_permission_partial_members_with_permission_success(self, db_session_with_containers): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_permission(dataset, user) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_permission_partial_members_without_permission_error(self, db_session_with_containers): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_partial_team_creator_can_access(self, db_session_with_containers): + """Test PARTIAL_TEAM permission allows creator to access without explicit permission.""" + creator, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.EDITOR) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, creator.id, permission=DatasetPermissionEnum.PARTIAL_TEAM + ) + + DatasetService.check_dataset_permission(dataset, creator) + + +class TestDatasetServiceCheckDatasetOperatorPermission: + """Verify operator permission checks against persisted partial-member permissions.""" + + def test_check_dataset_operator_permission_partial_members_with_permission_success( + self, db_session_with_containers + ): + """ + Test that user with explicit permission can access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + DatasetPermissionTestDataFactory.create_dataset_permission(dataset.id, user.id, tenant.id) + + # Act (should not raise) + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + # Assert + permissions = DatasetPermissionService.get_dataset_partial_member_list(dataset.id) + assert user.id in permissions + + def test_check_dataset_operator_permission_partial_members_without_permission_error( + self, db_session_with_containers + ): + """ + Test error when user without permission tries to access partial_members dataset. + """ + # Arrange + owner, tenant = DatasetPermissionTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + user, _ = DatasetPermissionTestDataFactory.create_account_with_tenant( + role=TenantAccountRole.NORMAL, + tenant=tenant, + ) + + dataset = DatasetPermissionTestDataFactory.create_dataset( + tenant.id, + owner.id, + permission=DatasetPermissionEnum.PARTIAL_TEAM, + ) + + # Act & Assert + with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index f05c47913e..ac3d9f9604 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,14 +9,16 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from core.model_runtime.entities.model_entities import ModelType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from extensions.ext_database import db +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole -from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings +from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline +from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import RerankingModel, RetrievalModel +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity from services.errors.dataset import DatasetNameDuplicateError @@ -24,7 +26,9 @@ class DatasetServiceIntegrationDataFactory: """Factory for creating real database entities used by integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create an account and tenant, then bind the account as current tenant member.""" account = Account( email=f"{uuid4()}@example.com", @@ -33,8 +37,8 @@ class DatasetServiceIntegrationDataFactory: status="active", ) tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -42,8 +46,8 @@ class DatasetServiceIntegrationDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.flush() + db_session_with_containers.add(join) + db_session_with_containers.flush() # Keep tenant context on the in-memory user without opening a separate session. account.role = role @@ -52,6 +56,7 @@ class DatasetServiceIntegrationDataFactory: @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -70,7 +75,7 @@ class DatasetServiceIntegrationDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, @@ -81,28 +86,30 @@ class DatasetServiceIntegrationDataFactory: collection_binding_id=collection_binding_id, chunk_structure=chunk_structure, ) - db.session.add(dataset) - db.session.flush() + db_session_with_containers.add(dataset) + db_session_with_containers.flush() return dataset @staticmethod - def create_document(dataset: Dataset, created_by: str, name: str = "doc.txt") -> Document: + def create_document( + db_session_with_containers: Session, dataset: Dataset, created_by: str, name: str = "doc.txt" + ) -> Document: """Create a document row belonging to the given dataset.""" document = Document( tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info='{"upload_file_id": "upload-file-id"}', batch=str(uuid4()), name=name, - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) - db.session.add(document) - db.session.flush() + db_session_with_containers.add(document) + db_session_with_containers.flush() return document @staticmethod @@ -117,10 +124,10 @@ class DatasetServiceIntegrationDataFactory: class TestDatasetServiceCreateDataset: """Integration coverage for DatasetService.create_empty_dataset.""" - def test_create_internal_dataset_basic_success(self, db_session_with_containers): + def test_create_internal_dataset_basic_success(self, db_session_with_containers: Session): """Create a basic internal dataset with minimal configuration.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -132,17 +139,17 @@ class TestDatasetServiceCreateDataset: ) # Assert - created_dataset = db.session.get(Dataset, result.id) + created_dataset = db_session_with_containers.get(Dataset, result.id) assert created_dataset is not None assert created_dataset.provider == "vendor" assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME assert created_dataset.embedding_model_provider is None assert created_dataset.embedding_model is None - def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_economy_indexing(self, db_session_with_containers: Session): """Create an internal dataset with economy indexing and no embedding model.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) # Act result = DatasetService.create_empty_dataset( @@ -154,15 +161,15 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "economy" assert result.embedding_model_provider is None assert result.embedding_model is None - def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers): + def test_create_internal_dataset_with_high_quality_indexing(self, db_session_with_containers: Session): """Create a high-quality dataset and persist embedding model settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act @@ -178,7 +185,7 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.indexing_technique == "high_quality" assert result.embedding_model_provider == embedding_model.provider assert result.embedding_model == embedding_model.model_name @@ -187,11 +194,12 @@ class TestDatasetServiceCreateDataset: model_type=ModelType.TEXT_EMBEDDING, ) - def test_create_dataset_duplicate_name_error(self, db_session_with_containers): + def test_create_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Raise duplicate-name error when the same tenant already has the name.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Duplicate Dataset", @@ -208,10 +216,10 @@ class TestDatasetServiceCreateDataset: account=account, ) - def test_create_external_dataset_success(self, db_session_with_containers): + def test_create_external_dataset_success(self, db_session_with_containers: Session): """Create an external dataset and persist external knowledge binding.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) external_knowledge_api_id = str(uuid4()) external_knowledge_id = "knowledge-123" @@ -230,16 +238,16 @@ class TestDatasetServiceCreateDataset: ) # Assert - binding = db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() + binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(dataset_id=result.id).first() assert result.provider == "external" assert binding is not None assert binding.external_knowledge_id == external_knowledge_id assert binding.external_knowledge_api_id == external_knowledge_api_id - def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers): + def test_create_dataset_with_retrieval_model_and_reranking(self, db_session_with_containers: Session): """Create a high-quality dataset with retrieval/reranking settings.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() retrieval_model = RetrievalModel( search_method=RetrievalMethod.SEMANTIC_SEARCH, @@ -270,24 +278,299 @@ class TestDatasetServiceCreateDataset: ) # Assert - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.retrieval_model == retrieval_model.model_dump() mock_check_reranking.assert_called_once_with(tenant.id, "cohere", "rerank-english-v2.0") + def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( + self, db_session_with_containers: Session + ): + """Create high-quality dataset with explicitly configured embedding model.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + embedding_provider = "openai" + embedding_model_name = "text-embedding-3-small" + embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model( + provider=embedding_provider, model_name=embedding_model_name + ) + + # Act + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, + ): + mock_model_manager.return_value.get_model_instance.return_value = embedding_model + + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Embedding Dataset", + description=None, + indexing_technique="high_quality", + account=account, + embedding_model_provider=embedding_provider, + embedding_model_name=embedding_model_name, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.indexing_technique == "high_quality" + assert result.embedding_model_provider == embedding_provider + assert result.embedding_model == embedding_model_name + mock_check_embedding.assert_called_once_with(tenant.id, embedding_provider, embedding_model_name) + mock_model_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id=tenant.id, + provider=embedding_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=embedding_model_name, + ) + + def test_create_internal_dataset_with_retrieval_model(self, db_session_with_containers: Session): + """Persist retrieval model settings when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=False, + top_k=2, + score_threshold_enabled=True, + score_threshold=0.0, + ) + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Retrieval Model Dataset", + description=None, + indexing_technique=None, + account=account, + retrieval_model=retrieval_model, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.retrieval_model == retrieval_model.model_dump() + + def test_create_internal_dataset_with_custom_permission(self, db_session_with_containers: Session): + """Persist canonical custom permission when creating an internal dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + + # Act + result = DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="Custom Permission Dataset", + description=None, + indexing_technique=None, + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): + """Raise error when external API template does not exist.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = None + with pytest.raises(ValueError, match=r"External API template not found\.?"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing API Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id="knowledge-123", + ) + + def test_create_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): + """Raise error when external knowledge id is missing for external dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + external_knowledge_api_id = str(uuid4()) + + # Act / Assert + with patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api") as mock_get_api: + mock_get_api.return_value = Mock(id=external_knowledge_api_id) + with pytest.raises(ValueError, match="external_knowledge_id is required"): + DatasetService.create_empty_dataset( + tenant_id=tenant.id, + name="External Missing Knowledge Dataset", + description=None, + indexing_technique=None, + account=account, + provider="external", + external_knowledge_api_id=external_knowledge_api_id, + external_knowledge_id=None, + ) + + +class TestDatasetServiceCreateRagPipelineDataset: + """Integration coverage for DatasetService.create_empty_rag_pipeline_dataset.""" + + def test_create_rag_pipeline_dataset_with_name_success(self, db_session_with_containers: Session): + """Create rag-pipeline dataset and pipeline rows when a name is provided.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="RAG Pipeline Dataset", + description="RAG Pipeline Description", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + created_dataset = db_session_with_containers.get(Dataset, result.id) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) + assert created_dataset is not None + assert created_dataset.name == entity.name + assert created_dataset.runtime_mode == DatasetRuntimeMode.RAG_PIPELINE + assert created_dataset.created_by == account.id + assert created_dataset.permission == DatasetPermissionEnum.ONLY_ME + assert created_pipeline is not None + assert created_pipeline.name == entity.name + assert created_pipeline.created_by == account.id + + def test_create_rag_pipeline_dataset_with_auto_generated_name(self, db_session_with_containers: Session): + """Create rag-pipeline dataset with generated incremental name when input name is empty.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + generated_name = "Untitled 1" + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with ( + patch("services.dataset_service.current_user", account), + patch("services.dataset_service.generate_incremental_name") as mock_generate_name, + ): + mock_generate_name.return_value = generated_name + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + created_pipeline = db_session_with_containers.get(Pipeline, result.pipeline_id) + assert result.name == generated_name + assert created_pipeline is not None + assert created_pipeline.name == generated_name + mock_generate_name.assert_called_once() + + def test_create_rag_pipeline_dataset_duplicate_name_error(self, db_session_with_containers: Session): + """Raise duplicate-name error when rag-pipeline dataset name already exists.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + duplicate_name = "Duplicate RAG Dataset" + DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=account.id, + name=duplicate_name, + indexing_technique=None, + ) + db_session_with_containers.commit() + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name=duplicate_name, + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act / Assert + with ( + patch("services.dataset_service.current_user", account), + pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {duplicate_name} already exists"), + ): + DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + def test_create_rag_pipeline_dataset_with_custom_permission(self, db_session_with_containers: Session): + """Persist canonical custom permission for rag-pipeline dataset creation.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") + entity = RagPipelineDatasetCreateEntity( + name="Custom Permission RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.permission == DatasetPermissionEnum.ALL_TEAM + + def test_create_rag_pipeline_dataset_with_icon_info(self, db_session_with_containers: Session): + """Persist icon metadata when creating rag-pipeline dataset.""" + # Arrange + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + icon_info = IconInfo( + icon="📚", + icon_background="#E8F5E9", + icon_type="emoji", + icon_url="https://example.com/icon.png", + ) + entity = RagPipelineDatasetCreateEntity( + name="Icon Info RAG Dataset", + description="", + icon_info=icon_info, + permission=DatasetPermissionEnum.ONLY_ME, + ) + + # Act + with patch("services.dataset_service.current_user", account): + result = DatasetService.create_empty_rag_pipeline_dataset( + tenant_id=tenant.id, rag_pipeline_dataset_create_entity=entity + ) + + # Assert + db_session_with_containers.refresh(result) + assert result.icon_info == icon_info.model_dump() + class TestDatasetServiceUpdateAndDeleteDataset: """Integration coverage for SQL-backed update and delete behavior.""" - def test_update_dataset_duplicate_name_error(self, db_session_with_containers): + def test_update_dataset_duplicate_name_error(self, db_session_with_containers: Session): """Reject update when target name already exists within the same tenant.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) source_dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Source Dataset", ) DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Existing Dataset", @@ -297,17 +580,20 @@ class TestDatasetServiceUpdateAndDeleteDataset: with pytest.raises(ValueError, match="Dataset name already exists"): DatasetService.update_dataset(source_dataset.id, {"name": "Existing Dataset"}, account) - def test_delete_dataset_with_documents_success(self, db_session_with_containers): + def test_delete_dataset_with_documents_success(self, db_session_with_containers: Session): """Delete a dataset that already has documents.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", chunk_structure="text_model", ) - DatasetServiceIntegrationDataFactory.create_document(dataset=dataset, created_by=account.id) + DatasetServiceIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, created_by=account.id + ) # Act with patch("services.dataset_service.dataset_was_deleted") as dataset_deleted_signal: @@ -315,14 +601,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_empty_dataset_success(self, db_session_with_containers): + def test_delete_empty_dataset_success(self, db_session_with_containers: Session): """Delete a dataset that has no documents and no indexing technique.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -335,14 +622,15 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) - def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers: Session): """Delete dataset when indexing_technique is None but doc_form path still exists.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique=None, @@ -355,17 +643,17 @@ class TestDatasetServiceUpdateAndDeleteDataset: # Assert assert result is True - assert db.session.get(Dataset, dataset.id) is None + assert db_session_with_containers.get(Dataset, dataset.id) is None dataset_deleted_signal.send.assert_called_once_with(dataset) class TestDatasetServiceRetrievalConfiguration: """Integration coverage for retrieval configuration persistence.""" - def test_get_dataset_retrieval_configuration(self, db_session_with_containers): + def test_get_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Return retrieval configuration that is persisted in SQL.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) retrieval_model = { "search_method": "semantic_search", "top_k": 5, @@ -373,6 +661,7 @@ class TestDatasetServiceRetrievalConfiguration: "reranking_enable": True, } dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, retrieval_model=retrieval_model, @@ -387,11 +676,12 @@ class TestDatasetServiceRetrievalConfiguration: assert result.retrieval_model["search_method"] == "semantic_search" assert result.retrieval_model["top_k"] == 5 - def test_update_dataset_retrieval_configuration(self, db_session_with_containers): + def test_update_dataset_retrieval_configuration(self, db_session_with_containers: Session): """Persist retrieval configuration updates through DatasetService.update_dataset.""" # Arrange - account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant() + account, tenant = DatasetServiceIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetServiceIntegrationDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, indexing_technique="high_quality", @@ -413,6 +703,6 @@ class TestDatasetServiceRetrievalConfiguration: result = DatasetService.update_dataset(dataset.id, update_data, account) # Assert - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert result.id == dataset.id assert dataset.retrieval_model == update_data["retrieval_model"] diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py new file mode 100644 index 0000000000..7983b1cd93 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_batch_update_document_status.py @@ -0,0 +1,696 @@ +"""Integration tests for DocumentService.batch_update_document_status. + +This suite validates SQL-backed batch status updates with testcontainers. +It keeps database access real and only patches non-DB side effects. +""" + +import datetime +import json +from dataclasses import dataclass +from unittest.mock import call, patch +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus +from services.dataset_service import DocumentService +from services.errors.document import DocumentIndexingError + +FIXED_TIME = datetime.datetime(2023, 1, 1, 12, 0, 0) + + +@dataclass +class UserDouble: + """Minimal user object for batch update operations.""" + + id: str + + +class DocumentBatchUpdateIntegrationDataFactory: + """Factory for creating persisted entities used in integration tests.""" + + @staticmethod + def create_dataset( + db_session_with_containers: Session, + dataset_id: str | None = None, + tenant_id: str | None = None, + name: str = "Test Dataset", + created_by: str | None = None, + ) -> Dataset: + """Create and persist a dataset.""" + dataset = Dataset( + tenant_id=tenant_id or str(uuid4()), + name=name, + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=created_by or str(uuid4()), + ) + if dataset_id: + dataset.id = dataset_id + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers: Session, + dataset: Dataset, + document_id: str | None = None, + name: str = "test_document.pdf", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + completed_at: datetime.datetime | None = None, + position: int = 1, + created_by: str | None = None, + commit: bool = True, + **kwargs, + ) -> Document: + """Create a document bound to the given dataset and persist it.""" + document = Document( + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=position, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info=json.dumps({"upload_file_id": str(uuid4())}), + batch=f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=created_by or str(uuid4()), + doc_form="text_model", + ) + document.id = document_id or str(uuid4()) + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.completed_at = ( + completed_at + if completed_at is not None + else (FIXED_TIME if indexing_status == IndexingStatus.COMPLETED else None) + ) + + for key, value in kwargs.items(): + setattr(document, key, value) + + db_session_with_containers.add(document) + if commit: + db_session_with_containers.commit() + return document + + @staticmethod + def create_multiple_documents( + db_session_with_containers: Session, + dataset: Dataset, + document_ids: list[str], + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + ) -> list[Document]: + """Create and persist multiple documents for one dataset in a single transaction.""" + documents: list[Document] = [] + for index, doc_id in enumerate(document_ids, start=1): + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + document_id=doc_id, + name=f"document_{doc_id}.pdf", + enabled=enabled, + archived=archived, + indexing_status=indexing_status, + position=index, + commit=False, + ) + documents.append(document) + db_session_with_containers.commit() + return documents + + @staticmethod + def create_user(user_id: str | None = None) -> UserDouble: + """Create a lightweight user for update metadata fields.""" + return UserDouble(id=user_id or str(uuid4())) + + +class TestDatasetServiceBatchUpdateDocumentStatus: + """Integration coverage for batch document status updates.""" + + @pytest.fixture + def patched_dependencies(self): + """Patch non-DB collaborators only.""" + with ( + patch("services.dataset_service.redis_client") as redis_client, + patch("services.dataset_service.add_document_to_index_task") as add_task, + patch("services.dataset_service.remove_document_from_index_task") as remove_task, + patch("services.dataset_service.naive_utc_now") as naive_utc_now, + ): + naive_utc_now.return_value = FIXED_TIME + redis_client.get.return_value = None + yield { + "redis_client": redis_client, + "add_task": add_task, + "remove_task": remove_task, + "naive_utc_now": naive_utc_now, + } + + def _assert_document_enabled(self, document: Document, current_time: datetime.datetime): + """Verify enabled-state fields after action=enable.""" + assert document.enabled is True + assert document.disabled_at is None + assert document.disabled_by is None + assert document.updated_at == current_time + + def _assert_document_disabled(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify disabled-state fields after action=disable.""" + assert document.enabled is False + assert document.disabled_at == current_time + assert document.disabled_by == user_id + assert document.updated_at == current_time + + def _assert_document_archived(self, document: Document, user_id: str, current_time: datetime.datetime): + """Verify archived-state fields after action=archive.""" + assert document.archived is True + assert document.archived_at == current_time + assert document.archived_by == user_id + assert document.updated_at == current_time + + def _assert_document_unarchived(self, document: Document): + """Verify unarchived-state fields after action=un_archive.""" + assert document.archived is False + assert document.archived_at is None + assert document.archived_by is None + + def test_batch_update_enable_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Enable disabled documents and trigger indexing side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + disabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, document_ids=document_ids, action="enable", user=user + ) + + # Assert + for document in disabled_docs: + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_add_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_add_calls) + + def test_batch_update_enable_already_enabled_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip enable operation for already-enabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.enabled is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_disable_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Disable completed documents and trigger remove-index tasks.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()), str(uuid4())] + enabled_docs = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=True, + indexing_status=IndexingStatus.COMPLETED, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="disable", + user=user, + ) + + # Assert + for document in enabled_docs: + db_session_with_containers.refresh(document) + self._assert_document_disabled(document, user.id, FIXED_TIME) + + expected_get_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_remove_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].get.assert_has_calls(expected_get_calls) + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["remove_task"].delay.assert_has_calls(expected_remove_calls) + + def test_batch_update_disable_already_disabled_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip disable operation for already-disabled documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=False, + indexing_status=IndexingStatus.COMPLETED, + completed_at=FIXED_TIME, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[disabled_doc.id], + action="disable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(disabled_doc) + assert disabled_doc.enabled is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_disable_non_completed_document_error( + self, db_session_with_containers: Session, patched_dependencies + ): + """Raise error when disabling a non-completed document.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + non_completed_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + indexing_status=IndexingStatus.INDEXING, + completed_at=None, + ) + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[non_completed_doc.id], + action="disable", + user=user, + ) + + def test_batch_update_archive_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Archive enabled documents and trigger remove-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["remove_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_archive_already_archived_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip archive operation for already-archived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is True + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_archive_disabled_document_no_index_removal( + self, db_session_with_containers: Session, patched_dependencies + ): + """Archive disabled document without index-removal side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_archived(document, user.id, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["remove_task"].delay.assert_not_called() + + def test_batch_update_unarchive_documents_success(self, db_session_with_containers: Session, patched_dependencies): + """Unarchive enabled documents and trigger add-index task.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(document.id) + + def test_batch_update_unarchive_already_unarchived_document_skipped( + self, db_session_with_containers: Session, patched_dependencies + ): + """Skip unarchive operation for already-unarchived documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, archived=False + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + assert document.archived is False + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_unarchive_disabled_document_no_index_addition( + self, db_session_with_containers: Session, patched_dependencies + ): + """Unarchive disabled document without index-add side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False, archived=True + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="un_archive", + user=user, + ) + + # Assert + db_session_with_containers.refresh(document) + self._assert_document_unarchived(document) + assert document.updated_at == FIXED_TIME + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_document_indexing_error_redis_cache_hit( + self, db_session_with_containers: Session, patched_dependencies + ): + """Raise DocumentIndexingError when redis indicates active indexing.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + name="test_document.pdf", + enabled=True, + ) + patched_dependencies["redis_client"].get.return_value = "indexing" + + # Act / Assert + with pytest.raises(DocumentIndexingError, match="is being indexed") as exc_info: + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + assert "test_document.pdf" in str(exc_info.value) + patched_dependencies["redis_client"].get.assert_called_once_with(f"document_{document.id}_indexing") + + def test_batch_update_async_task_error_handling(self, db_session_with_containers: Session, patched_dependencies): + """Persist DB update, then propagate async task error.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + patched_dependencies["add_task"].delay.side_effect = Exception("Celery task error") + + # Act / Assert + with pytest.raises(Exception, match="Celery task error"): + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[document.id], + action="enable", + user=user, + ) + + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + def test_batch_update_empty_document_list(self, db_session_with_containers: Session, patched_dependencies): + """Return early when document_ids is empty.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + + # Act + result = DocumentService.batch_update_document_status( + dataset=dataset, document_ids=[], action="enable", user=user + ) + + # Assert + assert result is None + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + + def test_batch_update_document_not_found_skipped(self, db_session_with_containers: Session, patched_dependencies): + """Skip IDs that do not map to existing dataset documents.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + missing_document_id = str(uuid4()) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=[missing_document_id], + action="enable", + user=user, + ) + + # Assert + patched_dependencies["redis_client"].get.assert_not_called() + patched_dependencies["redis_client"].setex.assert_not_called() + patched_dependencies["add_task"].delay.assert_not_called() + + def test_batch_update_mixed_document_states_and_actions( + self, db_session_with_containers: Session, patched_dependencies + ): + """Process only the applicable document in a mixed-state enable batch.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + disabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + enabled_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + position=2, + ) + archived_doc = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + archived=True, + position=3, + ) + document_ids = [disabled_doc.id, enabled_doc.id, archived_doc.id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(disabled_doc) + db_session_with_containers.refresh(enabled_doc) + db_session_with_containers.refresh(archived_doc) + self._assert_document_enabled(disabled_doc, FIXED_TIME) + assert enabled_doc.enabled is True + assert archived_doc.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with( + f"document_{disabled_doc.id}_indexing", + 600, + 1, + ) + patched_dependencies["add_task"].delay.assert_called_once_with(disabled_doc.id) + + def test_batch_update_large_document_list_performance( + self, db_session_with_containers: Session, patched_dependencies + ): + """Handle large document lists with consistent updates and side effects.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + document_ids = [str(uuid4()) for _ in range(100)] + documents = DocumentBatchUpdateIntegrationDataFactory.create_multiple_documents( + db_session_with_containers, + dataset=dataset, + document_ids=document_ids, + enabled=False, + ) + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + for document in documents: + db_session_with_containers.refresh(document) + self._assert_document_enabled(document, FIXED_TIME) + + assert patched_dependencies["redis_client"].setex.call_count == len(document_ids) + assert patched_dependencies["add_task"].delay.call_count == len(document_ids) + + expected_setex_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] + expected_task_calls = [call(doc_id) for doc_id in document_ids] + patched_dependencies["redis_client"].setex.assert_has_calls(expected_setex_calls) + patched_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) + + def test_batch_update_mixed_document_states_complex_scenario( + self, db_session_with_containers: Session, patched_dependencies + ): + """Process a complex mixed-state batch and update only eligible records.""" + # Arrange + dataset = DocumentBatchUpdateIntegrationDataFactory.create_dataset(db_session_with_containers) + user = DocumentBatchUpdateIntegrationDataFactory.create_user() + doc1 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=False + ) + doc2 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=2 + ) + doc3 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=3 + ) + doc4 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, dataset=dataset, enabled=True, position=4 + ) + doc5 = DocumentBatchUpdateIntegrationDataFactory.create_document( + db_session_with_containers, + dataset=dataset, + enabled=True, + archived=True, + position=5, + ) + missing_id = str(uuid4()) + + document_ids = [doc1.id, doc2.id, doc3.id, doc4.id, doc5.id, missing_id] + + # Act + DocumentService.batch_update_document_status( + dataset=dataset, + document_ids=document_ids, + action="enable", + user=user, + ) + + # Assert + db_session_with_containers.refresh(doc1) + db_session_with_containers.refresh(doc2) + db_session_with_containers.refresh(doc3) + db_session_with_containers.refresh(doc4) + db_session_with_containers.refresh(doc5) + self._assert_document_enabled(doc1, FIXED_TIME) + assert doc2.enabled is True + assert doc3.enabled is True + assert doc4.enabled is True + assert doc5.enabled is True + + patched_dependencies["redis_client"].setex.assert_called_once_with(f"document_{doc1.id}_indexing", 600, 1) + patched_dependencies["add_task"].delay.assert_called_once_with(doc1.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py new file mode 100644 index 0000000000..ed070527c9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_delete_dataset.py @@ -0,0 +1,245 @@ +"""Container-backed integration tests for DatasetService.delete_dataset real SQL paths.""" + +from unittest.mock import patch +from uuid import uuid4 + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom +from services.dataset_service import DatasetService + + +class DatasetDeleteIntegrationDataFactory: + """Create persisted entities used by delete_dataset integration tests.""" + + @staticmethod + def create_account_with_tenant(db_session_with_containers) -> tuple[Account, Tenant]: + """Persist an owner account, tenant, and tenant join for dataset deletion tests.""" + account = Account( + email=f"owner-{uuid4()}@example.com", + name="Owner", + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant( + name=f"tenant-{uuid4()}", + status="normal", + ) + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + db_session_with_containers, + tenant_id: str, + created_by: str, + *, + indexing_technique: str | None, + chunk_structure: str | None, + index_struct: str | None = '{"type": "paragraph"}', + collection_binding_id: str | None = None, + pipeline_id: str | None = None, + ) -> Dataset: + """Persist a dataset with delete_dataset-relevant fields configured.""" + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + indexing_technique=indexing_technique, + index_struct=index_struct, + created_by=created_by, + collection_binding_id=collection_binding_id, + pipeline_id=pipeline_id, + chunk_structure=chunk_structure, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + @staticmethod + def create_document( + db_session_with_containers, + *, + tenant_id: str, + dataset_id: str, + created_by: str, + doc_form: str = "text_model", + ) -> Document: + """Persist a document so dataset.doc_form resolves through the real document path.""" + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + batch=f"batch-{uuid4()}", + name="Document", + created_from=DocumentCreatedFrom.WEB, + created_by=created_by, + doc_form=doc_form, + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + return document + + +class TestDatasetServiceDeleteDataset: + """Integration coverage for DatasetService.delete_dataset using testcontainers.""" + + def test_delete_dataset_with_documents_success(self, db_session_with_containers): + """Delete a dataset with documents and dispatch cleanup through the real signal handler.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + DatasetDeleteIntegrationDataFactory.create_document( + db_session_with_containers, + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=owner.id, + doc_form="text_model", + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_called_once_with( + dataset.id, + dataset.tenant_id, + dataset.indexing_technique, + dataset.index_struct, + dataset.collection_binding_id, + dataset.doc_form, + dataset.pipeline_id, + ) + + def test_delete_empty_dataset_success(self, db_session_with_containers): + """Delete an empty dataset without scheduling cleanup when both gating fields are absent.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure=None, + index_struct=None, + collection_binding_id=None, + pipeline_id=None, + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_partial_none_values(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing_technique is missing but doc_form resolves.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique=None, + chunk_structure="text_model", + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, db_session_with_containers): + """Delete a dataset without cleanup when indexing exists but doc_form resolves to None.""" + # Arrange + owner, tenant = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetDeleteIntegrationDataFactory.create_dataset( + db_session_with_containers, + tenant_id=tenant.id, + created_by=owner.id, + indexing_technique="high_quality", + chunk_structure=None, + index_struct='{"type": "paragraph"}', + collection_binding_id=str(uuid4()), + pipeline_id=str(uuid4()), + ) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(dataset.id, owner) + + # Assert + db_session_with_containers.expire_all() + assert result is True + assert db_session_with_containers.get(Dataset, dataset.id) is None + clean_dataset_delay.assert_not_called() + + def test_delete_dataset_not_found(self, db_session_with_containers): + """Return False without scheduling cleanup when the target dataset does not exist.""" + # Arrange + owner, _ = DatasetDeleteIntegrationDataFactory.create_account_with_tenant(db_session_with_containers) + missing_dataset_id = str(uuid4()) + + # Act + with patch( + "events.event_handlers.clean_when_dataset_deleted.clean_dataset_task.delay", + autospec=True, + ) as clean_dataset_delay: + result = DatasetService.delete_dataset(missing_dataset_id, owner) + + # Assert + assert result is False + clean_dataset_delay.assert_not_called() diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py index 6effe795e2..c4b3a57bb2 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_get_segments.py @@ -10,9 +10,11 @@ Tests the retrieval of document segments with pagination and filtering: from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom from services.dataset_service import SegmentService @@ -23,6 +25,7 @@ class SegmentServiceTestDataFactory: @staticmethod def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER, tenant: Tenant | None = None, ) -> tuple[Account, Tenant]: @@ -33,13 +36,13 @@ class SegmentServiceTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() if tenant is None: tenant = Tenant(name=f"tenant-{uuid4()}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -47,49 +50,52 @@ class SegmentServiceTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_dataset(tenant_id: str, created_by: str) -> Dataset: + def create_dataset(db_session_with_containers: Session, tenant_id: str, created_by: str) -> Dataset: """Create a real dataset.""" dataset = Dataset( tenant_id=tenant_id, name=f"Test Dataset {uuid4()}", description="Test description", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=DatasetPermissionEnum.ONLY_ME, provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_document(tenant_id: str, dataset_id: str, created_by: str) -> Document: + def create_document( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str + ) -> Document: """Create a real document.""" document = Document( tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=f"batch-{uuid4()}", name=f"test-doc-{uuid4()}.txt", - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=created_by, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document @staticmethod def create_segment( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, document_id: str, @@ -112,8 +118,8 @@ class SegmentServiceTestDataFactory: tokens=tokens, created_by=created_by, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment @@ -130,7 +136,7 @@ class TestSegmentServiceGetSegments: - Combined filters """ - def test_get_segments_basic_pagination(self, db_session_with_containers): + def test_get_segments_basic_pagination(self, db_session_with_containers: Session): """ Test basic pagination functionality. @@ -140,11 +146,14 @@ class TestSegmentServiceGetSegments: - Returns segments and total count """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) segment1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -153,6 +162,7 @@ class TestSegmentServiceGetSegments: content="First segment", ) segment2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -170,7 +180,7 @@ class TestSegmentServiceGetSegments: assert items[0].id == segment1.id assert items[1].id == segment2.id - def test_get_segments_with_status_filter(self, db_session_with_containers): + def test_get_segments_with_status_filter(self, db_session_with_containers: Session): """ Test filtering by status list. @@ -179,11 +189,14 @@ class TestSegmentServiceGetSegments: - Only segments with matching status are returned """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -192,6 +205,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -200,6 +214,7 @@ class TestSegmentServiceGetSegments: status="indexing", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -219,7 +234,7 @@ class TestSegmentServiceGetSegments: statuses = {item.status for item in items} assert statuses == {"completed", "indexing"} - def test_get_segments_with_empty_status_list(self, db_session_with_containers): + def test_get_segments_with_empty_status_list(self, db_session_with_containers: Session): """ Test with empty status list. @@ -228,11 +243,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied to avoid WHERE false condition """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -241,6 +259,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -256,7 +275,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_with_keyword_search(self, db_session_with_containers): + def test_get_segments_with_keyword_search(self, db_session_with_containers: Session): """ Test keyword search functionality. @@ -265,11 +284,14 @@ class TestSegmentServiceGetSegments: - Search pattern includes wildcards (%keyword%) """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -278,6 +300,7 @@ class TestSegmentServiceGetSegments: content="This contains search term in the middle", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -294,7 +317,7 @@ class TestSegmentServiceGetSegments: assert total == 1 assert "search term" in items[0].content - def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers): + def test_get_segments_ordering_by_position_and_id(self, db_session_with_containers: Session): """ Test ordering by position and id. @@ -304,12 +327,15 @@ class TestSegmentServiceGetSegments: - This prevents duplicate data across pages when positions are not unique """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with different positions seg_pos2 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -318,6 +344,7 @@ class TestSegmentServiceGetSegments: content="Position 2", ) seg_pos1 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -326,6 +353,7 @@ class TestSegmentServiceGetSegments: content="Position 1", ) seg_pos3 = SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -344,7 +372,7 @@ class TestSegmentServiceGetSegments: assert items[1].id == seg_pos2.id assert items[2].id == seg_pos3.id - def test_get_segments_empty_results(self, db_session_with_containers): + def test_get_segments_empty_results(self, db_session_with_containers: Session): """ Test when no segments match the criteria. @@ -353,7 +381,7 @@ class TestSegmentServiceGetSegments: - Total count is 0 """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) non_existent_doc_id = str(uuid4()) # Act @@ -363,7 +391,7 @@ class TestSegmentServiceGetSegments: assert items == [] assert total == 0 - def test_get_segments_combined_filters(self, db_session_with_containers): + def test_get_segments_combined_filters(self, db_session_with_containers: Session): """ Test with multiple filters combined. @@ -372,12 +400,15 @@ class TestSegmentServiceGetSegments: - Status list and keyword search both applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create segments with various statuses and content SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -387,6 +418,7 @@ class TestSegmentServiceGetSegments: content="This is important information", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -396,6 +428,7 @@ class TestSegmentServiceGetSegments: content="This is also important", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -421,7 +454,7 @@ class TestSegmentServiceGetSegments: assert items[0].status == "completed" assert "important" in items[0].content - def test_get_segments_with_none_status_list(self, db_session_with_containers): + def test_get_segments_with_none_status_list(self, db_session_with_containers: Session): """ Test with None status list. @@ -430,11 +463,14 @@ class TestSegmentServiceGetSegments: - No status filter is applied """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -443,6 +479,7 @@ class TestSegmentServiceGetSegments: status="completed", ) SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, @@ -462,7 +499,7 @@ class TestSegmentServiceGetSegments: assert len(items) == 2 assert total == 2 - def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers): + def test_get_segments_pagination_max_per_page_limit(self, db_session_with_containers: Session): """ Test that max_per_page is correctly set to 100. @@ -471,13 +508,16 @@ class TestSegmentServiceGetSegments: - This prevents excessive page sizes """ # Arrange - owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant() - dataset = SegmentServiceTestDataFactory.create_dataset(tenant.id, owner.id) - document = SegmentServiceTestDataFactory.create_document(tenant.id, dataset.id, owner.id) + owner, tenant = SegmentServiceTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = SegmentServiceTestDataFactory.create_dataset(db_session_with_containers, tenant.id, owner.id) + document = SegmentServiceTestDataFactory.create_document( + db_session_with_containers, tenant.id, dataset.id, owner.id + ) # Create 105 segments to exceed max_per_page of 100 for i in range(105): SegmentServiceTestDataFactory.create_segment( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, document_id=document.id, diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py index f605a286ed..3021d8984d 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_retrieval.py @@ -13,7 +13,8 @@ This test suite covers: import json from uuid import uuid4 -from extensions.ext_database import db +from sqlalchemy.orm import Session + from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -23,6 +24,7 @@ from models.dataset import ( DatasetProcessRule, DatasetQuery, ) +from models.enums import DatasetQuerySource, DataSourceType, ProcessRuleMode from models.model import Tag, TagBinding from services.dataset_service import DatasetService, DocumentService @@ -31,7 +33,9 @@ class DatasetRetrievalTestDataFactory: """Factory class for creating database-backed test data for dataset retrieval integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.NORMAL) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.NORMAL + ) -> tuple[Account, Tenant]: """Create an account and tenant with the specified role.""" account = Account( email=f"{uuid4()}@example.com", @@ -43,8 +47,8 @@ class DatasetRetrievalTestDataFactory: name=f"tenant-{uuid4()}", status="normal", ) - db.session.add_all([account, tenant]) - db.session.flush() + db_session_with_containers.add_all([account, tenant]) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -52,14 +56,16 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod - def create_account_in_tenant(tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER) -> Account: + def create_account_in_tenant( + db_session_with_containers: Session, tenant: Tenant, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> Account: """Create an account and add it to an existing tenant.""" account = Account( email=f"{uuid4()}@example.com", @@ -67,8 +73,8 @@ class DatasetRetrievalTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() join = TenantAccountJoin( tenant_id=tenant.id, @@ -76,14 +82,15 @@ class DatasetRetrievalTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, name: str = "Test Dataset", @@ -94,19 +101,21 @@ class DatasetRetrievalTestDataFactory: tenant_id=tenant_id, name=name, description="desc", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=created_by, permission=permission, provider="vendor", retrieval_model={"top_k": 2}, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod - def create_dataset_permission(dataset_id: str, tenant_id: str, account_id: str) -> DatasetPermission: + def create_dataset_permission( + db_session_with_containers: Session, dataset_id: str, tenant_id: str, account_id: str + ) -> DatasetPermission: """Create a dataset permission.""" permission = DatasetPermission( dataset_id=dataset_id, @@ -114,12 +123,14 @@ class DatasetRetrievalTestDataFactory: account_id=account_id, has_permission=True, ) - db.session.add(permission) - db.session.commit() + db_session_with_containers.add(permission) + db_session_with_containers.commit() return permission @staticmethod - def create_process_rule(dataset_id: str, created_by: str, mode: str, rules: dict) -> DatasetProcessRule: + def create_process_rule( + db_session_with_containers: Session, dataset_id: str, created_by: str, mode: str, rules: dict + ) -> DatasetProcessRule: """Create a dataset process rule.""" process_rule = DatasetProcessRule( dataset_id=dataset_id, @@ -127,38 +138,40 @@ class DatasetRetrievalTestDataFactory: mode=mode, rules=json.dumps(rules), ) - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule @staticmethod - def create_dataset_query(dataset_id: str, created_by: str, content: str) -> DatasetQuery: + def create_dataset_query( + db_session_with_containers: Session, dataset_id: str, created_by: str, content: str + ) -> DatasetQuery: """Create a dataset query.""" dataset_query = DatasetQuery( dataset_id=dataset_id, content=content, - source="web", + source=DatasetQuerySource.APP, source_app_id=None, created_by_role="account", created_by=created_by, ) - db.session.add(dataset_query) - db.session.commit() + db_session_with_containers.add(dataset_query) + db_session_with_containers.commit() return dataset_query @staticmethod - def create_app_dataset_join(dataset_id: str) -> AppDatasetJoin: + def create_app_dataset_join(db_session_with_containers: Session, dataset_id: str) -> AppDatasetJoin: """Create an app-dataset join.""" join = AppDatasetJoin( app_id=str(uuid4()), dataset_id=dataset_id, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return join @staticmethod - def create_tag_binding(tenant_id: str, created_by: str, target_id: str) -> Tag: + def create_tag_binding(db_session_with_containers: Session, tenant_id: str, created_by: str, target_id: str) -> Tag: """Create a knowledge tag and bind it to the target dataset.""" tag = Tag( tenant_id=tenant_id, @@ -166,8 +179,8 @@ class DatasetRetrievalTestDataFactory: name=f"tag-{uuid4()}", created_by=created_by, ) - db.session.add(tag) - db.session.flush() + db_session_with_containers.add(tag) + db_session_with_containers.flush() binding = TagBinding( tenant_id=tenant_id, @@ -175,8 +188,8 @@ class DatasetRetrievalTestDataFactory: target_id=target_id, created_by=created_by, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return tag @@ -195,15 +208,16 @@ class TestDatasetServiceGetDatasets: # ==================== Basic Retrieval Tests ==================== - def test_get_datasets_basic_pagination(self, db_session_with_containers): + def test_get_datasets_basic_pagination(self, db_session_with_containers: Session): """Test basic pagination without user or filters.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 for i in range(5): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"Dataset {i}", @@ -217,21 +231,23 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 5 assert total == 5 - def test_get_datasets_with_search(self, db_session_with_containers): + def test_get_datasets_with_search(self, db_session_with_containers: Session): """Test get_datasets with search keyword.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 search = "test" DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Test Dataset", permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name="Another Dataset", @@ -245,26 +261,32 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_with_tag_filtering(self, db_session_with_containers): + def test_get_datasets_with_tag_filtering(self, db_session_with_containers: Session): """Test get_datasets with tag_ids filtering.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 dataset_1 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) dataset_2 = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) - tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_1.id) - tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding(tenant.id, account.id, dataset_2.id) + tag_1 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_1.id + ) + tag_2 = DatasetRetrievalTestDataFactory.create_tag_binding( + db_session_with_containers, tenant.id, account.id, dataset_2.id + ) tag_ids = [tag_1.id, tag_2.id] # Act @@ -274,16 +296,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 2 assert total == 2 - def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers): + def test_get_datasets_with_empty_tag_ids(self, db_session_with_containers: Session): """Test get_datasets with empty tag_ids skips tag filtering and returns all matching datasets.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 tag_ids = [] for i in range(3): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, name=f"dataset-{i}", @@ -300,19 +323,21 @@ class TestDatasetServiceGetDatasets: # ==================== Permission-Based Filtering Tests ==================== - def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers): + def test_get_datasets_without_user_shows_only_all_team(self, db_session_with_containers: Session): """Test that without user, only ALL_TEAM datasets are shown.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) page = 1 per_page = 20 DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ALL_TEAM, ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -325,15 +350,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_owner_with_include_all(self, db_session_with_containers): + def test_get_datasets_owner_with_include_all(self, db_session_with_containers: Session): """Test that OWNER with include_all=True sees all datasets.""" # Arrange - owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + owner, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) for i, permission in enumerate( [DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM] ): DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, name=f"dataset-{i}", @@ -353,12 +381,15 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 3 assert total == 3 - def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_only_me_permission(self, db_session_with_containers: Session): """Test that normal user sees ONLY_ME datasets they created.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, permission=DatasetPermissionEnum.ONLY_ME, @@ -371,13 +402,18 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_all_team_permission(self, db_session_with_containers: Session): """Test that normal user sees ALL_TEAM datasets.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -390,18 +426,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers): + def test_get_datasets_normal_user_partial_team_with_permission(self, db_session_with_containers: Session): """Test that normal user sees PARTIAL_TEAM datasets they have permission for.""" # Arrange - user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) + user, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER + ) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.PARTIAL_TEAM, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, user.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, user.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=user) @@ -410,20 +453,25 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_with_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR only sees datasets they have explicit permission for.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ONLY_ME, ) - DatasetRetrievalTestDataFactory.create_dataset_permission(dataset.id, tenant.id, operator.id) + DatasetRetrievalTestDataFactory.create_dataset_permission( + db_session_with_containers, dataset.id, tenant.id, operator.id + ) # Act datasets, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id=tenant.id, user=operator) @@ -432,14 +480,17 @@ class TestDatasetServiceGetDatasets: assert len(datasets) == 1 assert total == 1 - def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers): + def test_get_datasets_dataset_operator_without_permissions(self, db_session_with_containers: Session): """Test that DATASET_OPERATOR without permissions returns empty result.""" # Arrange operator, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant( - role=TenantAccountRole.DATASET_OPERATOR + db_session_with_containers, role=TenantAccountRole.DATASET_OPERATOR + ) + owner = DatasetRetrievalTestDataFactory.create_account_in_tenant( + db_session_with_containers, tenant, role=TenantAccountRole.OWNER ) - owner = DatasetRetrievalTestDataFactory.create_account_in_tenant(tenant, role=TenantAccountRole.OWNER) DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, permission=DatasetPermissionEnum.ALL_TEAM, @@ -456,11 +507,13 @@ class TestDatasetServiceGetDatasets: class TestDatasetServiceGetDataset: """Comprehensive integration tests for DatasetService.get_dataset method.""" - def test_get_dataset_success(self, db_session_with_containers): + def test_get_dataset_success(self, db_session_with_containers: Session): """Test successful retrieval of a single dataset.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_dataset(dataset.id) @@ -469,7 +522,7 @@ class TestDatasetServiceGetDataset: assert result is not None assert result.id == dataset.id - def test_get_dataset_not_found(self, db_session_with_containers): + def test_get_dataset_not_found(self, db_session_with_containers: Session): """Test retrieval when dataset doesn't exist.""" # Arrange dataset_id = str(uuid4()) @@ -484,12 +537,15 @@ class TestDatasetServiceGetDataset: class TestDatasetServiceGetDatasetsByIds: """Comprehensive integration tests for DatasetService.get_datasets_by_ids method.""" - def test_get_datasets_by_ids_success(self, db_session_with_containers): + def test_get_datasets_by_ids_success(self, db_session_with_containers: Session): """Test successful bulk retrieval of datasets by IDs.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) datasets = [ - DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) for _ in range(3) + DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) + for _ in range(3) ] dataset_ids = [dataset.id for dataset in datasets] @@ -501,7 +557,7 @@ class TestDatasetServiceGetDatasetsByIds: assert total == 3 assert all(dataset.id in dataset_ids for dataset in result_datasets) - def test_get_datasets_by_ids_empty_list(self, db_session_with_containers): + def test_get_datasets_by_ids_empty_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with empty list returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -514,7 +570,7 @@ class TestDatasetServiceGetDatasetsByIds: assert datasets == [] assert total == 0 - def test_get_datasets_by_ids_none_list(self, db_session_with_containers): + def test_get_datasets_by_ids_none_list(self, db_session_with_containers: Session): """Test get_datasets_by_ids with None returns empty result.""" # Arrange tenant_id = str(uuid4()) @@ -530,20 +586,23 @@ class TestDatasetServiceGetDatasetsByIds: class TestDatasetServiceGetProcessRules: """Comprehensive integration tests for DatasetService.get_process_rules method.""" - def test_get_process_rules_with_existing_rule(self, db_session_with_containers): + def test_get_process_rules_with_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when rule exists.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) rules_data = { "pre_processing_rules": [{"id": "remove_extra_spaces", "enabled": True}], "segmentation": {"delimiter": "\n", "max_tokens": 500}, } DatasetRetrievalTestDataFactory.create_process_rule( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, - mode="custom", + mode=ProcessRuleMode.CUSTOM, rules=rules_data, ) @@ -554,11 +613,13 @@ class TestDatasetServiceGetProcessRules: assert result["mode"] == "custom" assert result["rules"] == rules_data - def test_get_process_rules_without_existing_rule(self, db_session_with_containers): + def test_get_process_rules_without_existing_rule(self, db_session_with_containers: Session): """Test retrieval of process rules when no rule exists (returns defaults).""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_process_rules(dataset.id) @@ -572,16 +633,19 @@ class TestDatasetServiceGetProcessRules: class TestDatasetServiceGetDatasetQueries: """Comprehensive integration tests for DatasetService.get_dataset_queries method.""" - def test_get_dataset_queries_success(self, db_session_with_containers): + def test_get_dataset_queries_success(self, db_session_with_containers: Session): """Test successful retrieval of dataset queries.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 for i in range(3): DatasetRetrievalTestDataFactory.create_dataset_query( + db_session_with_containers, dataset_id=dataset.id, created_by=account.id, content=f"query-{i}", @@ -595,11 +659,13 @@ class TestDatasetServiceGetDatasetQueries: assert total == 3 assert all(query.dataset_id == dataset.id for query in queries) - def test_get_dataset_queries_empty_result(self, db_session_with_containers): + def test_get_dataset_queries_empty_result(self, db_session_with_containers: Session): """Test retrieval when no queries exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) page = 1 per_page = 20 @@ -614,14 +680,16 @@ class TestDatasetServiceGetDatasetQueries: class TestDatasetServiceGetRelatedApps: """Comprehensive integration tests for DatasetService.get_related_apps method.""" - def test_get_related_apps_success(self, db_session_with_containers): + def test_get_related_apps_success(self, db_session_with_containers: Session): """Test successful retrieval of related apps.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) for _ in range(2): - DatasetRetrievalTestDataFactory.create_app_dataset_join(dataset.id) + DatasetRetrievalTestDataFactory.create_app_dataset_join(db_session_with_containers, dataset.id) # Act result = DatasetService.get_related_apps(dataset.id) @@ -630,11 +698,13 @@ class TestDatasetServiceGetRelatedApps: assert len(result) == 2 assert all(join.dataset_id == dataset.id for join in result) - def test_get_related_apps_empty_result(self, db_session_with_containers): + def test_get_related_apps_empty_result(self, db_session_with_containers: Session): """Test retrieval when no related apps exist.""" # Arrange - account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant() - dataset = DatasetRetrievalTestDataFactory.create_dataset(tenant_id=tenant.id, created_by=account.id) + account, tenant = DatasetRetrievalTestDataFactory.create_account_with_tenant(db_session_with_containers) + dataset = DatasetRetrievalTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=account.id + ) # Act result = DatasetService.get_related_apps(dataset.id) diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index f6d9dfddae..fd81948247 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,11 +2,12 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from sqlalchemy.orm import Session -from core.model_runtime.entities.model_entities import ModelType -from extensions.ext_database import db +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings +from models.enums import DataSourceType from services.dataset_service import DatasetService from services.errors.account import NoPermissionError @@ -15,7 +16,9 @@ class DatasetUpdateTestDataFactory: """Factory class for creating real test data for dataset update integration tests.""" @staticmethod - def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + def create_account_with_tenant( + db_session_with_containers: Session, role: TenantAccountRole = TenantAccountRole.OWNER + ) -> tuple[Account, Tenant]: """Create a real account and tenant with the given role.""" account = Account( email=f"{uuid4()}@example.com", @@ -23,12 +26,12 @@ class DatasetUpdateTestDataFactory: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant(name=f"tenant-{account.id}", status="normal") - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() join = TenantAccountJoin( tenant_id=tenant.id, @@ -36,14 +39,15 @@ class DatasetUpdateTestDataFactory: role=role, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() account.current_tenant = tenant return account, tenant @staticmethod def create_dataset( + db_session_with_containers: Session, tenant_id: str, created_by: str, provider: str = "vendor", @@ -61,7 +65,7 @@ class DatasetUpdateTestDataFactory: tenant_id=tenant_id, name=name, description=description, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique=indexing_technique, created_by=created_by, provider=provider, @@ -71,12 +75,13 @@ class DatasetUpdateTestDataFactory: embedding_model=embedding_model, collection_binding_id=collection_binding_id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset @staticmethod def create_external_binding( + db_session_with_containers: Session, tenant_id: str, dataset_id: str, created_by: str, @@ -93,8 +98,8 @@ class DatasetUpdateTestDataFactory: external_knowledge_id=external_knowledge_id, external_knowledge_api_id=external_knowledge_api_id, ) - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() return binding @@ -112,10 +117,11 @@ class TestDatasetServiceUpdateDataset: # ==================== External Dataset Tests ==================== - def test_update_external_dataset_success(self, db_session_with_containers): + def test_update_external_dataset_success(self, db_session_with_containers: Session): """Test successful update of external dataset.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -124,12 +130,13 @@ class TestDatasetServiceUpdateDataset: retrieval_model="old_model", ) binding = DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, ) binding_id = binding.id - db.session.expunge(binding) + db_session_with_containers.expunge(binding) update_data = { "name": "new_name", @@ -142,8 +149,8 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) - updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() + db_session_with_containers.refresh(dataset) + updated_binding = db_session_with_containers.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -153,15 +160,17 @@ class TestDatasetServiceUpdateDataset: assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] assert result.id == dataset.id - def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers: Session): """Test error when external knowledge id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -173,17 +182,19 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): + def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers: Session): """Test error when external knowledge api id is missing.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", ) DatasetUpdateTestDataFactory.create_external_binding( + db_session_with_containers, tenant_id=tenant.id, dataset_id=dataset.id, created_by=user.id, @@ -195,12 +206,13 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge api id is required" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() - def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): + def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers: Session): """Test error when external knowledge binding is not found.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="external", @@ -216,15 +228,16 @@ class TestDatasetServiceUpdateDataset: DatasetService.update_dataset(dataset.id, update_data, user) assert "External knowledge binding not found" in str(context.value) - db.session.rollback() + db_session_with_containers.rollback() # ==================== Internal Dataset Basic Tests ==================== - def test_update_internal_dataset_basic_success(self, db_session_with_containers): + def test_update_internal_dataset_basic_success(self, db_session_with_containers: Session): """Test successful update of internal dataset with basic fields.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -244,7 +257,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description == "new_description" @@ -254,11 +267,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.embedding_model == "text-embedding-ada-002" assert result.id == dataset.id - def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): + def test_update_internal_dataset_filter_none_values(self, db_session_with_containers: Session): """Test that None values are filtered out except for description field.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -278,7 +292,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.description is None @@ -289,11 +303,12 @@ class TestDatasetServiceUpdateDataset: # ==================== Indexing Technique Switch Tests ==================== - def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to economy.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -312,7 +327,7 @@ class TestDatasetServiceUpdateDataset: result = DatasetService.update_dataset(dataset.id, update_data, user) mock_task.delay.assert_called_once_with(dataset.id, "remove") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "economy" assert dataset.embedding_model is None assert dataset.embedding_model_provider is None @@ -320,10 +335,11 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): + def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers: Session): """Test updating internal dataset indexing technique to high_quality.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -366,7 +382,7 @@ class TestDatasetServiceUpdateDataset: mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") mock_task.delay.assert_called_once_with(dataset.id, "add") - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.indexing_technique == "high_quality" assert dataset.embedding_model == "text-embedding-ada-002" assert dataset.embedding_model_provider == "openai" @@ -380,9 +396,10 @@ class TestDatasetServiceUpdateDataset: self, db_session_with_containers ): """Test preserving embedding settings when indexing technique remains unchanged.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -399,7 +416,7 @@ class TestDatasetServiceUpdateDataset: } result = DatasetService.update_dataset(dataset.id, update_data, user) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.name == "new_name" assert dataset.indexing_technique == "high_quality" @@ -409,11 +426,12 @@ class TestDatasetServiceUpdateDataset: assert dataset.retrieval_model == "new_model" assert result.id == dataset.id - def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers: Session): """Test updating internal dataset with new embedding model.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) existing_binding_id = str(uuid4()) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", @@ -465,7 +483,7 @@ class TestDatasetServiceUpdateDataset: regenerate_vectors_only=True, ) - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.embedding_model == "text-embedding-3-small" assert dataset.embedding_model_provider == "openai" assert dataset.collection_binding_id == binding.id @@ -474,9 +492,9 @@ class TestDatasetServiceUpdateDataset: # ==================== Error Handling Tests ==================== - def test_update_dataset_not_found_error(self, db_session_with_containers): + def test_update_dataset_not_found_error(self, db_session_with_containers: Session): """Test error when dataset is not found.""" - user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) update_data = {"name": "new_name"} with pytest.raises(ValueError) as context: @@ -484,11 +502,16 @@ class TestDatasetServiceUpdateDataset: assert "Dataset not found" in str(context.value) - def test_update_dataset_permission_error(self, db_session_with_containers): + def test_update_dataset_permission_error(self, db_session_with_containers: Session): """Test error when user doesn't have permission.""" - owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) - outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.OWNER + ) + outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant( + db_session_with_containers, role=TenantAccountRole.NORMAL + ) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=owner.id, provider="vendor", @@ -500,10 +523,11 @@ class TestDatasetServiceUpdateDataset: with pytest.raises(NoPermissionError): DatasetService.update_dataset(dataset.id, update_data, outsider) - def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): + def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers: Session): """Test error when embedding model is not available.""" - user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(db_session_with_containers) dataset = DatasetUpdateTestDataFactory.create_dataset( + db_session_with_containers, tenant_id=tenant.id, created_by=user.id, provider="vendor", diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 546292109e..5f86cb2ae9 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -7,7 +7,7 @@ from uuid import uuid4 from sqlalchemy import select -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py index 124056e10f..c6aa89c733 100644 --- a/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py +++ b/api/tests/test_containers_integration_tests/services/test_document_service_display_status.py @@ -4,6 +4,7 @@ from uuid import uuid4 from sqlalchemy import select from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from services.dataset_service import DocumentService @@ -11,7 +12,7 @@ def _create_dataset(db_session_with_containers) -> Dataset: dataset = Dataset( tenant_id=str(uuid4()), name=f"dataset-{uuid4()}", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) dataset.id = str(uuid4()) @@ -35,11 +36,11 @@ def _create_document( tenant_id=tenant_id, dataset_id=dataset_id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch=f"batch-{uuid4()}", name=f"doc-{uuid4()}", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), doc_form="text_model", ) @@ -48,7 +49,7 @@ def _create_document( document.enabled = enabled document.archived = archived document.is_paused = is_paused - if indexing_status == "completed": + if indexing_status == IndexingStatus.COMPLETED: document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db_session_with_containers.add(document) @@ -62,7 +63,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, position=1, @@ -71,7 +72,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, position=2, @@ -80,7 +81,7 @@ def test_build_display_status_filters_available(db_session_with_containers): db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, position=3, @@ -101,14 +102,14 @@ def test_apply_display_status_filter_applies_when_status_present(db_session_with db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) @@ -125,14 +126,14 @@ def test_apply_display_status_filter_returns_same_when_invalid(db_session_with_c db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, position=1, ) doc2 = _create_document( db_session_with_containers, dataset_id=dataset.id, tenant_id=dataset.tenant_id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, position=2, ) diff --git a/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py new file mode 100644 index 0000000000..bffa520ce6 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_document_service_rename_document.py @@ -0,0 +1,253 @@ +"""Container-backed integration tests for DocumentService.rename_document real SQL paths.""" + +import datetime +import json +from unittest.mock import create_autospec, patch +from uuid import uuid4 + +import pytest + +from extensions.storage.storage_type import StorageType +from models import Account +from models.dataset import Dataset, Document +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom +from models.model import UploadFile +from services.dataset_service import DocumentService + +FIXED_UPLOAD_CREATED_AT = datetime.datetime(2024, 1, 1, 0, 0, 0) + + +@pytest.fixture +def mock_env(): + """Patch only non-SQL dependency used by rename_document: current_user context.""" + with patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user: + current_user.current_tenant_id = str(uuid4()) + current_user.id = str(uuid4()) + yield {"current_user": current_user} + + +def make_dataset(db_session_with_containers, dataset_id=None, tenant_id=None, built_in_field_enabled=False): + """Persist a dataset row for rename_document integration scenarios.""" + dataset_id = dataset_id or str(uuid4()) + tenant_id = tenant_id or str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + ) + dataset.id = dataset_id + dataset.built_in_field_enabled = built_in_field_enabled + + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + return dataset + + +def make_document( + db_session_with_containers, + document_id=None, + dataset_id=None, + tenant_id=None, + name="Old Name", + data_source_info=None, + doc_metadata=None, +): + """Persist a document row used by rename_document integration scenarios.""" + document_id = document_id or str(uuid4()) + dataset_id = dataset_id or str(uuid4()) + tenant_id = tenant_id or str(uuid4()) + + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info=json.dumps(data_source_info or {}), + batch=f"batch-{uuid4()}", + name=name, + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + doc_form="text_model", + ) + doc.id = document_id + doc.indexing_status = "completed" + doc.doc_metadata = dict(doc_metadata or {}) + + db_session_with_containers.add(doc) + db_session_with_containers.commit() + return doc + + +def make_upload_file(db_session_with_containers, tenant_id: str, file_id: str, name: str): + """Persist an upload file row referenced by document.data_source_info.""" + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key=f"uploads/{uuid4()}", + name=name, + size=128, + extension="pdf", + mime_type="application/pdf", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid4()), + created_at=FIXED_UPLOAD_CREATED_AT, + used=False, + ) + upload_file.id = file_id + + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() + return upload_file + + +def test_rename_document_success(db_session_with_containers, mock_env): + """Rename succeeds and returns the renamed document identity by id.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "New Document Name" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset_id, + tenant_id=mock_env["current_user"].current_tenant_id, + ) + + # Act + result = DocumentService.rename_document(dataset.id, document_id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert result.id == document.id + assert document.name == new_name + + +def test_rename_document_with_built_in_fields(db_session_with_containers, mock_env): + """Built-in document_name metadata is updated while existing metadata keys are preserved.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "Renamed" + dataset = make_dataset( + db_session_with_containers, + dataset_id, + mock_env["current_user"].current_tenant_id, + built_in_field_enabled=True, + ) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + doc_metadata={"foo": "bar"}, + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + assert document.name == new_name + assert document.doc_metadata["document_name"] == new_name + assert document.doc_metadata["foo"] == "bar" + + +def test_rename_document_updates_upload_file_when_present(db_session_with_containers, mock_env): + """Rename propagates to UploadFile.name when upload_file_id is present in data_source_info.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + file_id = str(uuid4()) + new_name = "Renamed" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + data_source_info={"upload_file_id": file_id}, + ) + upload_file = make_upload_file( + db_session_with_containers, + tenant_id=mock_env["current_user"].current_tenant_id, + file_id=file_id, + name="old.pdf", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(upload_file) + assert document.name == new_name + assert upload_file.name == new_name + + +def test_rename_document_does_not_update_upload_file_when_missing_id(db_session_with_containers, mock_env): + """Rename does not update UploadFile when data_source_info lacks upload_file_id.""" + # Arrange + dataset_id = str(uuid4()) + document_id = str(uuid4()) + new_name = "Another Name" + dataset = make_dataset(db_session_with_containers, dataset_id, mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + document_id=document_id, + dataset_id=dataset.id, + tenant_id=mock_env["current_user"].current_tenant_id, + data_source_info={"url": "https://example.com"}, + ) + untouched_file = make_upload_file( + db_session_with_containers, + tenant_id=mock_env["current_user"].current_tenant_id, + file_id=str(uuid4()), + name="untouched.pdf", + ) + + # Act + DocumentService.rename_document(dataset.id, document.id, new_name) + + # Assert + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(untouched_file) + assert document.name == new_name + assert untouched_file.name == "untouched.pdf" + + +def test_rename_document_dataset_not_found(db_session_with_containers, mock_env): + """Rename raises Dataset not found when dataset id does not exist.""" + # Arrange + missing_dataset_id = str(uuid4()) + + # Act / Assert + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document(missing_dataset_id, str(uuid4()), "x") + + +def test_rename_document_not_found(db_session_with_containers, mock_env): + """Rename raises Document not found when document id is absent in the dataset.""" + # Arrange + dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id) + + # Act / Assert + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, str(uuid4()), "x") + + +def test_rename_document_permission_denied_when_tenant_mismatch(db_session_with_containers, mock_env): + """Rename raises No permission when document tenant differs from current_user tenant.""" + # Arrange + dataset = make_dataset(db_session_with_containers, str(uuid4()), mock_env["current_user"].current_tenant_id) + document = make_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=str(uuid4()), + ) + + # Act / Assert + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/test_containers_integration_tests/services/test_feature_service.py b/api/tests/test_containers_integration_tests/services/test_feature_service.py index bd2fd14ffa..b3e7dd2a59 100644 --- a/api/tests/test_containers_integration_tests/services/test_feature_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feature_service.py @@ -358,10 +358,9 @@ class TestFeatureService: assert result is not None assert isinstance(result, SystemFeatureModel) - # --- 1. Verify Response Payload Optimization (Data Minimization) --- - # Ensure only essential UI flags are returned to unauthenticated clients - # to keep the payload lightweight and adhere to architectural boundaries. - assert result.license.status == LicenseStatus.NONE + # --- 1. Verify only license *status* is exposed to unauthenticated clients --- + # Detailed license info (expiry, workspaces) remains auth-gated. + assert result.license.status == LicenseStatus.ACTIVE assert result.license.expired_at == "" assert result.license.workspaces.enabled is False assert result.license.workspaces.limit == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_feedback_service.py b/api/tests/test_containers_integration_tests/services/test_feedback_service.py index 60919dff0d..771f406775 100644 --- a/api/tests/test_containers_integration_tests/services/test_feedback_service.py +++ b/api/tests/test_containers_integration_tests/services/test_feedback_service.py @@ -8,6 +8,7 @@ from unittest import mock import pytest from extensions.ext_database import db +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, Conversation, Message from services.feedback_service import FeedbackService @@ -47,8 +48,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="like", - from_source="user", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, content="Great answer!", from_end_user_id="user-123", from_account_id=None, @@ -61,8 +62,8 @@ class TestFeedbackService: app_id=app_id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="admin", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.ADMIN, content="Could be more detailed", from_end_user_id=None, from_account_id="admin-456", @@ -179,8 +180,8 @@ class TestFeedbackService: # Test with filters result = FeedbackService.export_feedbacks( app_id=sample_data["app"].id, - from_source="admin", - rating="dislike", + from_source=FeedbackFromSource.ADMIN, + rating=FeedbackRating.DISLIKE, has_comment=True, start_date="2024-01-01", end_date="2024-12-31", @@ -293,8 +294,8 @@ class TestFeedbackService: app_id=sample_data["app"].id, conversation_id="test-conversation-id", message_id="test-message-id", - rating="dislike", - from_source="user", + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, content="回答不够详细,需要更多信息", from_end_user_id="user-123", from_account_id=None, diff --git a/api/tests/test_containers_integration_tests/services/test_file_service.py b/api/tests/test_containers_integration_tests/services/test_file_service.py index 93516a0030..42dbdef1c9 100644 --- a/api/tests/test_containers_integration_tests/services/test_file_service.py +++ b/api/tests/test_containers_integration_tests/services/test_file_service.py @@ -5,9 +5,11 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import Engine +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from configs import dify_config +from extensions.storage.storage_type import StorageType from models import Account, Tenant from models.enums import CreatorUserRole from models.model import EndUser, UploadFile @@ -19,7 +21,7 @@ class TestFileService: """Integration tests for FileService using testcontainers.""" @pytest.fixture - def engine(self, db_session_with_containers): + def engine(self, db_session_with_containers: Session): bind = db_session_with_containers.get_bind() assert isinstance(bind, Engine) return bind @@ -46,7 +48,7 @@ class TestFileService: "extract_processor": mock_extract_processor, } - def _create_test_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account for testing. @@ -67,18 +69,16 @@ class TestFileService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -89,15 +89,15 @@ class TestFileService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_end_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test end user for testing. @@ -118,14 +118,14 @@ class TestFileService: session_id=fake.uuid4(), ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_upload_file(self, db_session_with_containers, mock_external_service_dependencies, account): + def _create_test_upload_file( + self, db_session_with_containers: Session, mock_external_service_dependencies, account + ): """ Helper method to create a test upload file for testing. @@ -141,7 +141,7 @@ class TestFileService: upload_file = UploadFile( tenant_id=account.current_tenant_id if hasattr(account, "current_tenant_id") else str(fake.uuid4()), - storage_type="local", + storage_type=StorageType.LOCAL, key=f"upload_files/test/{fake.uuid4()}.txt", name="test_file.txt", size=1024, @@ -155,15 +155,13 @@ class TestFileService: source_url="", ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file # Test upload_file method - def test_upload_file_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful file upload with valid parameters. """ @@ -196,7 +194,9 @@ class TestFileService: assert upload_file.id is not None - def test_upload_file_with_end_user(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_end_user( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with end user instead of account. """ @@ -219,7 +219,7 @@ class TestFileService: assert upload_file.created_by_role == CreatorUserRole.END_USER def test_upload_file_with_datasets_source( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with datasets source parameter. @@ -244,7 +244,7 @@ class TestFileService: assert upload_file.source_url == "https://example.com/source" def test_upload_file_invalid_filename_characters( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with invalid filename characters. @@ -264,8 +264,29 @@ class TestFileService: user=account, ) + def test_upload_file_allows_regular_punctuation_in_filename( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): + """ + Test file upload allows punctuation that is safe when stored as metadata. + """ + account = self._create_test_account(db_session_with_containers, mock_external_service_dependencies) + + filename = 'candidate?resume for "dify"|v2:.txt' + content = b"test content" + mimetype = "text/plain" + + upload_file = FileService(engine).upload_file( + filename=filename, + content=content, + mimetype=mimetype, + user=account, + ) + + assert upload_file.name == filename + def test_upload_file_filename_too_long( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with filename that exceeds length limit. @@ -295,7 +316,7 @@ class TestFileService: assert len(base_name) <= 200 def test_upload_file_datasets_unsupported_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload for datasets with unsupported file type. @@ -316,7 +337,9 @@ class TestFileService: source="datasets", ) - def test_upload_file_too_large(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_too_large( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with file size exceeding limit. """ @@ -338,7 +361,7 @@ class TestFileService: # Test is_file_size_within_limit method def test_is_file_size_within_limit_image_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files within limit. @@ -351,7 +374,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_video_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for video files within limit. @@ -364,7 +387,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_audio_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for audio files within limit. @@ -377,7 +400,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_document_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for document files within limit. @@ -390,7 +413,7 @@ class TestFileService: assert result is True def test_is_file_size_within_limit_image_exceeded( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for image files exceeding limit. @@ -403,7 +426,7 @@ class TestFileService: assert result is False def test_is_file_size_within_limit_unknown_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file size check for unknown file extension. @@ -416,7 +439,7 @@ class TestFileService: assert result is True # Test upload_text method - def test_upload_text_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_success(self, db_session_with_containers: Session, engine, mock_external_service_dependencies): """ Test successful text upload. """ @@ -447,7 +470,9 @@ class TestFileService: # Verify storage was called mock_external_service_dependencies["storage"].save.assert_called_once() - def test_upload_text_name_too_long(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_name_too_long( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with name that exceeds length limit. """ @@ -472,7 +497,9 @@ class TestFileService: assert upload_file.name == "a" * 200 # Test get_file_preview method - def test_get_file_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_file_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful file preview generation. """ @@ -484,9 +511,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() result = FileService(engine).get_file_preview(file_id=upload_file.id) @@ -494,7 +520,7 @@ class TestFileService: mock_external_service_dependencies["extract_processor"].load_from_upload_file.assert_called_once() def test_get_file_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with non-existent file. @@ -506,7 +532,7 @@ class TestFileService: FileService(engine).get_file_preview(file_id=non_existent_id) def test_get_file_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with unsupported file type. @@ -519,15 +545,14 @@ class TestFileService: # Update file to have non-document extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_file_preview(file_id=upload_file.id) def test_get_file_preview_text_truncation( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file preview with text that exceeds preview limit. @@ -540,9 +565,8 @@ class TestFileService: # Update file to have document extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock long text content long_text = "x" * 5000 # Longer than PREVIEW_WORDS_LIMIT @@ -554,7 +578,9 @@ class TestFileService: assert result == "x" * 3000 # Test get_image_preview method - def test_get_image_preview_success(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_get_image_preview_success( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test successful image preview generation. """ @@ -566,9 +592,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -586,7 +611,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_image_signature.assert_called_once() def test_get_image_preview_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with invalid signature. @@ -613,7 +638,7 @@ class TestFileService: ) def test_get_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-existent file. @@ -634,7 +659,7 @@ class TestFileService: ) def test_get_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test image preview with non-image file type. @@ -647,9 +672,8 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() timestamp = "1234567890" nonce = "test_nonce" @@ -665,7 +689,7 @@ class TestFileService: # Test get_file_generator_by_file_id method def test_get_file_generator_by_file_id_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful file generator retrieval. @@ -692,7 +716,7 @@ class TestFileService: mock_external_service_dependencies["file_helpers"].verify_file_signature.assert_called_once() def test_get_file_generator_by_file_id_invalid_signature( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with invalid signature. @@ -719,7 +743,7 @@ class TestFileService: ) def test_get_file_generator_by_file_id_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file generator retrieval with non-existent file. @@ -741,7 +765,7 @@ class TestFileService: # Test get_public_image_preview method def test_get_public_image_preview_success( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test successful public image preview generation. @@ -754,9 +778,8 @@ class TestFileService: # Update file to have image extension upload_file.extension = "jpg" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() generator, mime_type = FileService(engine).get_public_image_preview(file_id=upload_file.id) @@ -765,7 +788,7 @@ class TestFileService: mock_external_service_dependencies["storage"].load.assert_called_once() def test_get_public_image_preview_file_not_found( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-existent file. @@ -777,7 +800,7 @@ class TestFileService: FileService(engine).get_public_image_preview(file_id=non_existent_id) def test_get_public_image_preview_unsupported_file_type( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test public image preview with non-image file type. @@ -790,15 +813,16 @@ class TestFileService: # Update file to have non-image extension upload_file.extension = "pdf" - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() with pytest.raises(UnsupportedFileTypeError): FileService(engine).get_public_image_preview(file_id=upload_file.id) # Test edge cases and boundary conditions - def test_upload_file_empty_content(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_content( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty content. """ @@ -820,7 +844,7 @@ class TestFileService: assert upload_file.size == 0 def test_upload_file_special_characters_in_name( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with special characters in filename (but valid ones). @@ -843,7 +867,7 @@ class TestFileService: assert upload_file.name == filename def test_upload_file_different_case_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with different case extensions. @@ -865,7 +889,9 @@ class TestFileService: assert upload_file is not None assert upload_file.extension == "pdf" # Should be converted to lowercase - def test_upload_text_empty_text(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_text_empty_text( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test text upload with empty text. """ @@ -888,7 +914,9 @@ class TestFileService: assert upload_file is not None assert upload_file.size == 0 - def test_file_size_limits_edge_cases(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_file_size_limits_edge_cases( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file size limits with edge case values. """ @@ -908,7 +936,9 @@ class TestFileService: result = FileService(engine).is_file_size_within_limit(extension=extension, file_size=file_size) assert result is False - def test_upload_file_with_source_url(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_with_source_url( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with source URL that gets overridden by signed URL. """ @@ -946,7 +976,7 @@ class TestFileService: # Test file extension blacklist def test_upload_file_blocked_extension( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension. @@ -969,7 +999,7 @@ class TestFileService: ) def test_upload_file_blocked_extension_case_insensitive( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with blocked extension (case insensitive). @@ -992,7 +1022,9 @@ class TestFileService: user=account, ) - def test_upload_file_not_in_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_not_in_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with extension not in blacklist. """ @@ -1016,7 +1048,9 @@ class TestFileService: assert upload_file.name == filename assert upload_file.extension == "pdf" - def test_upload_file_empty_blacklist(self, db_session_with_containers, engine, mock_external_service_dependencies): + def test_upload_file_empty_blacklist( + self, db_session_with_containers: Session, engine, mock_external_service_dependencies + ): """ Test file upload with empty blacklist (default behavior). """ @@ -1041,7 +1075,7 @@ class TestFileService: assert upload_file.extension == "sh" def test_upload_file_multiple_blocked_extensions( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with multiple blocked extensions. @@ -1066,7 +1100,7 @@ class TestFileService: ) def test_upload_file_no_extension_with_blacklist( - self, db_session_with_containers, engine, mock_external_service_dependencies + self, db_session_with_containers: Session, engine, mock_external_service_dependencies ): """ Test file upload with no extension when blacklist is configured. diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 9c978f830f..70d05792ce 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,8 +4,8 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -68,7 +68,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - inputs=[], user_actions=[], ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value + node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT graph = json.dumps({"nodes": [{"id": "human-node", "data": node_data}], "edges": []}) workflow = Workflow.new( diff --git a/api/tests/test_containers_integration_tests/services/test_message_export_service.py b/api/tests/test_containers_integration_tests/services/test_message_export_service.py new file mode 100644 index 0000000000..00dfe9dda4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_message_export_service.py @@ -0,0 +1,234 @@ +import datetime +import json +import uuid +from decimal import Decimal + +import pytest +from sqlalchemy.orm import Session + +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ConversationFromSource, FeedbackFromSource, FeedbackRating +from models.model import ( + App, + AppAnnotationHitHistory, + Conversation, + DatasetRetrieverResource, + Message, + MessageAgentThought, + MessageAnnotation, + MessageChain, + MessageFeedback, + MessageFile, +) +from models.web import SavedMessage +from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats + + +class TestAppMessageExportServiceIntegration: + @pytest.fixture(autouse=True) + def cleanup_database(self, db_session_with_containers: Session): + yield + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() + + @staticmethod + def _create_app_context(session: Session) -> tuple[App, Conversation]: + account = Account( + email=f"test-{uuid.uuid4()}@example.com", + name="tester", + interface_language="en-US", + status="active", + ) + session.add(account) + session.flush() + + tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal") + session.add(tenant) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + session.add(join) + session.flush() + + app = App( + tenant_id=tenant.id, + name="export-app", + description="integration test app", + mode="chat", + enable_site=True, + enable_api=True, + api_rpm=60, + api_rph=3600, + is_demo=False, + is_public=False, + created_by=account.id, + updated_by=account.id, + ) + session.add(app) + session.flush() + + conversation = Conversation( + app_id=app.id, + app_model_config_id=str(uuid.uuid4()), + model_provider="openai", + model_id="gpt-4o-mini", + mode="chat", + name="conv", + inputs={"seed": 1}, + status="normal", + from_source=ConversationFromSource.API, + from_end_user_id=str(uuid.uuid4()), + ) + session.add(conversation) + session.commit() + return app, conversation + + @staticmethod + def _create_message( + session: Session, + app: App, + conversation: Conversation, + created_at: datetime.datetime, + *, + query: str, + answer: str, + inputs: dict, + message_metadata: str | None, + ) -> Message: + message = Message( + app_id=app.id, + conversation_id=conversation.id, + model_provider="openai", + model_id="gpt-4o-mini", + inputs=inputs, + query=query, + answer=answer, + message=[{"role": "assistant", "content": answer}], + message_tokens=10, + message_unit_price=Decimal("0.001"), + answer_tokens=20, + answer_unit_price=Decimal("0.002"), + total_price=Decimal("0.003"), + currency="USD", + message_metadata=message_metadata, + from_source=ConversationFromSource.API, + from_end_user_id=conversation.from_end_user_id, + created_at=created_at, + ) + session.add(message) + session.flush() + return message + + def test_iter_records_with_stats(self, db_session_with_containers: Session): + app, conversation = self._create_app_context(db_session_with_containers) + + first_inputs = { + "plain": "v1", + "nested": {"a": 1, "b": [1, {"x": True}]}, + "list": ["x", 2, {"y": "z"}], + } + second_inputs = {"other": "value", "items": [1, 2, 3]} + + base_time = datetime.datetime(2026, 2, 25, 10, 0, 0) + first_message = self._create_message( + db_session_with_containers, + app, + conversation, + created_at=base_time, + query="q1", + answer="a1", + inputs=first_inputs, + message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}), + ) + second_message = self._create_message( + db_session_with_containers, + app, + conversation, + created_at=base_time + datetime.timedelta(minutes=1), + query="q2", + answer="a2", + inputs=second_inputs, + message_metadata=None, + ) + + user_feedback_1 = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, + content="first", + from_end_user_id=conversation.from_end_user_id, + ) + user_feedback_2 = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating=FeedbackRating.DISLIKE, + from_source=FeedbackFromSource.USER, + content="second", + from_end_user_id=conversation.from_end_user_id, + ) + admin_feedback = MessageFeedback( + app_id=app.id, + conversation_id=conversation.id, + message_id=first_message.id, + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.ADMIN, + content="should-be-filtered", + from_account_id=str(uuid.uuid4()), + ) + db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback]) + user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2) + user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3) + admin_feedback.created_at = base_time + datetime.timedelta(minutes=4) + db_session_with_containers.commit() + + service = AppMessageExportService( + app_id=app.id, + start_from=base_time - datetime.timedelta(minutes=1), + end_before=base_time + datetime.timedelta(minutes=10), + filename="unused", + batch_size=1, + dry_run=True, + ) + stats = AppMessageExportStats() + records = list(service._iter_records_with_stats(stats)) + service._finalize_stats(stats) + + assert len(records) == 2 + assert records[0].message_id == first_message.id + assert records[1].message_id == second_message.id + + assert records[0].inputs == first_inputs + assert records[1].inputs == second_inputs + + assert records[0].retriever_resources == [{"dataset_id": "ds-1"}] + assert records[1].retriever_resources == [] + + assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"] + assert [feedback.content for feedback in records[0].feedback] == ["first", "second"] + assert records[1].feedback == [] + + assert stats.batches == 2 + assert stats.total_messages == 2 + assert stats.messages_with_feedback == 1 + assert stats.total_feedbacks == 2 diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index ece6de6cdf..85dc04b162 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -2,7 +2,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session +from models.enums import ConversationFromSource, FeedbackRating, InvokeFrom from models.model import MessageFeedback from services.app_service import AppService from services.errors.message import ( @@ -12,6 +14,7 @@ from services.errors.message import ( SuggestedQuestionsAfterAnswerDisabledError, ) from services.message_service import MessageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestMessageService: @@ -69,7 +72,7 @@ class TestMessageService: # "current_user": mock_current_user, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -94,7 +97,7 @@ class TestMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -127,11 +130,10 @@ class TestMessageService: # mock_external_service_dependencies["current_user"].id = account_id # mock_external_service_dependencies["current_user"].current_tenant_id = tenant_id - def _create_test_conversation(self, app, account, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, account, fake): """ Helper method to create a test conversation with all required fields. """ - from extensions.ext_database import db from models.model import Conversation conversation = Conversation( @@ -147,23 +149,22 @@ class TestMessageService: system_instruction="", system_instruction_tokens=0, status="normal", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) - db.session.add(conversation) - db.session.flush() + db_session_with_containers.add(conversation) + db_session_with_containers.flush() return conversation - def _create_test_message(self, app, conversation, account, fake): + def _create_test_message(self, db_session_with_containers: Session, app, conversation, account, fake): """ Helper method to create a test message with all required fields. """ import json - from extensions.ext_database import db from models.model import Message message = Message( @@ -186,17 +187,19 @@ class TestMessageService: provider_response_latency=0, total_price=0, currency="USD", - invoke_from="console", - from_source="console", + invoke_from=InvokeFrom.EXPLORE, + from_source=ConversationFromSource.CONSOLE, from_end_user_id=None, from_account_id=account.id, ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_pagination_by_first_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by first ID. """ @@ -204,10 +207,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by first ID @@ -228,7 +231,9 @@ class TestMessageService: # Verify messages are in ascending order assert result.data[0].created_at <= result.data[1].created_at - def test_pagination_by_first_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_first_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by first ID when no user is provided. """ @@ -246,7 +251,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_no_conversation_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID when no conversation ID is provided. @@ -265,7 +270,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_first_id_invalid_first_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by first ID with invalid first_id. @@ -274,8 +279,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid first_id with pytest.raises(FirstMessageNotExistsError): @@ -287,7 +292,9 @@ class TestMessageService: limit=10, ) - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID. """ @@ -295,10 +302,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination by last ID @@ -319,7 +326,7 @@ class TestMessageService: assert result.data[0].created_at >= result.data[1].created_at def test_pagination_by_last_id_with_include_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with include_ids filter. @@ -328,10 +335,10 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and multiple messages - conversation = self._create_test_conversation(app, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) messages = [] for i in range(5): - message = self._create_test_message(app, conversation, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) messages.append(message) # Test pagination with include_ids @@ -347,7 +354,9 @@ class TestMessageService: for message in result.data: assert message.id in include_ids - def test_pagination_by_last_id_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pagination by last ID when no user is provided. """ @@ -363,7 +372,7 @@ class TestMessageService: assert result.has_more is False def test_pagination_by_last_id_invalid_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with invalid last_id. @@ -372,8 +381,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test pagination with invalid last_id with pytest.raises(LastMessageNotExistsError): @@ -385,7 +394,7 @@ class TestMessageService: conversation_id=conversation.id, ) - def test_create_feedback_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of feedback. """ @@ -393,11 +402,11 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create feedback - rating = "like" + rating = FeedbackRating.LIKE content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=rating, content=content @@ -413,7 +422,7 @@ class TestMessageService: assert feedback.from_account_id == account.id assert feedback.from_end_user_id is None - def test_create_feedback_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test creating feedback when no user is provided. """ @@ -421,16 +430,22 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no user with pytest.raises(ValueError, match="user cannot be None"): MessageService.create_feedback( - app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=None, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) - def test_create_feedback_update_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_update_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test updating existing feedback. """ @@ -438,18 +453,18 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback - initial_rating = "like" + initial_rating = FeedbackRating.LIKE initial_content = fake.text(max_nb_chars=100) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content ) # Update feedback - updated_rating = "dislike" + updated_rating = FeedbackRating.DISLIKE updated_content = fake.text(max_nb_chars=100) updated_feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content @@ -462,7 +477,9 @@ class TestMessageService: assert updated_feedback.rating != initial_rating assert updated_feedback.content != initial_content - def test_create_feedback_delete_existing(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_feedback_delete_existing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting existing feedback by setting rating to None. """ @@ -470,25 +487,30 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create initial feedback feedback = MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100) + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=fake.text(max_nb_chars=100), ) # Delete feedback by setting rating to None MessageService.create_feedback(app_model=app, message_id=message.id, user=account, rating=None, content=None) # Verify feedback was deleted - from extensions.ext_database import db - deleted_feedback = db.session.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + deleted_feedback = ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.id == feedback.id).first() + ) assert deleted_feedback is None def test_create_feedback_no_rating_when_not_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creating feedback with no rating when feedback doesn't exist. @@ -497,8 +519,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test creating feedback with no rating when no feedback exists with pytest.raises(ValueError, match="rating cannot be None when feedback not exists"): @@ -506,7 +528,9 @@ class TestMessageService: app_model=app, message_id=message.id, user=account, rating=None, content=None ) - def test_get_all_messages_feedbacks_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_all_messages_feedbacks_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of all message feedbacks. """ @@ -516,14 +540,14 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks feedbacks = [] for i in range(3): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) feedback = MessageService.create_feedback( app_model=app, message_id=message.id, user=account, - rating="like" if i % 2 == 0 else "dislike", + rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE, content=f"Feedback {i}: {fake.text(max_nb_chars=50)}", ) feedbacks.append(feedback) @@ -539,7 +563,7 @@ class TestMessageService: assert result[i]["created_at"] >= result[i + 1]["created_at"] def test_get_all_messages_feedbacks_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of message feedbacks. @@ -549,11 +573,15 @@ class TestMessageService: # Create multiple conversations and messages with feedbacks for i in range(5): - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) MessageService.create_feedback( - app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}" + app_model=app, + message_id=message.id, + user=account, + rating=FeedbackRating.LIKE, + content=f"Feedback {i}", ) # Get feedbacks with pagination @@ -569,7 +597,7 @@ class TestMessageService: page_2_ids = {feedback["id"] for feedback in result_page_2} assert len(page_1_ids.intersection(page_2_ids)) == 0 - def test_get_message_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of message. """ @@ -577,8 +605,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Get message retrieved_message = MessageService.get_message(app_model=app, user=account, message_id=message.id) @@ -590,7 +618,7 @@ class TestMessageService: assert retrieved_message.from_source == "console" assert retrieved_message.from_account_id == account.id - def test_get_message_not_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_not_exists(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message that doesn't exist. """ @@ -601,7 +629,7 @@ class TestMessageService: with pytest.raises(MessageNotExistsError): MessageService.get_message(app_model=app, user=account, message_id=fake.uuid4()) - def test_get_message_wrong_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_message_wrong_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting message with wrong user (different account). """ @@ -609,8 +637,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Create another account from services.account_service import AccountService, TenantService @@ -619,7 +647,7 @@ class TestMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(other_account, name=fake.company()) @@ -628,7 +656,7 @@ class TestMessageService: MessageService.get_message(app_model=app, user=other_account, message_id=message.id) def test_get_suggested_questions_after_answer_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful generation of suggested questions after answer. @@ -637,8 +665,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the LLMGenerator to return specific questions mock_questions = ["What is AI?", "How does machine learning work?", "Tell me about neural networks"] @@ -665,7 +693,7 @@ class TestMessageService: mock_external_service_dependencies["trace_manager_instance"].add_trace_task.assert_called_once() def test_get_suggested_questions_after_answer_no_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no user is provided. @@ -674,8 +702,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Test getting suggested questions with no user from core.app.entities.app_invoke_entities import InvokeFrom @@ -686,7 +714,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when feature is disabled. @@ -695,8 +723,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock the feature to be disabled mock_external_service_dependencies[ @@ -712,7 +740,7 @@ class TestMessageService: ) def test_get_suggested_questions_after_answer_no_workflow( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions when no workflow exists. @@ -721,8 +749,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock no workflow mock_external_service_dependencies["workflow_service"].return_value.get_published_workflow.return_value = None @@ -738,7 +766,7 @@ class TestMessageService: assert result == [] def test_get_suggested_questions_after_answer_debugger_mode( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting suggested questions in debugger mode. @@ -747,8 +775,8 @@ class TestMessageService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) # Create a conversation and message - conversation = self._create_test_conversation(app, account, fake) - message = self._create_test_message(app, conversation, account, fake) + conversation = self._create_test_conversation(db_session_with_containers, app, account, fake) + message = self._create_test_message(db_session_with_containers, app, conversation, account, fake) # Mock questions mock_questions = ["Debug question 1", "Debug question 2"] diff --git a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py index 772365ba54..f2cb667204 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service_extra_contents.py @@ -4,6 +4,7 @@ from decimal import Decimal import pytest +from models.enums import ConversationFromSource from models.model import Message from services import message_service from tests.test_containers_integration_tests.helpers.execution_extra_content import ( @@ -36,7 +37,7 @@ def test_attach_message_extra_contents_assigns_serialized_payload(db_session_wit total_price=Decimal(0), currency="USD", status="normal", - from_source="console", + from_source=ConversationFromSource.CONSOLE, from_account_id=fixture.account.id, ) db_session_with_containers.add(message_without_extra_content) diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 5b6db64c09..8707f2e827 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -6,11 +6,19 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from enums.cloud_plan import CloudPlan -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.enums import ( + ConversationFromSource, + DataSourceType, + FeedbackFromSource, + FeedbackRating, + MessageChainType, + MessageFileBelongsTo, +) from models.model import ( App, AppAnnotationHitHistory, @@ -40,25 +48,25 @@ class TestMessagesCleanServiceIntegration: PLAN_CACHE_KEY_PREFIX = BillingService._PLAN_CACHE_KEY_PREFIX # "tenant_plan:" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before and after each test to ensure isolation.""" yield # Clear all test data in correct order (respecting foreign key constraints) - db.session.query(DatasetRetrieverResource).delete() - db.session.query(AppAnnotationHitHistory).delete() - db.session.query(SavedMessage).delete() - db.session.query(MessageFile).delete() - db.session.query(MessageAgentThought).delete() - db.session.query(MessageChain).delete() - db.session.query(MessageAnnotation).delete() - db.session.query(MessageFeedback).delete() - db.session.query(Message).delete() - db.session.query(Conversation).delete() - db.session.query(App).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DatasetRetrieverResource).delete() + db_session_with_containers.query(AppAnnotationHitHistory).delete() + db_session_with_containers.query(SavedMessage).delete() + db_session_with_containers.query(MessageFile).delete() + db_session_with_containers.query(MessageAgentThought).delete() + db_session_with_containers.query(MessageChain).delete() + db_session_with_containers.query(MessageAnnotation).delete() + db_session_with_containers.query(MessageFeedback).delete() + db_session_with_containers.query(Message).delete() + db_session_with_containers.query(Conversation).delete() + db_session_with_containers.query(App).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() @pytest.fixture(autouse=True) def cleanup_redis(self): @@ -100,7 +108,7 @@ class TestMessagesCleanServiceIntegration: with patch("services.retention.conversation.messages_clean_policy.dify_config.BILLING_ENABLED", False): yield - def _create_account_and_tenant(self, plan: str = CloudPlan.SANDBOX): + def _create_account_and_tenant(self, db_session_with_containers: Session, plan: str = CloudPlan.SANDBOX): """Helper to create account and tenant.""" fake = Faker() @@ -110,28 +118,28 @@ class TestMessagesCleanServiceIntegration: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.flush() + db_session_with_containers.add(account) + db_session_with_containers.flush() tenant = Tenant( name=fake.company(), plan=str(plan), status="normal", ) - db.session.add(tenant) - db.session.flush() + db_session_with_containers.add(tenant) + db_session_with_containers.flush() tenant_account_join = TenantAccountJoin( tenant_id=tenant.id, account_id=account.id, role=TenantAccountRole.OWNER, ) - db.session.add(tenant_account_join) - db.session.commit() + db_session_with_containers.add(tenant_account_join) + db_session_with_containers.commit() return account, tenant - def _create_app(self, tenant, account): + def _create_app(self, db_session_with_containers: Session, tenant, account): """Helper to create an app.""" fake = Faker() @@ -149,12 +157,12 @@ class TestMessagesCleanServiceIntegration: created_by=account.id, updated_by=account.id, ) - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_conversation(self, app): + def _create_conversation(self, db_session_with_containers: Session, app): """Helper to create a conversation.""" conversation = Conversation( app_id=app.id, @@ -165,15 +173,17 @@ class TestMessagesCleanServiceIntegration: name="Test conversation", inputs={}, status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid.uuid4()), ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def _create_message(self, app, conversation, created_at=None, with_relations=True): + def _create_message( + self, db_session_with_containers: Session, app, conversation, created_at=None, with_relations=True + ): """Helper to create a message with optional related records.""" if created_at is None: created_at = datetime.datetime.now() @@ -193,31 +203,31 @@ class TestMessagesCleanServiceIntegration: answer_unit_price=Decimal("0.002"), total_price=Decimal("0.003"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, from_account_id=conversation.from_end_user_id, created_at=created_at, ) - db.session.add(message) - db.session.flush() + db_session_with_containers.add(message) + db_session_with_containers.flush() if with_relations: - self._create_message_relations(message) + self._create_message_relations(db_session_with_containers, message) - db.session.commit() + db_session_with_containers.commit() return message - def _create_message_relations(self, message): + def _create_message_relations(self, db_session_with_containers: Session, message): """Helper to create all message-related records.""" # MessageFeedback feedback = MessageFeedback( app_id=message.app_id, conversation_id=message.conversation_id, message_id=message.id, - rating="like", - from_source="api", + rating=FeedbackRating.LIKE, + from_source=FeedbackFromSource.USER, from_end_user_id=str(uuid.uuid4()), ) - db.session.add(feedback) + db_session_with_containers.add(feedback) # MessageAnnotation annotation = MessageAnnotation( @@ -228,17 +238,17 @@ class TestMessagesCleanServiceIntegration: content="Test annotation", account_id=message.from_account_id, ) - db.session.add(annotation) + db_session_with_containers.add(annotation) # MessageChain chain = MessageChain( message_id=message.id, - type="system", + type=MessageChainType.SYSTEM, input=json.dumps({"test": "input"}), output=json.dumps({"test": "output"}), ) - db.session.add(chain) - db.session.flush() + db_session_with_containers.add(chain) + db_session_with_containers.flush() # MessageFile file = MessageFile( @@ -246,11 +256,11 @@ class TestMessagesCleanServiceIntegration: type="image", transfer_method="local_file", url="http://example.com/test.jpg", - belongs_to="user", + belongs_to=MessageFileBelongsTo.USER, created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(file) + db_session_with_containers.add(file) # SavedMessage saved = SavedMessage( @@ -259,9 +269,9 @@ class TestMessagesCleanServiceIntegration: created_by_role="end_user", created_by=str(uuid.uuid4()), ) - db.session.add(saved) + db_session_with_containers.add(saved) - db.session.flush() + db_session_with_containers.flush() # AppAnnotationHitHistory hit = AppAnnotationHitHistory( @@ -275,7 +285,7 @@ class TestMessagesCleanServiceIntegration: annotation_question="Test annotation question", annotation_content="Test annotation content", ) - db.session.add(hit) + db_session_with_containers.add(hit) # DatasetRetrieverResource resource = DatasetRetrieverResource( @@ -285,7 +295,7 @@ class TestMessagesCleanServiceIntegration: dataset_name="Test dataset", document_id=str(uuid.uuid4()), document_name="Test document", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, segment_id=str(uuid.uuid4()), score=0.9, content="Test content", @@ -296,25 +306,29 @@ class TestMessagesCleanServiceIntegration: retriever_from="dataset", created_by=message.from_account_id, ) - db.session.add(resource) + db_session_with_containers.add(resource) def test_billing_disabled_deletes_all_messages_in_time_range( - self, db_session_with_containers, mock_billing_disabled + self, db_session_with_containers: Session, mock_billing_disabled ): """Test that BillingDisabledPolicy deletes all messages within time range regardless of tenant plan.""" # Arrange - Create tenant with messages (plan doesn't matter for billing disabled) - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: in-range (should be deleted) and out-of-range (should be kept) in_range_date = datetime.datetime(2024, 1, 15, 12, 0, 0) out_of_range_date = datetime.datetime(2024, 1, 25, 12, 0, 0) - in_range_msg = self._create_message(app, conv, created_at=in_range_date, with_relations=True) + in_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=in_range_date, with_relations=True + ) in_range_msg_id = in_range_msg.id - out_of_range_msg = self._create_message(app, conv, created_at=out_of_range_date, with_relations=True) + out_of_range_msg = self._create_message( + db_session_with_containers, app, conv, created_at=out_of_range_date, with_relations=True + ) out_of_range_msg_id = out_of_range_msg.id # Act - create_message_clean_policy should return BillingDisabledPolicy @@ -336,17 +350,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # In-range message deleted - assert db.session.query(Message).where(Message.id == in_range_msg_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == in_range_msg_id).count() == 0 # Out-of-range message kept - assert db.session.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == out_of_range_msg_id).count() == 1 # Related records of in-range message deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == in_range_msg_id).count() == 0 - assert db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == in_range_msg_id).count() == 0 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == in_range_msg_id) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id == in_range_msg_id) + .count() + == 0 + ) # Related records of out-of-range message kept - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id == out_of_range_msg_id).count() == 1 + assert ( + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id == out_of_range_msg_id) + .count() + == 1 + ) - def test_no_messages_returns_empty_stats(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_no_messages_returns_empty_stats( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning when there are no messages to delete (B1).""" # Arrange end_before = datetime.datetime.now() - datetime.timedelta(days=30) @@ -371,36 +402,42 @@ class TestMessagesCleanServiceIntegration: assert stats["filtered_messages"] == 0 assert stats["total_deleted"] == 0 - def test_mixed_sandbox_and_paid_tenants(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_mixed_sandbox_and_paid_tenants( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cleaning with mixed sandbox and paid tenants (B2).""" # Arrange - Create sandbox tenants with expired messages sandbox_tenants = [] sandbox_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) sandbox_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 3 expired messages per sandbox tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) sandbox_message_ids.append(msg.id) # Create paid tenants with expired messages (should NOT be deleted) paid_tenants = [] paid_message_ids = [] for i in range(2): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) paid_tenants.append(tenant) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 2 expired messages per paid tenant expired_date = datetime.datetime.now() - datetime.timedelta(days=35) for j in range(2): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=j)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=j) + ) paid_message_ids.append(msg.id) # Mock billing service - return plan and expiration_date @@ -442,29 +479,39 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 6 # Only sandbox messages should be deleted - assert db.session.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(sandbox_message_ids)).count() == 0 # Paid messages should remain - assert db.session.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 + assert db_session_with_containers.query(Message).where(Message.id.in_(paid_message_ids)).count() == 4 # Related records of sandbox messages should be deleted - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(sandbox_message_ids)).count() == 0 assert ( - db.session.query(MessageAnnotation).where(MessageAnnotation.message_id.in_(sandbox_message_ids)).count() + db_session_with_containers.query(MessageFeedback) + .where(MessageFeedback.message_id.in_(sandbox_message_ids)) + .count() + == 0 + ) + assert ( + db_session_with_containers.query(MessageAnnotation) + .where(MessageAnnotation.message_id.in_(sandbox_message_ids)) + .count() == 0 ) - def test_cursor_pagination_multiple_batches(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_cursor_pagination_multiple_batches( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test cursor pagination works correctly across multiple batches (B3).""" # Arrange - Create sandbox tenant with messages that will span multiple batches - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create 10 expired messages with different timestamps base_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(10): msg = self._create_message( + db_session_with_containers, app, conv, created_at=base_date + datetime.timedelta(hours=i), @@ -498,20 +545,22 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 10 # All messages should be deleted - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 0 - def test_dry_run_does_not_delete(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_dry_run_does_not_delete(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test dry_run mode does not delete messages (B4).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create expired messages expired_date = datetime.datetime.now() - datetime.timedelta(days=35) message_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) message_ids.append(msg.id) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -540,21 +589,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # But NOT deleted # All messages should still exist - assert db.session.query(Message).where(Message.id.in_(message_ids)).count() == 3 + assert db_session_with_containers.query(Message).where(Message.id.in_(message_ids)).count() == 3 # Related records should also still exist - assert db.session.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() == 3 + assert ( + db_session_with_containers.query(MessageFeedback).where(MessageFeedback.message_id.in_(message_ids)).count() + == 3 + ) - def test_partial_plan_data_safe_default(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_partial_plan_data_safe_default( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns partial data, unknown tenants are preserved (B5).""" # Arrange - Create 3 tenants tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) tenants_data.append( { @@ -600,28 +654,30 @@ class TestMessagesCleanServiceIntegration: # Check which messages were deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 + db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 0 ) # Sandbox tenant's message deleted assert ( - db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 ) # Professional tenant's message preserved assert ( - db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 + db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 1 ) # Unknown tenant's message preserved (safe default) - def test_empty_plan_data_skips_deletion(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_empty_plan_data_skips_deletion( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test when billing returns empty data, skip deletion entirely (B6).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date) + msg = self._create_message(db_session_with_containers, app, conv, created_at=expired_date) msg_id = msg.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service to return empty data (simulating failure/no data scenario) with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -644,17 +700,20 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Message should still exist (safe default - don't delete if plan is unknown) - assert db.session.query(Message).where(Message.id == msg_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_id).count() == 1 - def test_time_range_boundary_behavior(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_time_range_boundary_behavior( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test that messages are correctly filtered by [start_from, end_before) time range (B7).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create messages: before range, in range, after range msg_before = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 1, 12, 0, 0), # Before start_from @@ -663,6 +722,7 @@ class TestMessagesCleanServiceIntegration: msg_before_id = msg_before.id msg_at_start = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 10, 12, 0, 0), # At start_from (inclusive) @@ -671,6 +731,7 @@ class TestMessagesCleanServiceIntegration: msg_at_start_id = msg_at_start.id msg_in_range = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 15, 12, 0, 0), # In range @@ -679,6 +740,7 @@ class TestMessagesCleanServiceIntegration: msg_in_range_id = msg_in_range.id msg_at_end = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 20, 12, 0, 0), # At end_before (exclusive) @@ -687,6 +749,7 @@ class TestMessagesCleanServiceIntegration: msg_at_end_id = msg_at_end.id msg_after = self._create_message( + db_session_with_containers, app, conv, created_at=datetime.datetime(2024, 1, 25, 12, 0, 0), # After end_before @@ -694,7 +757,7 @@ class TestMessagesCleanServiceIntegration: ) msg_after_id = msg_after.id - db.session.commit() + db_session_with_containers.commit() # Mock billing service with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: @@ -722,17 +785,17 @@ class TestMessagesCleanServiceIntegration: # Verify specific messages using stored IDs # Before range, kept - assert db.session.query(Message).where(Message.id == msg_before_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_before_id).count() == 1 # At start (inclusive), deleted - assert db.session.query(Message).where(Message.id == msg_at_start_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_start_id).count() == 0 # In range, deleted - assert db.session.query(Message).where(Message.id == msg_in_range_id).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == msg_in_range_id).count() == 0 # At end (exclusive), kept - assert db.session.query(Message).where(Message.id == msg_at_end_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_at_end_id).count() == 1 # After range, kept - assert db.session.query(Message).where(Message.id == msg_after_id).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == msg_after_id).count() == 1 - def test_grace_period_scenarios(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_grace_period_scenarios(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test cleaning with different graceful period scenarios (B8).""" # Arrange - Create 5 different tenants with different plan and expiration scenarios now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) @@ -740,50 +803,60 @@ class TestMessagesCleanServiceIntegration: # Scenario 1: Sandbox plan with expiration within graceful period (5 days ago) # Should NOT be deleted - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) msg1_id = msg1.id expired_5_days_ago = now_timestamp - (5 * 24 * 60 * 60) # Within grace period # Scenario 2: Sandbox plan with expiration beyond graceful period (10 days ago) # Should be deleted - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) msg2_id = msg2.id expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Beyond grace period # Scenario 3: Sandbox plan with expiration_date = -1 (no previous subscription) # Should be deleted - account3, tenant3 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app3 = self._create_app(tenant3, account3) - conv3 = self._create_conversation(app3) - msg3 = self._create_message(app3, conv3, created_at=expired_date, with_relations=False) + account3, tenant3 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app3 = self._create_app(db_session_with_containers, tenant3, account3) + conv3 = self._create_conversation(db_session_with_containers, app3) + msg3 = self._create_message( + db_session_with_containers, app3, conv3, created_at=expired_date, with_relations=False + ) msg3_id = msg3.id # Scenario 4: Non-sandbox plan (professional) with no expiration (future date) # Should NOT be deleted - account4, tenant4 = self._create_account_and_tenant(plan=CloudPlan.PROFESSIONAL) - app4 = self._create_app(tenant4, account4) - conv4 = self._create_conversation(app4) - msg4 = self._create_message(app4, conv4, created_at=expired_date, with_relations=False) + account4, tenant4 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.PROFESSIONAL) + app4 = self._create_app(db_session_with_containers, tenant4, account4) + conv4 = self._create_conversation(db_session_with_containers, app4) + msg4 = self._create_message( + db_session_with_containers, app4, conv4, created_at=expired_date, with_relations=False + ) msg4_id = msg4.id future_expiration = now_timestamp + (365 * 24 * 60 * 60) # Active for 1 year # Scenario 5: Sandbox plan with expiration exactly at grace period boundary (8 days ago) # Should NOT be deleted (boundary is exclusive: > graceful_period) - account5, tenant5 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app5 = self._create_app(tenant5, account5) - conv5 = self._create_conversation(app5) - msg5 = self._create_message(app5, conv5, created_at=expired_date, with_relations=False) + account5, tenant5 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app5 = self._create_app(db_session_with_containers, tenant5, account5) + conv5 = self._create_conversation(db_session_with_containers, app5) + msg5 = self._create_message( + db_session_with_containers, app5, conv5, created_at=expired_date, with_relations=False + ) msg5_id = msg5.id expired_exactly_8_days_ago = now_timestamp - (8 * 24 * 60 * 60) # Exactly at boundary - db.session.commit() + db_session_with_containers.commit() # Mock billing service with all scenarios plan_map = { @@ -832,23 +905,31 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 2 # Verify each scenario using saved IDs - assert db.session.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept - assert db.session.query(Message).where(Message.id == msg2_id).count() == 0 # Beyond grace, deleted - assert db.session.query(Message).where(Message.id == msg3_id).count() == 0 # No subscription, deleted - assert db.session.query(Message).where(Message.id == msg4_id).count() == 1 # Professional plan, kept - assert db.session.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept + assert db_session_with_containers.query(Message).where(Message.id == msg1_id).count() == 1 # Within grace, kept + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2_id).count() == 0 + ) # Beyond grace, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg3_id).count() == 0 + ) # No subscription, deleted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg4_id).count() == 1 + ) # Professional plan, kept + assert db_session_with_containers.query(Message).where(Message.id == msg5_id).count() == 1 # At boundary, kept - def test_tenant_whitelist(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_tenant_whitelist(self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist): """Test that whitelisted tenants' messages are not deleted (B9).""" # Arrange - Create 3 sandbox tenants with expired messages tenants_data = [] for i in range(3): - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg = self._create_message(app, conv, created_at=expired_date, with_relations=False) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date, with_relations=False + ) tenants_data.append( { @@ -897,27 +978,33 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 1 # Verify tenant0's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[0]["message_id"]).count() == 1 # Verify tenant1's message still exists (whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[1]["message_id"]).count() == 1 # Verify tenant2's message was deleted (not whitelisted) - assert db.session.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id == tenants_data[2]["message_id"]).count() == 0 - def test_from_days_cleans_old_messages(self, db_session_with_containers, mock_billing_enabled, mock_whitelist): + def test_from_days_cleans_old_messages( + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist + ): """Test from_days correctly cleans messages older than N days (B11).""" # Arrange - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) # Create old messages (should be deleted - older than 30 days) old_date = datetime.datetime.now() - datetime.timedelta(days=45) old_msg_ids = [] for i in range(3): msg = self._create_message( - app, conv, created_at=old_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=old_date - datetime.timedelta(hours=i), + with_relations=False, ) old_msg_ids.append(msg.id) @@ -926,11 +1013,15 @@ class TestMessagesCleanServiceIntegration: recent_msg_ids = [] for i in range(2): msg = self._create_message( - app, conv, created_at=recent_date - datetime.timedelta(hours=i), with_relations=False + db_session_with_containers, + app, + conv, + created_at=recent_date - datetime.timedelta(hours=i), + with_relations=False, ) recent_msg_ids.append(msg.id) - db.session.commit() + db_session_with_containers.commit() with patch("services.billing_service.BillingService.get_plan_bulk") as mock_billing: mock_billing.return_value = { @@ -955,30 +1046,34 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Old messages deleted - assert db.session.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(old_msg_ids)).count() == 0 # Recent messages kept - assert db.session.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 + assert db_session_with_containers.query(Message).where(Message.id.in_(recent_msg_ids)).count() == 2 def test_whitelist_precedence_over_grace_period( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that whitelist takes precedence over grace period logic.""" # Arrange - Create 2 sandbox tenants now_timestamp = int(datetime.datetime.now(datetime.UTC).timestamp()) # Tenant1: whitelisted, expired beyond grace period - account1, tenant1 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app1 = self._create_app(tenant1, account1) - conv1 = self._create_conversation(app1) + account1, tenant1 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app1 = self._create_app(db_session_with_containers, tenant1, account1) + conv1 = self._create_conversation(db_session_with_containers, app1) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) - msg1 = self._create_message(app1, conv1, created_at=expired_date, with_relations=False) + msg1 = self._create_message( + db_session_with_containers, app1, conv1, created_at=expired_date, with_relations=False + ) expired_30_days_ago = now_timestamp - (30 * 24 * 60 * 60) # Well beyond 21-day grace # Tenant2: not whitelisted, within grace period - account2, tenant2 = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app2 = self._create_app(tenant2, account2) - conv2 = self._create_conversation(app2) - msg2 = self._create_message(app2, conv2, created_at=expired_date, with_relations=False) + account2, tenant2 = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app2 = self._create_app(db_session_with_containers, tenant2, account2) + conv2 = self._create_conversation(db_session_with_containers, app2) + msg2 = self._create_message( + db_session_with_containers, app2, conv2, created_at=expired_date, with_relations=False + ) expired_10_days_ago = now_timestamp - (10 * 24 * 60 * 60) # Within 21-day grace # Mock billing service @@ -1019,22 +1114,26 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 0 # Verify both messages still exist - assert db.session.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted - assert db.session.query(Message).where(Message.id == msg2.id).count() == 1 # Within grace period + assert db_session_with_containers.query(Message).where(Message.id == msg1.id).count() == 1 # Whitelisted + assert ( + db_session_with_containers.query(Message).where(Message.id == msg2.id).count() == 1 + ) # Within grace period def test_empty_whitelist_deletes_eligible_messages( - self, db_session_with_containers, mock_billing_enabled, mock_whitelist + self, db_session_with_containers: Session, mock_billing_enabled, mock_whitelist ): """Test that empty whitelist behaves as no whitelist (all eligible messages deleted).""" # Arrange - Create sandbox tenant with expired messages - account, tenant = self._create_account_and_tenant(plan=CloudPlan.SANDBOX) - app = self._create_app(tenant, account) - conv = self._create_conversation(app) + account, tenant = self._create_account_and_tenant(db_session_with_containers, plan=CloudPlan.SANDBOX) + app = self._create_app(db_session_with_containers, tenant, account) + conv = self._create_conversation(db_session_with_containers, app) expired_date = datetime.datetime.now() - datetime.timedelta(days=35) msg_ids = [] for i in range(3): - msg = self._create_message(app, conv, created_at=expired_date - datetime.timedelta(hours=i)) + msg = self._create_message( + db_session_with_containers, app, conv, created_at=expired_date - datetime.timedelta(hours=i) + ) msg_ids.append(msg.id) # Mock billing service @@ -1068,4 +1167,4 @@ class TestMessagesCleanServiceIntegration: assert stats["total_deleted"] == 3 # Verify all messages were deleted - assert db.session.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_service.py b/api/tests/test_containers_integration_tests/services/test_metadata_service.py index e04725627b..e847329c5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_metadata_service.py +++ b/api/tests/test_containers_integration_tests/services/test_metadata_service.py @@ -2,10 +2,12 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.built_in_field import BuiltInField from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetMetadata, DatasetMetadataBinding, Document +from models.enums import DatasetMetadataType, DataSourceType, DocumentCreatedFrom from services.entities.knowledge_entities.knowledge_entities import MetadataArgs from services.metadata_service import MetadataService @@ -32,7 +34,7 @@ class TestMetadataService: "document_service": mock_document_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -53,18 +55,16 @@ class TestMetadataService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -73,15 +73,17 @@ class TestMetadataService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, account, tenant): + def _create_test_dataset( + self, db_session_with_containers: Session, mock_external_service_dependencies, account, tenant + ): """ Helper method to create a test dataset for testing. @@ -100,19 +102,19 @@ class TestMetadataService: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, built_in_field_enabled=False, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, mock_external_service_dependencies, dataset, account): + def _create_test_document( + self, db_session_with_containers: Session, mock_external_service_dependencies, dataset, account + ): """ Helper method to create a test document for testing. @@ -131,24 +133,22 @@ class TestMetadataService: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test-batch", name=fake.file_name(), - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text", doc_language="en", ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def test_create_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata creation with valid parameters. """ @@ -164,7 +164,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].id = account.id - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") # Act: Execute the method under test result = MetadataService.create_metadata(dataset.id, metadata_args) @@ -178,13 +178,14 @@ class TestMetadataService: assert result.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None - def test_create_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name exceeds 255 characters. """ @@ -201,13 +202,15 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id long_name = "a" * 256 # 256 characters, exceeding 255 limit - metadata_args = MetadataArgs(type="string", name=long_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=long_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_create_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata creation fails when name already exists in the same dataset. """ @@ -224,18 +227,18 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create first metadata - first_metadata_args = MetadataArgs(type="string", name="duplicate_name") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="duplicate_name") MetadataService.create_metadata(dataset.id, first_metadata_args) # Try to create second metadata with same name - second_metadata_args = MetadataArgs(type="number", name="duplicate_name") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="duplicate_name") # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists."): MetadataService.create_metadata(dataset.id, second_metadata_args) def test_create_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata creation fails when name conflicts with built-in field names. @@ -254,13 +257,15 @@ class TestMetadataService: # Try to create metadata with built-in field name built_in_field_name = BuiltInField.document_name - metadata_args = MetadataArgs(type="string", name=built_in_field_name) + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name=built_in_field_name) # Act & Assert: Verify proper error handling with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.create_metadata(dataset.id, metadata_args) - def test_update_metadata_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful metadata name update with valid parameters. """ @@ -277,7 +282,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -291,12 +296,13 @@ class TestMetadataService: assert result.updated_at is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == new_name - def test_update_metadata_name_too_long(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_too_long( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name exceeds 255 characters. """ @@ -313,7 +319,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with too long name @@ -323,7 +329,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name cannot exceed 255 characters."): MetadataService.update_metadata_name(dataset.id, metadata.id, long_name) - def test_update_metadata_name_already_exists(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_already_exists( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when new name already exists in the same dataset. """ @@ -340,10 +348,10 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create two metadata entries - first_metadata_args = MetadataArgs(type="string", name="first_metadata") + first_metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="first_metadata") first_metadata = MetadataService.create_metadata(dataset.id, first_metadata_args) - second_metadata_args = MetadataArgs(type="number", name="second_metadata") + second_metadata_args = MetadataArgs(type=DatasetMetadataType.NUMBER, name="second_metadata") second_metadata = MetadataService.create_metadata(dataset.id, second_metadata_args) # Try to update first metadata with second metadata's name @@ -351,7 +359,7 @@ class TestMetadataService: MetadataService.update_metadata_name(dataset.id, first_metadata.id, "second_metadata") def test_update_metadata_name_conflicts_with_built_in_field( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata name update fails when new name conflicts with built-in field names. @@ -369,7 +377,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="old_name") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="old_name") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Try to update with built-in field name @@ -378,7 +386,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Metadata name already exists in Built-in fields."): MetadataService.update_metadata_name(dataset.id, metadata.id, built_in_field_name) - def test_update_metadata_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_metadata_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test metadata name update fails when metadata ID does not exist. """ @@ -406,7 +416,7 @@ class TestMetadataService: # Assert: Verify the method returns None when metadata is not found assert result is None - def test_delete_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful metadata deletion with valid parameters. """ @@ -423,7 +433,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata first - metadata_args = MetadataArgs(type="string", name="to_be_deleted") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="to_be_deleted") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -434,12 +444,11 @@ class TestMetadataService: assert result.id == metadata.id # Verify metadata was deleted from database - from extensions.ext_database import db - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None - def test_delete_metadata_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_metadata_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test metadata deletion fails when metadata ID does not exist. """ @@ -467,7 +476,7 @@ class TestMetadataService: assert result is None def test_delete_metadata_with_document_bindings( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata deletion successfully removes document metadata bindings. @@ -488,7 +497,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata binding @@ -500,15 +509,13 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Set document metadata document.doc_metadata = {"test_metadata": "test_value"} - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.delete_metadata(dataset.id, metadata.id) @@ -517,13 +524,13 @@ class TestMetadataService: assert result is not None # Verify metadata was deleted from database - deleted_metadata = db.session.query(DatasetMetadata).filter_by(id=metadata.id).first() + deleted_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(id=metadata.id).first() assert deleted_metadata is None # Note: The service attempts to update document metadata but may not succeed # due to mock configuration. The main functionality (metadata deletion) is verified. - def test_get_built_in_fields_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_built_in_fields_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of built-in metadata fields. """ @@ -548,7 +555,9 @@ class TestMetadataService: assert "string" in field_types assert "time" in field_types - def test_enable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful enabling of built-in fields for a dataset. """ @@ -579,16 +588,15 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (enabling built-in fields) is verified def test_enable_built_in_field_already_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields when they are already enabled. @@ -607,10 +615,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -619,11 +626,11 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the method returns early without changes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True def test_enable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test enabling built-in fields for a dataset with no documents. @@ -647,12 +654,13 @@ class TestMetadataService: MetadataService.enable_built_in_field(dataset) # Assert: Verify the expected outcomes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is True - def test_disable_built_in_field_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_built_in_field_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful disabling of built-in fields for a dataset. """ @@ -673,10 +681,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Set document metadata with built-in fields document.doc_metadata = { @@ -686,8 +693,8 @@ class TestMetadataService: BuiltInField.last_update_date: 1234567890.0, BuiltInField.source: "test_source", } - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [ @@ -698,14 +705,14 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False # Note: Document metadata update depends on DocumentService mock working correctly # The main functionality (disabling built-in fields) is verified def test_disable_built_in_field_already_disabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields when they are already disabled. @@ -732,13 +739,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the method returns early without changes - from extensions.ext_database import db - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False def test_disable_built_in_field_with_no_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test disabling built-in fields for a dataset with no documents. @@ -757,10 +763,9 @@ class TestMetadataService: # Enable built-in fields first dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Mock DocumentService.get_working_documents_by_dataset_id to return empty list mock_external_service_dependencies["document_service"].get_working_documents_by_dataset_id.return_value = [] @@ -769,10 +774,12 @@ class TestMetadataService: MetadataService.disable_built_in_field(dataset) # Assert: Verify the expected outcomes - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) assert dataset.built_in_field_enabled is False - def test_update_documents_metadata_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_documents_metadata_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful update of documents metadata. """ @@ -792,7 +799,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -815,24 +822,25 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify document metadata was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" # Verify metadata binding was created binding = ( - db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata.id, document_id=document.id).first() + db_session_with_containers.query(DatasetMetadataBinding) + .filter_by(metadata_id=metadata.id, document_id=document.id) + .first() ) assert binding is not None assert binding.tenant_id == tenant.id assert binding.dataset_id == dataset.id def test_update_documents_metadata_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when built-in fields are enabled. @@ -850,17 +858,16 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Mock DocumentService.get_document @@ -884,7 +891,7 @@ class TestMetadataService: # Assert: Verify the expected outcomes # Verify document metadata was updated with both custom and built-in fields - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.doc_metadata is not None assert "test_metadata" in document.doc_metadata assert document.doc_metadata["test_metadata"] == "test_value" @@ -893,7 +900,7 @@ class TestMetadataService: # The main functionality (custom metadata update) is verified def test_update_documents_metadata_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of documents metadata when document is not found. @@ -911,7 +918,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create metadata operation data @@ -936,7 +943,7 @@ class TestMetadataService: MetadataService.update_documents_metadata(dataset, operation_data) def test_knowledge_base_metadata_lock_check_dataset_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for dataset operations. @@ -959,7 +966,7 @@ class TestMetadataService: assert call_args[0][0] == f"dataset_metadata_lock_{dataset_id}" def test_knowledge_base_metadata_lock_check_document_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check for document operations. @@ -982,7 +989,7 @@ class TestMetadataService: assert call_args[0][0] == f"document_metadata_lock_{document_id}" def test_knowledge_base_metadata_lock_check_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when lock already exists. @@ -999,7 +1006,7 @@ class TestMetadataService: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) def test_knowledge_base_metadata_lock_check_document_lock_exists( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test metadata lock check when document lock already exists. @@ -1013,7 +1020,9 @@ class TestMetadataService: with pytest.raises(ValueError, match="Another document metadata operation is running, please wait a moment."): MetadataService.knowledge_base_metadata_lock_check(None, document_id) - def test_get_dataset_metadatas_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of dataset metadata information. """ @@ -1030,7 +1039,7 @@ class TestMetadataService: mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Create document and metadata binding @@ -1046,10 +1055,8 @@ class TestMetadataService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(binding) - db.session.commit() + db_session_with_containers.add(binding) + db_session_with_containers.commit() # Act: Execute the method under test result = MetadataService.get_dataset_metadatas(dataset) @@ -1071,7 +1078,7 @@ class TestMetadataService: assert result["built_in_field_enabled"] is False def test_get_dataset_metadatas_with_built_in_fields_enabled( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test retrieval of dataset metadata when built-in fields are enabled. @@ -1086,17 +1093,16 @@ class TestMetadataService: # Enable built-in fields dataset.built_in_field_enabled = True - from extensions.ext_database import db - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["current_user"].current_tenant_id = tenant.id mock_external_service_dependencies["current_user"].id = account.id # Create metadata - metadata_args = MetadataArgs(type="string", name="test_metadata") + metadata_args = MetadataArgs(type=DatasetMetadataType.STRING, name="test_metadata") metadata = MetadataService.create_metadata(dataset.id, metadata_args) # Act: Execute the method under test @@ -1114,7 +1120,9 @@ class TestMetadataService: # Verify built-in field status assert result["built_in_field_enabled"] is True - def test_get_dataset_metadatas_no_metadata(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_dataset_metadatas_no_metadata( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of dataset metadata when no metadata exists. """ diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 7c8472e819..989df42499 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from models.account import TenantAccountJoin, TenantAccountRole from models.model import Account, Tenant @@ -67,7 +68,7 @@ class TestModelLoadBalancingService: "credential_schema": mock_credential_schema, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -88,18 +89,16 @@ class TestModelLoadBalancingService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -108,8 +107,8 @@ class TestModelLoadBalancingService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -117,7 +116,7 @@ class TestModelLoadBalancingService: return account, tenant def _create_test_provider_and_setting( - self, db_session_with_containers, tenant_id, mock_external_service_dependencies + self, db_session_with_containers: Session, tenant_id, mock_external_service_dependencies ): """ Helper method to create a test provider and provider model setting. @@ -132,8 +131,6 @@ class TestModelLoadBalancingService: """ fake = Faker() - from extensions.ext_database import db - # Create provider provider = Provider( tenant_id=tenant_id, @@ -141,8 +138,8 @@ class TestModelLoadBalancingService: provider_type="custom", is_valid=True, ) - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Create provider model setting provider_model_setting = ProviderModelSetting( @@ -153,12 +150,14 @@ class TestModelLoadBalancingService: enabled=True, load_balancing_enabled=False, ) - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider, provider_model_setting - def test_enable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing enablement. @@ -193,14 +192,15 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None - def test_disable_model_load_balancing_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_disable_model_load_balancing_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful model load balancing disablement. @@ -235,15 +235,14 @@ class TestModelLoadBalancingService: assert call_args.kwargs["model_type"].value == "llm" # ModelType enum value # Verify database state - from extensions.ext_database import db - db.session.refresh(provider) - db.session.refresh(provider_model_setting) + db_session_with_containers.refresh(provider) + db_session_with_containers.refresh(provider_model_setting) assert provider.id is not None assert provider_model_setting.id is not None def test_enable_model_load_balancing_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist. @@ -275,11 +274,12 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() - def test_get_load_balancing_configs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_load_balancing_configs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of load balancing configurations. @@ -298,7 +298,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -309,11 +308,11 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Verify the config was created - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Setup mocks for get_load_balancing_configs method @@ -358,11 +357,11 @@ class TestModelLoadBalancingService: assert configs[0]["ttl"] == 0 # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None def test_get_load_balancing_configs_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when provider does not exist in get_load_balancing_configs. @@ -394,12 +393,11 @@ class TestModelLoadBalancingService: assert "Provider nonexistent_provider does not exist." in str(exc_info.value) # Verify no database state changes occurred - from extensions.ext_database import db - db.session.rollback() + db_session_with_containers.rollback() def test_get_load_balancing_configs_with_inherit_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test load balancing configs retrieval with inherit configuration. @@ -419,7 +417,6 @@ class TestModelLoadBalancingService: ) # Create load balancing config - from extensions.ext_database import db load_balancing_config = LoadBalancingModelConfig( tenant_id=tenant.id, @@ -430,8 +427,8 @@ class TestModelLoadBalancingService: encrypted_config='{"api_key": "test_key"}', enabled=True, ) - db.session.add(load_balancing_config) - db.session.commit() + db_session_with_containers.add(load_balancing_config) + db_session_with_containers.commit() # Setup mocks for inherit config scenario mock_provider_config = mock_external_service_dependencies["provider_config"] @@ -467,11 +464,11 @@ class TestModelLoadBalancingService: assert configs[1]["name"] == "config1" # Verify database state - db.session.refresh(load_balancing_config) + db_session_with_containers.refresh(load_balancing_config) assert load_balancing_config.id is not None # Verify inherit config was created in database - inherit_configs = db.session.scalars( + inherit_configs = db_session_with_containers.scalars( select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.name == "__inherit__") ).all() assert len(inherit_configs) == 1 diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index f7044f7d45..6afc5aa43c 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,9 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from core.model_runtime.entities.model_entities import FetchFrom, ModelType +from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -29,7 +30,7 @@ class TestModelProviderService: "model_provider_factory": mock_model_provider_factory, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +51,16 @@ class TestModelProviderService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +69,8 @@ class TestModelProviderService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -80,7 +79,7 @@ class TestModelProviderService: def _create_test_provider( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str = "openai", @@ -109,16 +108,14 @@ class TestModelProviderService: quota_used=0, ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider def _create_test_provider_model( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -149,16 +146,14 @@ class TestModelProviderService: is_valid=True, ) - from extensions.ext_database import db - - db.session.add(provider_model) - db.session.commit() + db_session_with_containers.add(provider_model) + db_session_with_containers.commit() return provider_model def _create_test_provider_model_setting( self, - db_session_with_containers, + db_session_with_containers: Session, mock_external_service_dependencies, tenant_id: str, provider_name: str, @@ -190,14 +185,12 @@ class TestModelProviderService: load_balancing_enabled=False, ) - from extensions.ext_database import db - - db.session.add(provider_model_setting) - db.session.commit() + db_session_with_containers.add(provider_model_setting) + db_session_with_containers.commit() return provider_model_setting - def test_get_provider_list_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_list_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful provider list retrieval. @@ -275,7 +268,7 @@ class TestModelProviderService: mock_provider_config.is_custom_configuration_available.assert_called_once() def test_get_provider_list_with_model_type_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test provider list retrieval with model type filtering. @@ -374,7 +367,9 @@ class TestModelProviderService: assert result[0].provider == "cohere" assert ModelType.TEXT_EMBEDDING in result[0].supported_model_types - def test_get_models_by_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by provider. @@ -407,8 +402,8 @@ class TestModelProviderService: # Create mock models from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from core.model_runtime.entities.common_entities import I18nObject - from core.model_runtime.entities.provider_entities import ProviderEntity + from dify_graph.model_runtime.entities.common_entities import I18nObject + from dify_graph.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -485,7 +480,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_configurations.get_models.assert_called_once_with(provider="openai") - def test_get_provider_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_provider_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of provider credentials. @@ -543,7 +540,7 @@ class TestModelProviderService: mock_method.assert_called_once_with(tenant.id, "openai") def test_provider_credentials_validate_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful validation of provider credentials. @@ -585,7 +582,7 @@ class TestModelProviderService: mock_provider_configuration.validate_provider_credentials.assert_called_once_with(test_credentials) def test_provider_credentials_validate_invalid_provider( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test validation failure for non-existent provider. @@ -617,7 +614,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) def test_get_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of default model for a specific model type. @@ -643,7 +640,7 @@ class TestModelProviderService: # Create mock default model response from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from core.model_runtime.entities.common_entities import I18nObject + from dify_graph.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", @@ -673,7 +670,7 @@ class TestModelProviderService: mock_provider_manager.get_default_model.assert_called_once_with(tenant_id=tenant.id, model_type=ModelType.LLM) def test_update_default_model_of_model_type_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of default model for a specific model type. @@ -706,7 +703,9 @@ class TestModelProviderService: tenant_id=tenant.id, model_type=ModelType.LLM, provider="openai", model="gpt-4" ) - def test_get_model_provider_icon_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_provider_icon_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model provider icon. @@ -743,7 +742,9 @@ class TestModelProviderService: # Verify mock interactions mock_model_provider_factory.get_provider_icon.assert_called_once_with("openai", "icon_small", "en_US") - def test_switch_preferred_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_switch_preferred_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful switching of preferred provider type. @@ -779,7 +780,7 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.switch_preferred_provider_type.assert_called_once() - def test_enable_model_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_enable_model_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful enabling of a model. @@ -815,7 +816,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configuration.enable_model.assert_called_once_with(model_type=ModelType.LLM, model="gpt-4") - def test_get_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model credentials. @@ -872,7 +875,9 @@ class TestModelProviderService: # Verify the method was called with correct parameters mock_method.assert_called_once_with(tenant.id, "openai", "llm", "gpt-4", None) - def test_model_credentials_validate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_model_credentials_validate_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful validation of model credentials. @@ -914,7 +919,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials ) - def test_save_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful saving of model credentials. @@ -955,7 +962,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credentials=test_credentials, credential_name="testname" ) - def test_remove_model_credentials_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_remove_model_credentials_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful removal of model credentials. @@ -993,7 +1002,9 @@ class TestModelProviderService: model_type=ModelType.LLM, model="gpt-4", credential_id="5540007c-b988-46e0-b1c7-9b5fb9f330d6" ) - def test_get_models_by_model_type_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_models_by_model_type_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of models by model type. @@ -1070,7 +1081,9 @@ class TestModelProviderService: mock_provider_manager.get_configurations.assert_called_once_with(tenant.id) mock_provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) - def test_get_model_parameter_rules_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_model_parameter_rules_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of model parameter rules. @@ -1137,7 +1150,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when no credentials are available. @@ -1181,7 +1194,7 @@ class TestModelProviderService: ) def test_get_model_parameter_rules_provider_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parameter rules retrieval when provider does not exist. diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index 9e6b9837ae..94a4e62560 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -2,11 +2,14 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session +from models.enums import ConversationFromSource from models.model import EndUser, Message from models.web import SavedMessage from services.app_service import AppService from services.saved_message_service import SavedMessageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestSavedMessageService: @@ -38,7 +41,7 @@ class TestSavedMessageService: "message_service": mock_message_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -63,7 +66,7 @@ class TestSavedMessageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -85,7 +88,7 @@ class TestSavedMessageService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -108,14 +111,12 @@ class TestSavedMessageService: is_anonymous=False, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_message(self, db_session_with_containers, app, user): + def _create_test_message(self, db_session_with_containers: Session, app, user): """ Helper method to create a test message for testing. @@ -132,29 +133,30 @@ class TestSavedMessageService: # Create a simple conversation first from models.model import Conversation + is_account = hasattr(user, "current_tenant") + from_source = ConversationFromSource.CONSOLE if is_account else ConversationFromSource.API + conversation = Conversation( app_id=app.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, name=fake.sentence(nb_words=3), inputs={}, status="normal", mode="chat", ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message( app_id=app.id, conversation_id=conversation.id, - from_source="account" if hasattr(user, "current_tenant") else "end_user", - from_end_user_id=user.id if not hasattr(user, "current_tenant") else None, - from_account_id=user.id if hasattr(user, "current_tenant") else None, + from_source=from_source, + from_end_user_id=user.id if not is_account else None, + from_account_id=user.id if is_account else None, inputs={}, query=fake.sentence(nb_words=5), message=fake.text(max_nb_chars=100), @@ -165,16 +167,16 @@ class TestSavedMessageService: answer_unit_price=0.002, total_price=0.003, currency="USD", - status="success", + status="normal", ) - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message def test_pagination_by_last_id_success_with_account_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with account user. @@ -207,10 +209,8 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -240,15 +240,15 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "account" assert saved_message2.created_by_role == "account" def test_pagination_by_last_id_success_with_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination by last ID with end user. @@ -282,10 +282,8 @@ class TestSavedMessageService: created_by=end_user.id, ) - from extensions.ext_database import db - - db.session.add_all([saved_message1, saved_message2]) - db.session.commit() + db_session_with_containers.add_all([saved_message1, saved_message2]) + db_session_with_containers.commit() # Mock MessageService.pagination_by_last_id return value from libs.infinite_scroll_pagination import InfiniteScrollPagination @@ -317,14 +315,16 @@ class TestSavedMessageService: assert actual_include_ids == expected_include_ids # Verify database state - db.session.refresh(saved_message1) - db.session.refresh(saved_message2) + db_session_with_containers.refresh(saved_message1) + db_session_with_containers.refresh(saved_message2) assert saved_message1.id is not None assert saved_message2.id is not None assert saved_message1.created_by_role == "end_user" assert saved_message2.created_by_role == "end_user" - def test_save_success_with_new_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_success_with_new_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful save of a new message. @@ -347,10 +347,9 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was created in database - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -373,10 +372,12 @@ class TestSavedMessageService: ) # Verify database state - db.session.refresh(saved_message) + db_session_with_containers.refresh(saved_message) assert saved_message.id is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -396,12 +397,11 @@ class TestSavedMessageService: assert "User is required" in str(exc_info.value) # Verify no database operations were performed - from extensions.ext_database import db - saved_messages = db.session.query(SavedMessage).all() + saved_messages = db_session_with_containers.query(SavedMessage).all() assert len(saved_messages) == 0 - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -422,10 +422,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -435,7 +434,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -457,14 +458,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -481,7 +480,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -494,11 +493,13 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None - def test_pagination_by_last_id_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_error_no_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when no user is provided. @@ -522,7 +523,7 @@ class TestSavedMessageService: # Instead, we verify that the error was properly raised pass - def test_save_error_no_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_error_no_user(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when saving message with no user. @@ -543,10 +544,9 @@ class TestSavedMessageService: assert result is None # Verify no saved message was created - from extensions.ext_database import db saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -556,7 +556,9 @@ class TestSavedMessageService: assert saved_message is None - def test_delete_success_existing_message(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_success_existing_message( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful deletion of an existing saved message. @@ -578,14 +580,12 @@ class TestSavedMessageService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(saved_message) - db.session.commit() + db_session_with_containers.add(saved_message) + db_session_with_containers.commit() # Verify saved message exists assert ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -602,7 +602,7 @@ class TestSavedMessageService: # Assert: Verify the expected outcomes # Check if saved message was deleted from database deleted_saved_message = ( - db.session.query(SavedMessage) + db_session_with_containers.query(SavedMessage) .where( SavedMessage.app_id == app.id, SavedMessage.message_id == message.id, @@ -615,6 +615,6 @@ class TestSavedMessageService: assert deleted_saved_message is None # Verify database state - db.session.commit() + db_session_with_containers.commit() # The message should still exist, only the saved_message should be deleted - assert db.session.query(Message).where(Message.id == message.id).first() is not None + assert db_session_with_containers.query(Message).where(Message.id == message.id).first() is not None diff --git a/api/tests/test_containers_integration_tests/services/test_tag_service.py b/api/tests/test_containers_integration_tests/services/test_tag_service.py index e8c7f17e0b..1a72e3b6c2 100644 --- a/api/tests/test_containers_integration_tests/services/test_tag_service.py +++ b/api/tests/test_containers_integration_tests/services/test_tag_service.py @@ -4,10 +4,12 @@ from unittest.mock import create_autospec, patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset +from models.enums import DataSourceType, TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -29,7 +31,7 @@ class TestTagService: "current_user": mock_current_user, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,18 +52,16 @@ class TestTagService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -70,8 +70,8 @@ class TestTagService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -82,7 +82,7 @@ class TestTagService: return account, tenant - def _create_test_dataset(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_dataset(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test dataset for testing. @@ -101,20 +101,18 @@ class TestTagService: description=fake.text(max_nb_chars=100), provider="vendor", permission="only_me", - data_source_type="upload", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", tenant_id=tenant_id, created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant_id): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id): """ Helper method to create a test app for testing. @@ -141,15 +139,13 @@ class TestTagService: created_by=mock_external_service_dependencies["current_user"].id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app def _create_test_tags( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, tag_type, count=3 + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, tag_type, count=3 ): """ Helper method to create test tags for testing. @@ -176,16 +172,14 @@ class TestTagService: ) tags.append(tag) - from extensions.ext_database import db - for tag in tags: - db.session.add(tag) - db.session.commit() + db_session_with_containers.add(tag) + db_session_with_containers.commit() return tags def _create_test_tag_bindings( - self, db_session_with_containers, mock_external_service_dependencies, tags, target_id, tenant_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tags, target_id, tenant_id ): """ Helper method to create test tag bindings for testing. @@ -211,15 +205,13 @@ class TestTagService: ) tag_bindings.append(tag_binding) - from extensions.ext_database import db - for tag_binding in tag_bindings: - db.session.add(tag_binding) - db.session.commit() + db_session_with_containers.add(tag_binding) + db_session_with_containers.commit() return tag_bindings - def test_get_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags with binding count. @@ -270,7 +262,9 @@ class TestTagService: # The ordering is handled by the database, we just verify the results are returned assert len(result) == 3 - def test_get_tags_with_keyword_filter(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_with_keyword_filter( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval with keyword filtering. @@ -291,12 +285,11 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_development" tags[1].name = "machine_learning" tags[2].name = "web_development" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword filter result = TagService.get_tags("app", tenant.id, keyword="development") @@ -314,7 +307,7 @@ class TestTagService: assert len(result_no_match) == 0 def test_get_tags_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test tag retrieval with special characters in keyword to verify SQL injection prevention. @@ -330,8 +323,6 @@ class TestTagService: db_session_with_containers, mock_external_service_dependencies ) - from extensions.ext_database import db - # Create tags with special characters in names tag_with_percent = Tag( name="50% discount", @@ -340,7 +331,7 @@ class TestTagService: created_by=account.id, ) tag_with_percent.id = str(uuid.uuid4()) - db.session.add(tag_with_percent) + db_session_with_containers.add(tag_with_percent) tag_with_underscore = Tag( name="test_data_tag", @@ -349,7 +340,7 @@ class TestTagService: created_by=account.id, ) tag_with_underscore.id = str(uuid.uuid4()) - db.session.add(tag_with_underscore) + db_session_with_containers.add(tag_with_underscore) tag_with_backslash = Tag( name="path\\to\\tag", @@ -358,7 +349,7 @@ class TestTagService: created_by=account.id, ) tag_with_backslash.id = str(uuid.uuid4()) - db.session.add(tag_with_backslash) + db_session_with_containers.add(tag_with_backslash) # Create tag that should NOT match tag_no_match = Tag( @@ -368,9 +359,9 @@ class TestTagService: created_by=account.id, ) tag_no_match.id = str(uuid.uuid4()) - db.session.add(tag_no_match) + db_session_with_containers.add(tag_no_match) - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Test 1 - Search with % character result = TagService.get_tags("app", tenant.id, keyword="50%") @@ -392,7 +383,7 @@ class TestTagService: assert len(result) == 1 assert all("50%" in item.name for item in result) - def test_get_tags_empty_result(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_empty_result(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag retrieval when no tags exist. @@ -414,7 +405,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_target_ids_by_tag_ids_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_target_ids_by_tag_ids_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of target IDs by tag IDs. @@ -469,7 +462,7 @@ class TestTagService: assert second_dataset_count == 1 def test_get_target_ids_by_tag_ids_empty_tag_ids( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval with empty tag IDs list. @@ -493,7 +486,7 @@ class TestTagService: assert isinstance(result, list) def test_get_target_ids_by_tag_ids_no_matching_tags( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target ID retrieval when no tags match the criteria. @@ -521,7 +514,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tags by tag name. @@ -542,11 +535,10 @@ class TestTagService: ) # Update tag names to make them searchable - from extensions.ext_database import db tags[0].name = "python_tag" tags[1].name = "ml_tag" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test result = TagService.get_tag_by_tag_name("app", tenant.id, "python_tag") @@ -555,10 +547,12 @@ class TestTagService: assert result is not None assert len(result) == 1 assert result[0].name == "python_tag" - assert result[0].type == "app" + assert result[0].type == TagType.APP assert result[0].tenant_id == tenant.id - def test_get_tag_by_tag_name_no_matches(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_no_matches( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name when no matches exist. @@ -580,7 +574,9 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_get_tag_by_tag_name_empty_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_by_tag_name_empty_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by name with empty parameters. @@ -605,7 +601,9 @@ class TestTagService: assert result_empty_name is not None assert len(result_empty_name) == 0 - def test_get_tags_by_target_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tags by target ID. @@ -640,11 +638,13 @@ class TestTagService: # Verify all tags are returned for tag in result: - assert tag.type == "app" + assert tag.type == TagType.APP assert tag.tenant_id == tenant.id assert tag.id in [t.id for t in tags] - def test_get_tags_by_target_id_no_bindings(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tags_by_target_id_no_bindings( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag retrieval by target ID when no tags are bound. @@ -669,7 +669,7 @@ class TestTagService: assert len(result) == 0 assert isinstance(result, list) - def test_save_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag creation. @@ -698,17 +698,18 @@ class TestTagService: assert result.id is not None # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None # Verify tag was actually saved to database - saved_tag = db.session.query(Tag).where(Tag.id == result.id).first() + saved_tag = db_session_with_containers.query(Tag).where(Tag.id == result.id).first() assert saved_tag is not None assert saved_tag.name == "test_tag_name" - def test_save_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag creation with duplicate name. @@ -731,7 +732,7 @@ class TestTagService: TagService.save_tags(tag_args) assert "Tag name already exists" in str(exc_info.value) - def test_update_tags_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag update. @@ -763,17 +764,16 @@ class TestTagService: assert result.id == tag.id # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.name == "updated_name" # Verify tag was actually updated in database - updated_tag = db.session.query(Tag).where(Tag.id == tag.id).first() + updated_tag = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert updated_tag is not None assert updated_tag.name == "updated_name" - def test_update_tags_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag update for non-existent tag. @@ -799,7 +799,9 @@ class TestTagService: TagService.update_tags(update_args, non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_update_tags_duplicate_name_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_tags_duplicate_name_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag update with duplicate name. @@ -828,7 +830,9 @@ class TestTagService: TagService.update_tags(update_args, tag2.id) assert "Tag name already exists" in str(exc_info.value) - def test_get_tag_binding_count_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tag_binding_count_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of tag binding count. @@ -863,7 +867,7 @@ class TestTagService: assert result_tag_without_bindings == 0 def test_get_tag_binding_count_non_existent_tag( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test binding count retrieval for non-existent tag. @@ -889,7 +893,7 @@ class TestTagService: # Assert: Verify the expected outcomes assert result == 0 - def test_delete_tag_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag deletion. @@ -916,12 +920,11 @@ class TestTagService: ) # Verify tag and binding exist before deletion - from extensions.ext_database import db - tag_before = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_before = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_before is not None - binding_before = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_before = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_before is not None # Act: Execute the method under test @@ -929,14 +932,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag was deleted - tag_after = db.session.query(Tag).where(Tag.id == tag.id).first() + tag_after = db_session_with_containers.query(Tag).where(Tag.id == tag.id).first() assert tag_after is None # Verify tag binding was deleted - binding_after = db.session.query(TagBinding).where(TagBinding.tag_id == tag.id).first() + binding_after = db_session_with_containers.query(TagBinding).where(TagBinding.tag_id == tag.id).first() assert binding_after is None - def test_delete_tag_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_not_found_error(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test tag deletion for non-existent tag. @@ -960,7 +963,7 @@ class TestTagService: TagService.delete_tag(non_existent_tag_id) assert "Tag not found" in str(exc_info.value) - def test_save_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding creation. @@ -988,12 +991,11 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify tag bindings were created for tag in tags: binding = ( - db.session.query(TagBinding) + db_session_with_containers.query(TagBinding) .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) .first() ) @@ -1001,7 +1003,9 @@ class TestTagService: assert binding.tenant_id == tenant.id assert binding.created_by == account.id - def test_save_tag_binding_duplicate_handling(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_duplicate_handling( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with duplicate bindings. @@ -1032,15 +1036,16 @@ class TestTagService: TagService.save_tag_binding(binding_args) # Assert: Verify the expected outcomes - from extensions.ext_database import db # Verify only one binding exists - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 1 - def test_save_tag_binding_invalid_target_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_save_tag_binding_invalid_target_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tag binding creation with invalid target type. @@ -1071,7 +1076,7 @@ class TestTagService: TagService.save_tag_binding(binding_args) assert "Invalid binding type" in str(exc_info.value) - def test_delete_tag_binding_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_tag_binding_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful tag binding deletion. @@ -1098,10 +1103,11 @@ class TestTagService: ) # Verify binding exists before deletion - from extensions.ext_database import db binding_before = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_before is not None @@ -1112,12 +1118,14 @@ class TestTagService: # Assert: Verify the expected outcomes # Verify tag binding was deleted binding_after = ( - db.session.query(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id).first() + db_session_with_containers.query(TagBinding) + .where(TagBinding.tag_id == tag.id, TagBinding.target_id == dataset.id) + .first() ) assert binding_after is None def test_delete_tag_binding_non_existent_binding( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tag binding deletion for non-existent binding. @@ -1145,15 +1153,14 @@ class TestTagService: # Assert: Verify the expected outcomes # No error should be raised, and database state should remain unchanged - from extensions.ext_database import db - bindings = db.session.scalars( + bindings = db_session_with_containers.scalars( select(TagBinding).where(TagBinding.tag_id == tag.id, TagBinding.target_id == app.id) ).all() assert len(bindings) == 0 def test_check_target_exists_knowledge_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful target existence check for knowledge type. @@ -1179,7 +1186,7 @@ class TestTagService: # No exception should be raised for existing dataset def test_check_target_exists_knowledge_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test target existence check for non-existent knowledge dataset. @@ -1204,7 +1211,9 @@ class TestTagService: TagService.check_target_exists("knowledge", non_existent_dataset_id) assert "Dataset not found" in str(exc_info.value) - def test_check_target_exists_app_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful target existence check for app type. @@ -1228,7 +1237,9 @@ class TestTagService: # Assert: Verify the expected outcomes # No exception should be raised for existing app - def test_check_target_exists_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for non-existent app. @@ -1252,7 +1263,9 @@ class TestTagService: TagService.check_target_exists("app", non_existent_app_id) assert "App not found" in str(exc_info.value) - def test_check_target_exists_invalid_type(self, db_session_with_containers, mock_external_service_dependencies): + def test_check_target_exists_invalid_type( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test target existence check for invalid type. diff --git a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py index 5315960d73..e0ea8211f6 100644 --- a/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_trigger_provider_service.py @@ -2,14 +2,15 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.plugin.entities.plugin_daemon import CredentialType from core.trigger.entities.entities import Subscription as TriggerSubscriptionEntity -from extensions.ext_database import db from models.provider_ids import TriggerProviderID from models.trigger import TriggerSubscription from services.trigger.trigger_provider_service import TriggerProviderService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestTriggerProviderService: @@ -47,7 +48,7 @@ class TestTriggerProviderService: "account_feature_service": mock_account_feature_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -75,7 +76,7 @@ class TestTriggerProviderService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -84,7 +85,7 @@ class TestTriggerProviderService: def _create_test_subscription( self, - db_session_with_containers, + db_session_with_containers: Session, tenant_id, user_id, provider_id, @@ -135,14 +136,14 @@ class TestTriggerProviderService: expires_at=-1, ) - db.session.add(subscription) - db.session.commit() - db.session.refresh(subscription) + db_session_with_containers.add(subscription) + db_session_with_containers.commit() + db_session_with_containers.refresh(subscription) return subscription def test_rebuild_trigger_subscription_success_with_merged_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful rebuild with credential merging (HIDDEN_VALUE handling). @@ -217,7 +218,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "new-secret-value" # New value # Verify database state was updated - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == "updated_name" assert subscription.parameters == {"param1": "updated_value"} @@ -244,7 +245,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_with_all_new_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are new (no HIDDEN_VALUE). @@ -304,7 +305,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == "completely-new-secret" def test_rebuild_trigger_subscription_with_all_hidden_values( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when all credentials are HIDDEN_VALUE (preserve all existing). @@ -363,7 +364,7 @@ class TestTriggerProviderService: assert subscribe_credentials["api_secret"] == original_credentials["api_secret"] def test_rebuild_trigger_subscription_with_missing_key_uses_unknown_value( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test rebuild when HIDDEN_VALUE is used for a key that doesn't exist in original. @@ -422,7 +423,7 @@ class TestTriggerProviderService: assert subscribe_credentials["non_existent_key"] == UNKNOWN_VALUE def test_rebuild_trigger_subscription_rollback_on_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that transaction is rolled back on error. @@ -470,12 +471,12 @@ class TestTriggerProviderService: ) # Verify subscription state was not changed (rolled back) - db.session.refresh(subscription) + db_session_with_containers.refresh(subscription) assert subscription.name == original_name assert subscription.parameters == original_parameters def test_rebuild_trigger_subscription_subscription_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error when subscription is not found. @@ -501,7 +502,7 @@ class TestTriggerProviderService: ) def test_rebuild_trigger_subscription_name_uniqueness_check( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that name uniqueness is checked when updating name. diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index bbbf48ede9..6b95954480 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -3,14 +3,17 @@ from unittest.mock import patch import pytest from faker import Faker from sqlalchemy import select +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from models import Account +from models.enums import ConversationFromSource from models.model import Conversation, EndUser from models.web import PinnedConversation from services.account_service import AccountService, TenantService from services.app_service import AppService from services.web_conversation_service import WebConversationService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebConversationService: @@ -45,7 +48,7 @@ class TestWebConversationService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -68,7 +71,7 @@ class TestWebConversationService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -90,7 +93,7 @@ class TestWebConversationService: return app, account - def _create_test_end_user(self, db_session_with_containers, app): + def _create_test_end_user(self, db_session_with_containers: Session, app): """ Helper method to create a test end user for testing. @@ -111,14 +114,12 @@ class TestWebConversationService: tenant_id=app.tenant_id, ) - from extensions.ext_database import db - - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() return end_user - def _create_test_conversation(self, db_session_with_containers, app, user, fake): + def _create_test_conversation(self, db_session_with_containers: Session, app, user, fake): """ Helper method to create a test conversation for testing. @@ -145,21 +146,21 @@ class TestWebConversationService: system_instruction_tokens=50, status="normal", invoke_from=InvokeFrom.WEB_APP, - from_source="console" if isinstance(user, Account) else "api", + from_source=ConversationFromSource.CONSOLE if isinstance(user, Account) else ConversationFromSource.API, from_end_user_id=user.id if isinstance(user, EndUser) else None, from_account_id=user.id if isinstance(user, Account) else None, dialogue_count=0, is_deleted=False, ) - from extensions.ext_database import db - - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() return conversation - def test_pagination_by_last_id_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pagination_by_last_id_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination by last ID with basic parameters. """ @@ -194,7 +195,7 @@ class TestWebConversationService: assert result.data[1].updated_at >= result.data[2].updated_at def test_pagination_by_last_id_with_pinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with pinned conversation filter. @@ -222,11 +223,9 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation1) - db.session.add(pinned_conversation2) - db.session.commit() + db_session_with_containers.add(pinned_conversation1) + db_session_with_containers.add(pinned_conversation2) + db_session_with_containers.commit() # Test pagination with pinned filter result = WebConversationService.pagination_by_last_id( @@ -251,7 +250,7 @@ class TestWebConversationService: assert set(returned_ids) == set(expected_ids) def test_pagination_by_last_id_with_unpinned_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination by last ID with unpinned conversation filter. @@ -273,10 +272,8 @@ class TestWebConversationService: created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(pinned_conversation) - db.session.commit() + db_session_with_containers.add(pinned_conversation) + db_session_with_containers.commit() # Test pagination with unpinned filter result = WebConversationService.pagination_by_last_id( @@ -303,7 +300,7 @@ class TestWebConversationService: expected_unpinned_ids = [conv.id for conv in conversations[1:]] assert set(returned_ids) == set(expected_unpinned_ids) - def test_pin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful pinning of a conversation. """ @@ -317,10 +314,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -336,7 +332,9 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "account" assert pinned_conversation.created_by == account.id - def test_pin_conversation_already_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_already_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation that is already pinned (should not create duplicate). """ @@ -353,9 +351,8 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify only one pinned conversation record exists - from extensions.ext_database import db - pinned_conversations = db.session.scalars( + pinned_conversations = db_session_with_containers.scalars( select(PinnedConversation).where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -366,7 +363,9 @@ class TestWebConversationService: assert len(pinned_conversations) == 1 - def test_pin_conversation_with_end_user(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_with_end_user( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test pinning a conversation with an end user. """ @@ -383,10 +382,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, end_user) # Verify the conversation was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -402,7 +400,7 @@ class TestWebConversationService: assert pinned_conversation.created_by_role == "end_user" assert pinned_conversation.created_by == end_user.id - def test_unpin_conversation_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful unpinning of a conversation. """ @@ -416,10 +414,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -436,7 +433,7 @@ class TestWebConversationService: # Verify it was unpinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -448,7 +445,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_not_pinned(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_not_pinned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test unpinning a conversation that is not pinned (should not cause error). """ @@ -462,10 +461,9 @@ class TestWebConversationService: WebConversationService.unpin(app, conversation.id, account) # Verify no pinned conversation record exists - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -478,7 +476,7 @@ class TestWebConversationService: assert pinned_conversation is None def test_pagination_by_last_id_user_required_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that pagination_by_last_id raises ValueError when user is None. @@ -499,7 +497,7 @@ class TestWebConversationService: sort_by="-updated_at", ) - def test_pin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_pin_conversation_user_none(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test that pin method returns early when user is None. """ @@ -513,10 +511,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, None) # Verify no pinned conversation was created - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -526,7 +523,9 @@ class TestWebConversationService: assert pinned_conversation is None - def test_unpin_conversation_user_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_unpin_conversation_user_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that unpin method returns early when user is None. """ @@ -540,10 +539,9 @@ class TestWebConversationService: WebConversationService.pin(app, conversation.id, account) # Verify it was pinned - from extensions.ext_database import db pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, @@ -560,7 +558,7 @@ class TestWebConversationService: # Verify the conversation is still pinned pinned_conversation = ( - db.session.query(PinnedConversation) + db_session_with_containers.query(PinnedConversation) .where( PinnedConversation.app_id == app.id, PinnedConversation.conversation_id == conversation.id, diff --git a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py index d1c566e477..4fe65d5803 100644 --- a/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webapp_auth_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound, Unauthorized from libs.password import hash_password @@ -11,6 +12,7 @@ from models import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAcco from models.model import App, Site from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError from services.webapp_auth_service import WebAppAuthService, WebAppAuthType +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebAppAuthService: @@ -45,7 +47,7 @@ class TestWebAppAuthService: "enterprise_service": mock_enterprise_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -68,18 +70,16 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,15 +88,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_account_with_password(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_with_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test account with password for testing. @@ -108,7 +110,7 @@ class TestWebAppAuthService: tuple: (account, tenant, password) - Created account, tenant and password """ fake = Faker() - password = fake.password(length=12) + password = generate_valid_password(fake) # Create account with password import uuid @@ -131,18 +133,16 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -151,15 +151,17 @@ class TestWebAppAuthService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant, password - def _create_test_app_and_site(self, db_session_with_containers, mock_external_service_dependencies, tenant): + def _create_test_app_and_site( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant + ): """ Helper method to create a test app and site for testing. @@ -188,10 +190,8 @@ class TestWebAppAuthService: enable_api=True, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create site site = Site( @@ -203,12 +203,12 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() return app, site - def test_authenticate_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful authentication with valid email and password. @@ -233,14 +233,15 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.password is not None assert result.password_salt is not None - def test_authenticate_account_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with non-existent email. @@ -262,7 +263,7 @@ class TestWebAppAuthService: with pytest.raises(AccountNotFoundError): WebAppAuthService.authenticate(non_existent_email, "any_password") - def test_authenticate_account_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_account_banned(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test authentication with banned account. @@ -272,7 +273,7 @@ class TestWebAppAuthService: """ # Arrange: Create banned account fake = Faker() - password = fake.password(length=12) + password = generate_valid_password(fake) unique_email = f"test_{uuid.uuid4().hex[:8]}@example.com" account = Account( @@ -292,10 +293,8 @@ class TestWebAppAuthService: account.password = base64.b64encode(password_hash).decode() account.password_salt = base64.b64encode(salt).decode() - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountLoginError) as exc_info: @@ -303,7 +302,9 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) - def test_authenticate_invalid_password(self, db_session_with_containers, mock_external_service_dependencies): + def test_authenticate_invalid_password( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test authentication with invalid password. @@ -323,7 +324,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) def test_authenticate_account_without_password( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test authentication for account without password. @@ -345,10 +346,8 @@ class TestWebAppAuthService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(AccountPasswordError) as exc_info: @@ -356,7 +355,7 @@ class TestWebAppAuthService: assert "Invalid email or password." in str(exc_info.value) - def test_login_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_login_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful login and JWT token generation. @@ -388,7 +387,9 @@ class TestWebAppAuthService: assert call_args["auth_type"] == "internal" assert "exp" in call_args - def test_get_user_through_email_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful user retrieval through email. @@ -413,12 +414,13 @@ class TestWebAppAuthService: assert result.status == AccountStatus.ACTIVE # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None - def test_get_user_through_email_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with non-existent email. @@ -435,7 +437,9 @@ class TestWebAppAuthService: # Assert: Verify proper handling assert result is None - def test_get_user_through_email_banned(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_user_through_email_banned( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test user retrieval with banned account. @@ -456,10 +460,8 @@ class TestWebAppAuthService: status=AccountStatus.BANNED, ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(Unauthorized) as exc_info: @@ -468,7 +470,7 @@ class TestWebAppAuthService: assert "Account is banned." in str(exc_info.value) def test_send_email_code_login_email_with_account( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with account. @@ -509,7 +511,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_with_email_only( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email with email only. @@ -549,7 +551,7 @@ class TestWebAppAuthService: assert "code" in mail_call_args[1] def test_send_email_code_login_email_no_email_provided( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test sending email code login email without providing email. @@ -566,7 +568,9 @@ class TestWebAppAuthService: assert "Email must be provided." in str(exc_info.value) - def test_get_email_code_login_data_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful retrieval of email code login data. @@ -593,7 +597,9 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_get_email_code_login_data_no_data(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_email_code_login_data_no_data( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test email code login data retrieval when no data exists. @@ -617,7 +623,7 @@ class TestWebAppAuthService: ) def test_revoke_email_code_login_token_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful revocation of email code login token. @@ -636,7 +642,7 @@ class TestWebAppAuthService: "mock_token", "email_code_login" ) - def test_create_end_user_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful end user creation. @@ -668,14 +674,15 @@ class TestWebAppAuthService: assert result.external_user_id == "enterpriseuser" # Verify database state - from extensions.ext_database import db - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.created_at is not None assert result.updated_at is not None - def test_create_end_user_site_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_site_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation with non-existent site code. @@ -693,7 +700,9 @@ class TestWebAppAuthService: assert "Site not found." in str(exc_info.value) - def test_create_end_user_app_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_end_user_app_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test end user creation when app is not found. @@ -708,10 +717,8 @@ class TestWebAppAuthService: status="normal", ) - from extensions.ext_database import db - - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() site = Site( app_id="00000000-0000-0000-0000-000000000000", @@ -722,8 +729,8 @@ class TestWebAppAuthService: status="normal", customize_token_strategy="not_allow", ) - db.session.add(site) - db.session.commit() + db_session_with_containers.add(site) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling with pytest.raises(NotFound) as exc_info: @@ -732,7 +739,7 @@ class TestWebAppAuthService: assert "App not found." in str(exc_info.value) def test_is_app_require_permission_check_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for private access mode. @@ -751,7 +758,7 @@ class TestWebAppAuthService: assert result is True def test_is_app_require_permission_check_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement for public access mode. @@ -770,7 +777,7 @@ class TestWebAppAuthService: assert result is False def test_is_app_require_permission_check_with_app_code( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement using app code. @@ -796,7 +803,7 @@ class TestWebAppAuthService: ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with("mock_app_id") def test_is_app_require_permission_check_no_parameters( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test permission check requirement with no parameters. @@ -814,7 +821,7 @@ class TestWebAppAuthService: assert "Either app_code or app_id must be provided." in str(exc_info.value) def test_get_app_auth_type_with_access_mode_public( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for public access mode. @@ -833,7 +840,7 @@ class TestWebAppAuthService: assert result == WebAppAuthType.PUBLIC def test_get_app_auth_type_with_access_mode_private( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test app authentication type for private access mode. @@ -851,7 +858,9 @@ class TestWebAppAuthService: # Assert: Verify correct result assert result == WebAppAuthType.INTERNAL - def test_get_app_auth_type_with_app_code(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_with_app_code( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type using app code. @@ -878,7 +887,9 @@ class TestWebAppAuthService: "enterprise_service" ].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id") - def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_app_auth_type_no_parameters( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test app authentication type with no parameters. diff --git a/api/tests/test_containers_integration_tests/services/test_webhook_service.py b/api/tests/test_containers_integration_tests/services/test_webhook_service.py index 8f345b9cea..970da98c55 100644 --- a/api/tests/test_containers_integration_tests/services/test_webhook_service.py +++ b/api/tests/test_containers_integration_tests/services/test_webhook_service.py @@ -13,6 +13,7 @@ from models.trigger import AppTrigger, WorkflowWebhookTrigger from models.workflow import Workflow from services.account_service import AccountService, TenantService from services.trigger.webhook_service import WebhookService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWebhookService: @@ -60,7 +61,7 @@ class TestWebhookService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -172,7 +173,7 @@ class TestWebhookService: assert workflow.app_id == test_data["app"].id assert node_config is not None assert node_config["id"] == "webhook_node" - assert node_config["data"]["title"] == "Test Webhook" + assert node_config["data"].title == "Test Webhook" def test_get_webhook_trigger_and_workflow_not_found(self, flask_app_with_containers): """Test webhook trigger not found scenario.""" diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 040fb826e1..8ab8df2a5a 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -5,8 +5,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from core.workflow.entities.workflow_execution import WorkflowExecutionStatus +from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowRun from models.enums import CreatorUserRole from services.account_service import AccountService, TenantService @@ -14,6 +15,7 @@ from services.account_service import AccountService, TenantService # Delay import of AppService to avoid circular dependency # from services.app_service import AppService from services.workflow_app_service import WorkflowAppService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowAppService: @@ -48,7 +50,7 @@ class TestWorkflowAppService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -71,7 +73,7 @@ class TestWorkflowAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -96,7 +98,7 @@ class TestWorkflowAppService: return app, account - def _create_test_tenant_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_tenant_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test tenant and account for testing. @@ -119,14 +121,14 @@ class TestWorkflowAppService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant return tenant, account - def _create_test_app(self, db_session_with_containers, tenant, account): + def _create_test_app(self, db_session_with_containers: Session, tenant, account): """ Helper method to create a test app for testing. @@ -160,7 +162,7 @@ class TestWorkflowAppService: return app - def _create_test_workflow_data(self, db_session_with_containers, app, account): + def _create_test_workflow_data(self, db_session_with_containers: Session, app, account): """ Helper method to create test workflow data for testing. @@ -174,8 +176,6 @@ class TestWorkflowAppService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -188,8 +188,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run workflow_run = WorkflowRun( @@ -212,8 +212,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -227,13 +227,13 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() return workflow, workflow_run, workflow_app_log def test_get_paginate_workflow_app_logs_basic_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of workflow app logs with basic parameters. @@ -268,13 +268,12 @@ class TestWorkflowAppService: assert log_entry.workflow_run_id == workflow_run.id # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow_app_log) + db_session_with_containers.refresh(workflow_app_log) assert workflow_app_log.id is not None def test_get_paginate_workflow_app_logs_with_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with keyword search functionality. @@ -287,11 +286,10 @@ class TestWorkflowAppService: ) # Update workflow run with searchable content - from extensions.ext_database import db workflow_run.inputs = json.dumps({"search_term": "test_keyword", "input2": "other_value"}) workflow_run.outputs = json.dumps({"result": "test_keyword_found", "status": "success"}) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test with keyword search service = WorkflowAppService() @@ -317,7 +315,7 @@ class TestWorkflowAppService: assert len(result_no_match["data"]) == 0 def test_get_paginate_workflow_app_logs_with_special_characters_in_keyword( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): r""" Test workflow app logs pagination with special characters in keyword to verify SQL injection prevention. @@ -332,8 +330,6 @@ class TestWorkflowAppService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) workflow, _, _ = self._create_test_workflow_data(db_session_with_containers, app, account) - from extensions.ext_database import db - service = WorkflowAppService() # Test 1: Search with % character @@ -353,8 +349,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_1) - db.session.flush() + db_session_with_containers.add(workflow_run_1) + db_session_with_containers.flush() workflow_app_log_1 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -367,8 +363,8 @@ class TestWorkflowAppService: ) workflow_app_log_1.id = str(uuid.uuid4()) workflow_app_log_1.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_1) - db.session.commit() + db_session_with_containers.add(workflow_app_log_1) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -395,8 +391,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_2) - db.session.flush() + db_session_with_containers.add(workflow_run_2) + db_session_with_containers.flush() workflow_app_log_2 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -409,8 +405,8 @@ class TestWorkflowAppService: ) workflow_app_log_2.id = str(uuid.uuid4()) workflow_app_log_2.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_2) - db.session.commit() + db_session_with_containers.add(workflow_app_log_2) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="test_data", page=1, limit=20 @@ -437,8 +433,8 @@ class TestWorkflowAppService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(workflow_run_4) - db.session.flush() + db_session_with_containers.add(workflow_run_4) + db_session_with_containers.flush() workflow_app_log_4 = WorkflowAppLog( tenant_id=app.tenant_id, @@ -451,8 +447,8 @@ class TestWorkflowAppService: ) workflow_app_log_4.id = str(uuid.uuid4()) workflow_app_log_4.created_at = datetime.now(UTC) - db.session.add(workflow_app_log_4) - db.session.commit() + db_session_with_containers.add(workflow_app_log_4) + db_session_with_containers.commit() result = service.get_paginate_workflow_app_logs( session=db_session_with_containers, app_model=app, keyword="50%", page=1, limit=20 @@ -467,7 +463,7 @@ class TestWorkflowAppService: assert workflow_run_4.id not in found_run_ids def test_get_paginate_workflow_app_logs_with_status_filter( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with status filtering. @@ -476,8 +472,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -490,8 +484,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different statuses statuses = ["succeeded", "failed", "running", "stopped"] @@ -519,8 +513,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1) if status != "running" else None, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -533,8 +527,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -568,7 +562,7 @@ class TestWorkflowAppService: assert result_running["data"][0].workflow_run.status == "running" def test_get_paginate_workflow_app_logs_with_time_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with time-based filtering. @@ -577,8 +571,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -591,8 +583,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow runs with different timestamps base_time = datetime.now(UTC) @@ -627,8 +619,8 @@ class TestWorkflowAppService: created_at=timestamp, finished_at=timestamp + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -641,8 +633,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = timestamp - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -682,7 +674,7 @@ class TestWorkflowAppService: assert result_range["total"] == 2 # Should get logs from 2 hours ago and 1 hour ago def test_get_paginate_workflow_app_logs_with_pagination( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with different page sizes and limits. @@ -691,8 +683,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -705,8 +695,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create 25 workflow runs and logs total_logs = 25 @@ -734,8 +724,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -748,8 +738,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -798,7 +788,7 @@ class TestWorkflowAppService: assert len(result_large_limit["data"]) == total_logs def test_get_paginate_workflow_app_logs_with_user_role_filtering( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with user role and session filtering. @@ -807,8 +797,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -821,8 +809,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create end user end_user = EndUser( @@ -835,8 +823,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), updated_at=datetime.now(UTC), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create workflow runs and logs for both account and end user workflow_runs = [] @@ -864,8 +852,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i), finished_at=datetime.now(UTC) + timedelta(minutes=i + 1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -878,8 +866,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -906,8 +894,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC) + timedelta(minutes=i + 10), finished_at=datetime.now(UTC) + timedelta(minutes=i + 11), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() workflow_app_log = WorkflowAppLog( tenant_id=app.tenant_id, @@ -920,8 +908,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) + timedelta(minutes=i + 10) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() workflow_runs.append(workflow_run) workflow_app_logs.append(workflow_app_log) @@ -994,7 +982,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_uuid_keyword_search( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with UUID keyword search functionality. @@ -1003,8 +991,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1017,8 +1003,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with specific UUID workflow_run_id = str(uuid.uuid4()) @@ -1042,8 +1028,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC) + timedelta(minutes=1), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1057,8 +1043,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test UUID keyword search service = WorkflowAppService() @@ -1085,7 +1071,7 @@ class TestWorkflowAppService: assert result_invalid_uuid["total"] == 0 def test_get_paginate_workflow_app_logs_with_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with edge cases and boundary conditions. @@ -1094,8 +1080,6 @@ class TestWorkflowAppService: fake = Faker() app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) - from extensions.ext_database import db - # Create workflow workflow = Workflow( id=str(uuid.uuid4()), @@ -1108,8 +1092,8 @@ class TestWorkflowAppService: created_by=account.id, updated_by=account.id, ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create workflow run with edge case data workflow_run = WorkflowRun( @@ -1132,8 +1116,8 @@ class TestWorkflowAppService: created_at=datetime.now(UTC), finished_at=datetime.now(UTC), ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() # Create workflow app log workflow_app_log = WorkflowAppLog( @@ -1147,8 +1131,8 @@ class TestWorkflowAppService: ) workflow_app_log.id = str(uuid.uuid4()) workflow_app_log.created_at = datetime.now(UTC) - db.session.add(workflow_app_log) - db.session.commit() + db_session_with_containers.add(workflow_app_log) + db_session_with_containers.commit() # Act & Assert: Test edge cases service = WorkflowAppService() @@ -1185,7 +1169,7 @@ class TestWorkflowAppService: assert result_high_page["has_more"] is False def test_get_paginate_workflow_app_logs_with_empty_results( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with empty results and no data scenarios. @@ -1252,7 +1236,7 @@ class TestWorkflowAppService: assert "Account not found" in str(exc_info.value) def test_get_paginate_workflow_app_logs_with_complex_query_combinations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with complex query combinations. @@ -1352,7 +1336,7 @@ class TestWorkflowAppService: assert len(result_time_status_limit["data"]) <= 2 def test_get_paginate_workflow_app_logs_with_large_dataset_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with large dataset for performance validation. @@ -1444,7 +1428,7 @@ class TestWorkflowAppService: assert result_last_page["page"] == 3 def test_get_paginate_workflow_app_logs_with_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow app logs pagination with proper tenant isolation. diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 1f91b40963..572cf72fa0 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,8 +1,9 @@ import pytest from faker import Faker +from sqlalchemy.orm import Session -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.segments import StringSegment +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -44,7 +45,7 @@ class TestWorkflowDraftVariableService: # WorkflowDraftVariableService doesn't have external dependencies that need mocking return {} - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, fake=None): + def _create_test_app(self, db_session_with_containers: Session, mock_external_service_dependencies, fake=None): """ Helper method to create a test app with realistic data for testing. @@ -75,13 +76,11 @@ class TestWorkflowDraftVariableService: app.created_by = fake.uuid4() app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, fake=None): """ Helper method to create a test workflow associated with an app. @@ -110,20 +109,20 @@ class TestWorkflowDraftVariableService: conversation_variables=[], rag_pipeline_variables=[], ) - from extensions.ext_database import db - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow def _create_test_variable( self, - db_session_with_containers, + db_session_with_containers: Session, app_id, node_id, name, value, variable_type: DraftVariableType = DraftVariableType.CONVERSATION, + user_id: str | None = None, fake=None, ): """ @@ -146,10 +145,15 @@ class TestWorkflowDraftVariableService: WorkflowDraftVariable: Created test variable instance with proper type configuration """ fake = fake or Faker() + if user_id is None: + app = db_session_with_containers.query(App).filter_by(id=app_id).first() + assert app is not None + user_id = app.created_by if variable_type == "conversation": # Create conversation variable using the appropriate factory method variable = WorkflowDraftVariable.new_conversation_variable( app_id=app_id, + user_id=user_id, name=name, value=value, description=fake.text(max_nb_chars=20), @@ -158,6 +162,7 @@ class TestWorkflowDraftVariableService: # Create system variable with editable flag and execution context variable = WorkflowDraftVariable.new_sys_variable( app_id=app_id, + user_id=user_id, name=name, value=value, node_execution_id=fake.uuid4(), @@ -167,6 +172,7 @@ class TestWorkflowDraftVariableService: # Create node variable with visibility and editability settings variable = WorkflowDraftVariable.new_node_variable( app_id=app_id, + user_id=user_id, node_id=node_id, name=name, value=value, @@ -174,13 +180,12 @@ class TestWorkflowDraftVariableService: visible=True, editable=True, ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() return variable - def test_get_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a single variable by ID successfully. @@ -192,7 +197,13 @@ class TestWorkflowDraftVariableService: app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) test_value = StringSegment(value=fake.word()) variable = self._create_test_variable( - db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "test_var", + test_value, + user_id=app.created_by, + fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) retrieved_variable = service.get_variable(variable.id) @@ -202,7 +213,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable.app_id == app.id assert retrieved_variable.get_value().value == test_value.value - def test_get_variable_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test getting a variable that doesn't exist. @@ -217,7 +228,7 @@ class TestWorkflowDraftVariableService: assert retrieved_variable is None def test_get_draft_variables_by_selectors_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting variables by selectors successfully. @@ -253,7 +264,7 @@ class TestWorkflowDraftVariableService: ["test_node_1", "var3"], ] service = WorkflowDraftVariableService(db_session_with_containers) - retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors) + retrieved_variables = service.get_draft_variables_by_selectors(app.id, selectors, user_id=app.created_by) assert len(retrieved_variables) == 3 var_names = [var.name for var in retrieved_variables] assert "var1" in var_names @@ -268,7 +279,7 @@ class TestWorkflowDraftVariableService: assert var.get_value().value == var3_value.value def test_list_variables_without_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test listing variables without values successfully with pagination. @@ -291,7 +302,7 @@ class TestWorkflowDraftVariableService: fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) - result = service.list_variables_without_values(app.id, page=1, limit=3) + result = service.list_variables_without_values(app.id, page=1, limit=3, user_id=app.created_by) assert result.total == 5 assert len(result.variables) == 3 assert result.variables[0].created_at >= result.variables[1].created_at @@ -300,7 +311,7 @@ class TestWorkflowDraftVariableService: assert var.name is not None assert var.app_id == app.id - def test_list_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_node_variables_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test listing variables for a specific node successfully. @@ -342,7 +353,7 @@ class TestWorkflowDraftVariableService: fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) - result = service.list_node_variables(app.id, node_id) + result = service.list_node_variables(app.id, node_id, user_id=app.created_by) assert len(result.variables) == 2 for var in result.variables: assert var.node_id == node_id @@ -352,7 +363,9 @@ class TestWorkflowDraftVariableService: assert "var2" in var_names assert "var3" not in var_names - def test_list_conversation_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_conversation_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing conversation variables successfully. @@ -382,7 +395,7 @@ class TestWorkflowDraftVariableService: fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) - result = service.list_conversation_variables(app.id) + result = service.list_conversation_variables(app.id, user_id=app.created_by) assert len(result.variables) == 2 for var in result.variables: assert var.node_id == CONVERSATION_VARIABLE_NODE_ID @@ -393,7 +406,7 @@ class TestWorkflowDraftVariableService: assert "conv_var2" in var_names assert "sys_var" not in var_names - def test_update_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test updating a variable's name and value successfully. @@ -418,14 +431,15 @@ class TestWorkflowDraftVariableService: assert updated_variable.name == "new_name" assert updated_variable.get_value().value == new_value.value assert updated_variable.last_edited_at is not None - from extensions.ext_database import db - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.name == "new_name" assert variable.get_value().value == new_value.value assert variable.last_edited_at is not None - def test_update_variable_not_editable(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_variable_not_editable( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test that updating a non-editable variable raises an exception. @@ -445,17 +459,18 @@ class TestWorkflowDraftVariableService: node_execution_id=fake.uuid4(), editable=False, # Set as non-editable ) - from extensions.ext_database import db - db.session.add(variable) - db.session.commit() + db_session_with_containers.add(variable) + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) with pytest.raises(UpdateNotSupportedError) as exc_info: service.update_variable(variable, name="new_name", value=new_value) assert "variable not support updating" in str(exc_info.value) assert variable.id in str(exc_info.value) - def test_reset_conversation_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_reset_conversation_variable_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test resetting conversation variable successfully. @@ -467,7 +482,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.workflow.variables.variables import StringVariable + from dify_graph.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -476,9 +491,8 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "test_conv_var"], ) workflow.conversation_variables = [conv_var] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() modified_value = StringSegment(value=fake.word()) variable = self._create_test_variable( db_session_with_containers, @@ -489,17 +503,17 @@ class TestWorkflowDraftVariableService: fake=fake, ) variable.last_edited_at = fake.date_time() - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) reset_variable = service.reset_variable(workflow, variable) assert reset_variable is not None assert reset_variable.get_value().value == "default_value" assert reset_variable.last_edited_at is None - db.session.refresh(variable) + db_session_with_containers.refresh(variable) assert variable.get_value().value == "default_value" assert variable.last_edited_at is None - def test_delete_variable_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_variable_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test deleting a single variable successfully. @@ -513,14 +527,15 @@ class TestWorkflowDraftVariableService: variable = self._create_test_variable( db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "test_var", test_value, fake=fake ) - from extensions.ext_database import db - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is not None service = WorkflowDraftVariableService(db_session_with_containers) service.delete_variable(variable) - assert db.session.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None + assert db_session_with_containers.query(WorkflowDraftVariable).filter_by(id=variable.id).first() is None - def test_delete_workflow_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_workflow_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a workflow successfully. @@ -550,20 +565,88 @@ class TestWorkflowDraftVariableService: other_value, fake=fake, ) - from extensions.ext_database import db - app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + app_variables = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables) == 3 assert len(other_app_variables) == 1 service = WorkflowDraftVariableService(db_session_with_containers) - service.delete_workflow_variables(app.id) - app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() - other_app_variables_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + service.delete_user_workflow_variables(app.id, user_id=app.created_by) + app_variables_after = db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id).all() + other_app_variables_after = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=other_app.id).all() + ) assert len(app_variables_after) == 0 assert len(other_app_variables_after) == 1 - def test_delete_node_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_draft_variables_are_isolated_between_users( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """ + Test draft variable isolation for different users in the same app. + + This test verifies that: + 1. Query APIs return only variables owned by the target user. + 2. User-scoped deletion only removes variables for that user and keeps + other users' variables in the same app untouched. + """ + fake = Faker() + app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) + user_a = app.created_by + user_b = fake.uuid4() + + # Use identical variable names on purpose to verify uniqueness scope includes user_id. + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "shared_name", + StringSegment(value="value_a"), + user_id=user_a, + fake=fake, + ) + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "shared_name", + StringSegment(value="value_b"), + user_id=user_b, + fake=fake, + ) + self._create_test_variable( + db_session_with_containers, + app.id, + CONVERSATION_VARIABLE_NODE_ID, + "only_a", + StringSegment(value="only_a"), + user_id=user_a, + fake=fake, + ) + + service = WorkflowDraftVariableService(db_session_with_containers) + + user_a_vars = service.list_conversation_variables(app.id, user_id=user_a) + user_b_vars = service.list_conversation_variables(app.id, user_id=user_b) + assert {v.name for v in user_a_vars.variables} == {"shared_name", "only_a"} + assert {v.name for v in user_b_vars.variables} == {"shared_name"} + + service.delete_user_workflow_variables(app.id, user_id=user_a) + + user_a_remaining = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_a).count() + ) + user_b_remaining = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, user_id=user_b).count() + ) + assert user_a_remaining == 0 + assert user_b_remaining == 1 + + def test_delete_node_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test deleting all variables for a specific node successfully. @@ -605,14 +688,15 @@ class TestWorkflowDraftVariableService: conv_value, fake=fake, ) - from extensions.ext_database import db - target_node_variables = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + target_node_variables = ( + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + ) other_node_variables = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -620,15 +704,15 @@ class TestWorkflowDraftVariableService: assert len(other_node_variables) == 1 assert len(conv_variables) == 1 service = WorkflowDraftVariableService(db_session_with_containers) - service.delete_node_variables(app.id, node_id) + service.delete_node_variables(app.id, node_id, user_id=app.created_by) target_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id=node_id).all() ) other_node_variables_after = ( - db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() + db_session_with_containers.query(WorkflowDraftVariable).filter_by(app_id=app.id, node_id="other_node").all() ) conv_variables_after = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -637,7 +721,7 @@ class TestWorkflowDraftVariableService: assert len(conv_variables_after) == 1 def test_prefill_conversation_variable_default_values_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test prefill conversation variable default values successfully. @@ -650,7 +734,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from core.workflow.variables.variables import StringVariable + from dify_graph.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), @@ -665,13 +749,12 @@ class TestWorkflowDraftVariableService: selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_var2"], ) workflow.conversation_variables = [conv_var1, conv_var2] - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() service = WorkflowDraftVariableService(db_session_with_containers) - service.prefill_conversation_variable_default_values(workflow) + service.prefill_conversation_variable_default_values(workflow, user_id="00000000-0000-0000-0000-000000000001") draft_variables = ( - db.session.query(WorkflowDraftVariable) + db_session_with_containers.query(WorkflowDraftVariable) .filter_by(app_id=app.id, node_id=CONVERSATION_VARIABLE_NODE_ID) .all() ) @@ -686,7 +769,7 @@ class TestWorkflowDraftVariableService: assert var.get_variable_type() == DraftVariableType.CONVERSATION def test_get_conversation_id_from_draft_variable_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID from draft variable successfully. @@ -709,11 +792,11 @@ class TestWorkflowDraftVariableService: fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) - retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by) assert retrieved_conv_id == conversation_id def test_get_conversation_id_from_draft_variable_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting conversation ID when it doesn't exist. @@ -725,10 +808,12 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) service = WorkflowDraftVariableService(db_session_with_containers) - retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id) + retrieved_conv_id = service._get_conversation_id_from_draft_variable(app.id, app.created_by) assert retrieved_conv_id is None - def test_list_system_variables_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_list_system_variables_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test listing system variables successfully. @@ -764,7 +849,7 @@ class TestWorkflowDraftVariableService: db_session_with_containers, app.id, CONVERSATION_VARIABLE_NODE_ID, "conv_var", conv_var_value, fake=fake ) service = WorkflowDraftVariableService(db_session_with_containers) - result = service.list_system_variables(app.id) + result = service.list_system_variables(app.id, user_id=app.created_by) assert len(result.variables) == 2 for var in result.variables: assert var.node_id == SYSTEM_VARIABLE_NODE_ID @@ -775,7 +860,9 @@ class TestWorkflowDraftVariableService: assert "sys_var2" in var_names assert "conv_var" not in var_names - def test_get_variable_by_name_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name successfully for different types. @@ -809,20 +896,22 @@ class TestWorkflowDraftVariableService: fake=fake, ) service = WorkflowDraftVariableService(db_session_with_containers) - retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var") + retrieved_conv_var = service.get_conversation_variable(app.id, "test_conv_var", user_id=app.created_by) assert retrieved_conv_var is not None assert retrieved_conv_var.name == "test_conv_var" assert retrieved_conv_var.node_id == CONVERSATION_VARIABLE_NODE_ID - retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var") + retrieved_sys_var = service.get_system_variable(app.id, "test_sys_var", user_id=app.created_by) assert retrieved_sys_var is not None assert retrieved_sys_var.name == "test_sys_var" assert retrieved_sys_var.node_id == SYSTEM_VARIABLE_NODE_ID - retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var") + retrieved_node_var = service.get_node_variable(app.id, "test_node", "test_node_var", user_id=app.created_by) assert retrieved_node_var is not None assert retrieved_node_var.name == "test_node_var" assert retrieved_node_var.node_id == "test_node" - def test_get_variable_by_name_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_variable_by_name_not_found( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test getting variables by name when they don't exist. @@ -833,9 +922,14 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) service = WorkflowDraftVariableService(db_session_with_containers) - retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var") + retrieved_conv_var = service.get_conversation_variable(app.id, "non_existent_conv_var", user_id=app.created_by) assert retrieved_conv_var is None - retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var") + retrieved_sys_var = service.get_system_variable(app.id, "non_existent_sys_var", user_id=app.created_by) assert retrieved_sys_var is None - retrieved_node_var = service.get_node_variable(app.id, "test_node", "non_existent_node_var") + retrieved_node_var = service.get_node_variable( + app.id, + "test_node", + "non_existent_node_var", + user_id=app.created_by, + ) assert retrieved_node_var is None diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 3a88081db3..731770e01a 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -5,8 +5,9 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from models.enums import CreatorUserRole +from models.enums import ConversationFromSource, CreatorUserRole from models.model import ( Message, ) @@ -14,6 +15,7 @@ from models.workflow import WorkflowRun from services.account_service import AccountService, TenantService from services.app_service import AppService from services.workflow_run_service import WorkflowRunService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowRunService: @@ -48,7 +50,7 @@ class TestWorkflowRunService: "account_feature_service": mock_account_feature_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -71,7 +73,7 @@ class TestWorkflowRunService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -94,7 +96,7 @@ class TestWorkflowRunService: return app, account def _create_test_workflow_run( - self, db_session_with_containers, app, account, triggered_from="debugging", offset_minutes=0 + self, db_session_with_containers: Session, app, account, triggered_from="debugging", offset_minutes=0 ): """ Helper method to create a test workflow run for testing. @@ -110,8 +112,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create workflow run with offset timestamp base_time = datetime.now(UTC) created_time = base_time - timedelta(minutes=offset_minutes) @@ -136,12 +136,12 @@ class TestWorkflowRunService: finished_at=created_time, ) - db.session.add(workflow_run) - db.session.commit() + db_session_with_containers.add(workflow_run) + db_session_with_containers.commit() return workflow_run - def _create_test_message(self, db_session_with_containers, app, account, workflow_run): + def _create_test_message(self, db_session_with_containers: Session, app, account, workflow_run): """ Helper method to create a test message for testing. @@ -156,8 +156,6 @@ class TestWorkflowRunService: """ fake = Faker() - from extensions.ext_database import db - # Create conversation first (required for message) from models.model import Conversation @@ -167,11 +165,11 @@ class TestWorkflowRunService: inputs={}, status="normal", mode="chat", - from_source=CreatorUserRole.ACCOUNT, + from_source=ConversationFromSource.CONSOLE, from_account_id=account.id, ) - db.session.add(conversation) - db.session.commit() + db_session_with_containers.add(conversation) + db_session_with_containers.commit() # Create message message = Message() @@ -188,17 +186,19 @@ class TestWorkflowRunService: message.answer_price_unit = 0.001 message.currency = "USD" message.status = "normal" - message.from_source = CreatorUserRole.ACCOUNT + message.from_source = ConversationFromSource.CONSOLE message.from_account_id = account.id message.workflow_run_id = workflow_run.id message.inputs = {"input": "test input"} - db.session.add(message) - db.session.commit() + db_session_with_containers.add(message) + db_session_with_containers.commit() return message - def test_get_paginate_workflow_runs_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_paginate_workflow_runs_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful pagination of workflow runs with debugging trigger. @@ -239,7 +239,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_with_last_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with last_id parameter. @@ -282,7 +282,7 @@ class TestWorkflowRunService: assert workflow_run.tenant_id == app.tenant_id def test_get_paginate_workflow_runs_default_limit( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test pagination of workflow runs with default limit. @@ -320,7 +320,7 @@ class TestWorkflowRunService: assert workflow_run_result.tenant_id == app.tenant_id def test_get_paginate_advanced_chat_workflow_runs_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful pagination of advanced chat workflow runs with message information. @@ -365,7 +365,7 @@ class TestWorkflowRunService: assert workflow_run.app_id == app.id assert workflow_run.tenant_id == app.tenant_id - def test_get_workflow_run_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of workflow run by ID. @@ -395,7 +395,7 @@ class TestWorkflowRunService: assert result.type == "chat" assert result.version == "1.0.0" - def test_get_workflow_run_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_workflow_run_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test workflow run retrieval when run ID does not exist. @@ -419,7 +419,7 @@ class TestWorkflowRunService: assert result is None def test_get_workflow_run_node_executions_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of workflow run node executions. @@ -438,7 +438,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create node executions - from extensions.ext_database import db from models.workflow import WorkflowNodeExecutionModel node_executions = [] @@ -462,7 +461,7 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) + db_session_with_containers.add(node_execution) node_executions.append(node_execution) paused_node_execution = WorkflowNodeExecutionModel( @@ -484,9 +483,9 @@ class TestWorkflowRunService: created_by=account.id, created_at=datetime.now(UTC), ) - db.session.add(paused_node_execution) + db_session_with_containers.add(paused_node_execution) - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() @@ -509,7 +508,7 @@ class TestWorkflowRunService: assert node_execution.node_id.startswith("node_") def test_get_workflow_run_node_executions_empty( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions for a workflow run with no executions. @@ -560,7 +559,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_invalid_workflow_run_id( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions with invalid workflow run ID. @@ -611,7 +610,7 @@ class TestWorkflowRunService: assert len(result) == 0 def test_get_workflow_run_node_executions_database_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test getting node executions when database encounters an error. @@ -662,7 +661,7 @@ class TestWorkflowRunService: ) def test_get_workflow_run_node_executions_end_user( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test node execution retrieval for end user. @@ -680,7 +679,6 @@ class TestWorkflowRunService: workflow_run = self._create_test_workflow_run(db_session_with_containers, app, account, "debugging") # Create end user - from extensions.ext_database import db from models.model import EndUser end_user = EndUser( @@ -692,8 +690,8 @@ class TestWorkflowRunService: external_user_id=str(uuid.uuid4()), name=fake.name(), ) - db.session.add(end_user) - db.session.commit() + db_session_with_containers.add(end_user) + db_session_with_containers.commit() # Create node execution from models.workflow import WorkflowNodeExecutionModel @@ -717,8 +715,8 @@ class TestWorkflowRunService: created_by=end_user.id, created_at=datetime.now(UTC), ) - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() # Act: Execute the method under test workflow_run_service = WorkflowRunService() diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index c29cda9a73..a5fe052206 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -10,6 +10,7 @@ from unittest.mock import MagicMock import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, App, Workflow from models.model import AppMode @@ -32,7 +33,7 @@ class TestWorkflowService: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -61,24 +62,22 @@ class TestWorkflowService: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.current_tenant_id tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at - from extensions.ext_database import db - - db.session.add(tenant) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.add(account) + db_session_with_containers.commit() # Set the current tenant for the account account.current_tenant = tenant return account - def _create_test_app(self, db_session_with_containers, fake=None): + def _create_test_app(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test app with realistic data. @@ -106,13 +105,11 @@ class TestWorkflowService: ) app.updated_by = app.created_by - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() return app - def _create_test_workflow(self, db_session_with_containers, app, account, fake=None): + def _create_test_workflow(self, db_session_with_containers: Session, app, account, fake=None): """ Helper method to create a test workflow associated with an app. @@ -141,13 +138,11 @@ class TestWorkflowService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() return workflow - def test_get_node_last_run_success(self, db_session_with_containers): + def test_get_node_last_run_success(self, db_session_with_containers: Session): """ Test successful retrieval of the most recent execution for a specific node. @@ -180,10 +175,8 @@ class TestWorkflowService: node_execution.created_by = account.id # Required field node_execution.created_at = fake.date_time_this_year() - from extensions.ext_database import db - - db.session.add(node_execution) - db.session.commit() + db_session_with_containers.add(node_execution) + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -196,7 +189,7 @@ class TestWorkflowService: assert result.workflow_id == workflow.id assert result.status == "succeeded" - def test_get_node_last_run_not_found(self, db_session_with_containers): + def test_get_node_last_run_not_found(self, db_session_with_containers: Session): """ Test retrieval when no execution record exists for the specified node. @@ -217,7 +210,7 @@ class TestWorkflowService: # Assert assert result is None - def test_is_workflow_exist_true(self, db_session_with_containers): + def test_is_workflow_exist_true(self, db_session_with_containers: Session): """ Test workflow existence check when a draft workflow exists. @@ -238,7 +231,7 @@ class TestWorkflowService: # Assert assert result is True - def test_is_workflow_exist_false(self, db_session_with_containers): + def test_is_workflow_exist_false(self, db_session_with_containers: Session): """ Test workflow existence check when no draft workflow exists. @@ -258,7 +251,7 @@ class TestWorkflowService: # Assert assert result is False - def test_get_draft_workflow_success(self, db_session_with_containers): + def test_get_draft_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of a draft workflow. @@ -284,7 +277,7 @@ class TestWorkflowService: assert result.app_id == app.id assert result.tenant_id == app.tenant_id - def test_get_draft_workflow_not_found(self, db_session_with_containers): + def test_get_draft_workflow_not_found(self, db_session_with_containers: Session): """ Test draft workflow retrieval when no draft workflow exists. @@ -304,7 +297,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_by_id_success(self, db_session_with_containers): + def test_get_published_workflow_by_id_success(self, db_session_with_containers: Session): """ Test successful retrieval of a published workflow by ID. @@ -321,9 +314,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -336,7 +327,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers): + def test_get_published_workflow_by_id_draft_error(self, db_session_with_containers: Session): """ Test error when trying to retrieve a draft workflow as published. @@ -359,7 +350,7 @@ class TestWorkflowService: with pytest.raises(IsDraftWorkflowError): workflow_service.get_published_workflow_by_id(app, workflow.id) - def test_get_published_workflow_by_id_not_found(self, db_session_with_containers): + def test_get_published_workflow_by_id_not_found(self, db_session_with_containers: Session): """ Test retrieval when no workflow exists with the specified ID. @@ -379,7 +370,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_published_workflow_success(self, db_session_with_containers): + def test_get_published_workflow_success(self, db_session_with_containers: Session): """ Test successful retrieval of the current published workflow for an app. @@ -395,10 +386,8 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -411,7 +400,7 @@ class TestWorkflowService: assert result.version != Workflow.VERSION_DRAFT assert result.app_id == app.id - def test_get_published_workflow_no_workflow_id(self, db_session_with_containers): + def test_get_published_workflow_no_workflow_id(self, db_session_with_containers: Session): """ Test retrieval when app has no associated workflow ID. @@ -431,7 +420,7 @@ class TestWorkflowService: # Assert assert result is None - def test_get_all_published_workflow_pagination(self, db_session_with_containers): + def test_get_all_published_workflow_pagination(self, db_session_with_containers: Session): """ Test pagination of published workflows. @@ -455,15 +444,13 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflows[0].id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - First page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=1, limit=3, @@ -476,7 +463,7 @@ class TestWorkflowService: # Act - Second page result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, + session=db_session_with_containers, app_model=app, page=2, limit=3, @@ -487,7 +474,7 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert has_more is False - def test_get_all_published_workflow_user_filter(self, db_session_with_containers): + def test_get_all_published_workflow_user_filter(self, db_session_with_containers: Session): """ Test filtering published workflows by user. @@ -513,22 +500,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter by account1 result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=account1.id + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=account1.id ) # Assert assert len(result_workflows) == 1 assert result_workflows[0].created_by == account1.id - def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers): + def test_get_all_published_workflow_named_only_filter(self, db_session_with_containers: Session): """ Test filtering published workflows to show only named workflows. @@ -557,22 +542,20 @@ class TestWorkflowService: # Set the app's workflow_id to the first workflow app.workflow_id = workflow1.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act - Filter named only result_workflows, has_more = workflow_service.get_all_published_workflow( - session=db.session, app_model=app, page=1, limit=10, user_id=None, named_only=True + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None, named_only=True ) # Assert assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) - def test_sync_draft_workflow_create_new(self, db_session_with_containers): + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -624,7 +607,7 @@ class TestWorkflowService: assert result.features == json.dumps(features) assert result.created_by == account.id - def test_sync_draft_workflow_update_existing(self, db_session_with_containers): + def test_sync_draft_workflow_update_existing(self, db_session_with_containers: Session): """ Test updating an existing draft workflow through sync operation. @@ -688,7 +671,7 @@ class TestWorkflowService: assert result.features == json.dumps(new_features) assert result.updated_by == account.id - def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers): + def test_sync_draft_workflow_hash_mismatch_error(self, db_session_with_containers: Session): """ Test error when sync is attempted with mismatched hash. @@ -738,7 +721,7 @@ class TestWorkflowService: conversation_variables=conversation_variables, ) - def test_publish_workflow_success(self, db_session_with_containers): + def test_publish_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow publishing. @@ -755,9 +738,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = Workflow.VERSION_DRAFT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -777,7 +758,7 @@ class TestWorkflowService: assert len(result.version) > 10 # Should be a reasonable timestamp length assert result.created_by == account.id - def test_publish_workflow_no_draft_error(self, db_session_with_containers): + def test_publish_workflow_no_draft_error(self, db_session_with_containers: Session): """ Test error when publishing workflow without draft. @@ -797,7 +778,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_publish_workflow_already_published_error(self, db_session_with_containers): + def test_publish_workflow_already_published_error(self, db_session_with_containers: Session): """ Test error when publishing already published workflow. @@ -813,9 +794,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Already published - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -823,7 +802,82 @@ class TestWorkflowService: with pytest.raises(ValueError, match="No valid workflow found"): workflow_service.publish_workflow(session=db_session_with_containers, app_model=app, account=account) - def test_get_default_block_configs(self, db_session_with_containers): + def test_restore_published_workflow_to_draft_does_not_persist_normalized_source_features( + self, db_session_with_containers: Session + ): + """Restore copies legacy feature JSON into draft without rewriting the source row.""" + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + app.mode = AppMode.ADVANCED_CHAT + + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + published_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version="2026.03.19.001", + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(legacy_features), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + draft_workflow = Workflow( + id=fake.uuid4(), + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + environment_variables=[], + conversation_variables=[], + ) + db_session_with_containers.add(published_workflow) + db_session_with_containers.add(draft_workflow) + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + restored_workflow = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=published_workflow.id, + account=account, + ) + + db_session_with_containers.expire_all() + refreshed_published_workflow = ( + db_session_with_containers.query(Workflow).filter_by(id=published_workflow.id).first() + ) + refreshed_draft_workflow = db_session_with_containers.query(Workflow).filter_by(id=draft_workflow.id).first() + + assert restored_workflow.id == draft_workflow.id + assert refreshed_published_workflow is not None + assert refreshed_draft_workflow is not None + assert refreshed_published_workflow.serialized_features == json.dumps(legacy_features) + assert refreshed_draft_workflow.serialized_features == json.dumps(legacy_features) + + def test_get_default_block_configs(self, db_session_with_containers: Session): """ Test retrieval of default block configurations for all node types. @@ -847,7 +901,7 @@ class TestWorkflowService: assert isinstance(config, dict) # The structure can vary, so we just check it's a dict - def test_get_default_block_config_specific_type(self, db_session_with_containers): + def test_get_default_block_config_specific_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for a specific node type. @@ -867,7 +921,7 @@ class TestWorkflowService: # This is acceptable behavior assert result is None or isinstance(result, dict) - def test_get_default_block_config_invalid_type(self, db_session_with_containers): + def test_get_default_block_config_invalid_type(self, db_session_with_containers: Session): """ Test retrieval of default block configuration for invalid node type. @@ -881,13 +935,13 @@ class TestWorkflowService: # Act try: result = workflow_service.get_default_block_config(node_type=invalid_node_type) - # If we get here, the service should return None for invalid types - assert result is None + # If we get here, the service should return an empty config for invalid types. + assert result == {} except ValueError: # It's also acceptable for the service to raise a ValueError for invalid types pass - def test_get_default_block_config_with_filters(self, db_session_with_containers): + def test_get_default_block_config_with_filters(self, db_session_with_containers: Session): """ Test retrieval of default block configuration with filters. @@ -907,7 +961,7 @@ class TestWorkflowService: # Result might be None if filters don't match, but should not raise error assert result is None or isinstance(result, dict) - def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_chat_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from chat mode app to workflow mode. @@ -944,11 +998,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -969,7 +1021,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers): + def test_convert_to_workflow_completion_mode_success(self, db_session_with_containers: Session): """ Test successful conversion from completion mode app to workflow mode. @@ -1006,11 +1058,9 @@ class TestWorkflowService: ) app_model_config.id = fake.uuid4() - from extensions.ext_database import db - - db.session.add(app_model_config) + db_session_with_containers.add(app_model_config) app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = { @@ -1031,7 +1081,7 @@ class TestWorkflowService: assert result.icon_type == conversion_args["icon_type"] assert result.icon_background == conversion_args["icon_background"] - def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers): + def test_convert_to_workflow_unsupported_mode_error(self, db_session_with_containers: Session): """ Test error when attempting to convert unsupported app mode. @@ -1046,9 +1096,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() conversion_args = {"name": "Test"} @@ -1057,7 +1105,7 @@ class TestWorkflowService: with pytest.raises(ValueError, match="Current App mode: workflow is not supported convert to workflow"): workflow_service.convert_to_workflow(app_model=app, account=account, args=conversion_args) - def test_validate_features_structure_advanced_chat(self, db_session_with_containers): + def test_validate_features_structure_advanced_chat(self, db_session_with_containers: Session): """ Test feature structure validation for advanced chat mode apps. @@ -1069,9 +1117,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.ADVANCED_CHAT - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = { @@ -1088,7 +1134,7 @@ class TestWorkflowService: # The exact behavior depends on the AdvancedChatAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_workflow(self, db_session_with_containers): + def test_validate_features_structure_workflow(self, db_session_with_containers: Session): """ Test feature structure validation for workflow mode apps. @@ -1100,9 +1146,7 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) app.mode = AppMode.WORKFLOW - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() features = {"workflow_config": {"max_steps": 10, "timeout": 300}} @@ -1115,30 +1159,27 @@ class TestWorkflowService: # The exact behavior depends on the WorkflowAppConfigManager implementation assert result is not None or isinstance(result, dict) - def test_validate_features_structure_invalid_mode(self, db_session_with_containers): + def test_validate_features_structure_invalid_mode(self, db_session_with_containers: Session): """ Test error when validating features for invalid app mode. This test ensures that the service correctly handles feature validation for unsupported app modes, preventing invalid operations. + With EnumText, invalid values are rejected at the DB level during flush, + raising StatementError wrapping ValueError. """ # Arrange fake = Faker() app = self._create_test_app(db_session_with_containers, fake) app.mode = "invalid_mode" # Invalid mode - from extensions.ext_database import db + # Act & Assert - EnumText validation rejects invalid values at DB flush + import sqlalchemy as sa - db.session.commit() + with pytest.raises((ValueError, sa.exc.StatementError)): + db_session_with_containers.commit() - workflow_service = WorkflowService() - features = {"test": "value"} - - # Act & Assert - with pytest.raises(ValueError, match="Invalid app mode: invalid_mode"): - workflow_service.validate_features_structure(app_model=app, features=features) - - def test_update_workflow_success(self, db_session_with_containers): + def test_update_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow update with allowed fields. @@ -1152,16 +1193,14 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = {"marked_name": "Updated Workflow Name", "marked_comment": "Updated workflow comment"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1174,7 +1213,7 @@ class TestWorkflowService: assert result.marked_comment == update_data["marked_comment"] assert result.updated_by == account.id - def test_update_workflow_not_found(self, db_session_with_containers): + def test_update_workflow_not_found(self, db_session_with_containers: Session): """ Test workflow update when workflow doesn't exist. @@ -1186,15 +1225,13 @@ class TestWorkflowService: account = self._create_test_account(db_session_with_containers, fake) app = self._create_test_app(db_session_with_containers, fake) - from extensions.ext_database import db - workflow_service = WorkflowService() non_existent_workflow_id = fake.uuid4() update_data = {"marked_name": "Test"} # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id, account_id=account.id, @@ -1204,7 +1241,7 @@ class TestWorkflowService: # Assert assert result is None - def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers): + def test_update_workflow_ignores_disallowed_fields(self, db_session_with_containers: Session): """ Test that workflow update ignores disallowed fields. @@ -1218,9 +1255,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) original_name = workflow.marked_name - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() update_data = { @@ -1231,7 +1266,7 @@ class TestWorkflowService: # Act result = workflow_service.update_workflow( - session=db.session, + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id, account_id=account.id, @@ -1245,7 +1280,7 @@ class TestWorkflowService: assert result.graph == workflow.graph assert result.features == workflow.features - def test_delete_workflow_success(self, db_session_with_containers): + def test_delete_workflow_success(self, db_session_with_containers: Session): """ Test successful workflow deletion. @@ -1262,25 +1297,23 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) workflow.version = "2024.01.01.001" # Published version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() # Act result = workflow_service.delete_workflow( - session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id ) # Assert assert result is True # Verify workflow is actually deleted - deleted_workflow = db.session.query(Workflow).filter_by(id=workflow.id).first() + deleted_workflow = db_session_with_containers.query(Workflow).filter_by(id=workflow.id).first() assert deleted_workflow is None - def test_delete_workflow_draft_error(self, db_session_with_containers): + def test_delete_workflow_draft_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a draft workflow. @@ -1296,9 +1329,7 @@ class TestWorkflowService: workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) # Keep as draft version - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1306,9 +1337,11 @@ class TestWorkflowService: from services.errors.workflow_service import DraftWorkflowDeletionError with pytest.raises(DraftWorkflowDeletionError, match="Cannot delete draft workflow versions"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_in_use_error(self, db_session_with_containers): + def test_delete_workflow_in_use_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a workflow that's in use by an app. @@ -1327,9 +1360,7 @@ class TestWorkflowService: # Associate workflow with app app.workflow_id = workflow.id - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() workflow_service = WorkflowService() @@ -1337,9 +1368,11 @@ class TestWorkflowService: from services.errors.workflow_service import WorkflowInUseError with pytest.raises(WorkflowInUseError, match="Cannot delete workflow that is currently in use by app"): - workflow_service.delete_workflow(session=db.session, workflow_id=workflow.id, tenant_id=workflow.tenant_id) + workflow_service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow.id, tenant_id=workflow.tenant_id + ) - def test_delete_workflow_not_found_error(self, db_session_with_containers): + def test_delete_workflow_not_found_error(self, db_session_with_containers: Session): """ Test error when attempting to delete a non-existent workflow. @@ -1351,17 +1384,15 @@ class TestWorkflowService: app = self._create_test_app(db_session_with_containers, fake) non_existent_workflow_id = fake.uuid4() - from extensions.ext_database import db - workflow_service = WorkflowService() # Act & Assert with pytest.raises(ValueError, match=f"Workflow with ID {non_existent_workflow_id} not found"): workflow_service.delete_workflow( - session=db.session, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id + session=db_session_with_containers, workflow_id=non_existent_workflow_id, tenant_id=app.tenant_id ) - def test_run_free_workflow_node_success(self, db_session_with_containers): + def test_run_free_workflow_node_success(self, db_session_with_containers: Session): """ Test successful execution of a free workflow node. @@ -1393,8 +1424,8 @@ class TestWorkflowService: from unittest.mock import patch - from core.app.workflow.node_factory import DifyNodeFactory from core.model_manager import ModelInstance + from core.workflow.node_factory import DifyNodeFactory # Act with patch.object( @@ -1413,7 +1444,7 @@ class TestWorkflowService: assert result.workflow_id == "" # No workflow ID for free nodes assert result.index == 1 - def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers): + def test_run_free_workflow_node_with_complex_inputs(self, db_session_with_containers: Session): """ Test execution of a free workflow node with complex input data. @@ -1454,7 +1485,7 @@ class TestWorkflowService: error_msg = str(exc_info.value).lower() assert any(keyword in error_msg for keyword in ["start", "not supported", "external"]) - def test_handle_node_run_result_success(self, db_session_with_containers): + def test_handle_node_run_result_success(self, db_session_with_containers: Session): """ Test successful handling of node run results. @@ -1472,14 +1503,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunSucceededEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunSucceededEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.START + mock_node.node_type = BuiltinNodeTypes.START mock_node.title = "Test Node" mock_node.error_strategy = None @@ -1496,7 +1527,7 @@ class TestWorkflowService: mock_event = NodeRunSucceededEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_run_result=mock_result, start_at=datetime.now(), ) @@ -1517,19 +1548,19 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from core.workflow.enums import NodeType + from dify_graph.enums import BuiltinNodeTypes - assert result.node_type == NodeType.START # Should match the mock node type + assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None assert result.outputs is not None assert result.process_data is not None - def test_handle_node_run_result_failure(self, db_session_with_containers): + def test_handle_node_run_result_failure(self, db_session_with_containers: Session): """ Test handling of failed node run results. @@ -1547,14 +1578,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunFailedEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunFailedEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.LLM + mock_node.node_type = BuiltinNodeTypes.LLM mock_node.title = "Test Node" mock_node.error_strategy = None @@ -1569,7 +1600,7 @@ class TestWorkflowService: mock_event = NodeRunFailedEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_run_result=mock_result, error="Test error message", start_at=datetime.now(), @@ -1592,13 +1623,13 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None assert "Test error message" in str(result.error) - def test_handle_node_run_result_continue_on_error(self, db_session_with_containers): + def test_handle_node_run_result_continue_on_error(self, db_session_with_containers: Session): """ Test handling of node run results with continue_on_error strategy. @@ -1616,14 +1647,14 @@ class TestWorkflowService: import uuid from datetime import datetime - from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus - from core.workflow.graph_events import NodeRunFailedEvent - from core.workflow.node_events import NodeRunResult - from core.workflow.nodes.base.node import Node + from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus + from dify_graph.graph_events import NodeRunFailedEvent + from dify_graph.node_events import NodeRunResult + from dify_graph.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) - mock_node.node_type = NodeType.TOOL + mock_node.node_type = BuiltinNodeTypes.TOOL mock_node.title = "Test Node" mock_node.error_strategy = ErrorStrategy.DEFAULT_VALUE mock_node.default_value_dict = {"default_output": "default_value"} @@ -1639,7 +1670,7 @@ class TestWorkflowService: mock_event = NodeRunFailedEvent( id=str(uuid.uuid4()), node_id=node_id, - node_type=NodeType.TOOL, + node_type=BuiltinNodeTypes.TOOL, node_run_result=mock_result, error="Test error message", start_at=datetime.now(), @@ -1662,7 +1693,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from core.workflow.enums import WorkflowNodeExecutionStatus + from dify_graph.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 4249642bc9..92dec24c7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from services.workspace_service import WorkspaceService @@ -29,7 +30,7 @@ class TestWorkspaceService: "dify_config": mock_dify_config, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -50,10 +51,8 @@ class TestWorkspaceService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +61,8 @@ class TestWorkspaceService: plan="basic", custom_config='{"replace_webapp_logo": true, "remove_webapp_brand": false}', ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +71,15 @@ class TestWorkspaceService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def test_get_tenant_info_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of tenant information with all features enabled. @@ -121,13 +120,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_without_custom_config( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval when custom config features are disabled. @@ -167,13 +165,12 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - from extensions.ext_database import db - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_normal_user_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for normal user role without privileged features. @@ -191,11 +188,14 @@ class TestWorkspaceService: ) # Update the join to have normal role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.NORMAL - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -220,11 +220,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_admin_role_and_logo_replacement( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for admin role with logo replacement enabled. @@ -242,11 +242,14 @@ class TestWorkspaceService: ) # Update the join to have admin role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.ADMIN - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -268,10 +271,12 @@ class TestWorkspaceService: assert "replace_webapp_logo" in result["custom_config"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None - def test_get_tenant_info_with_tenant_none(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_tenant_info_with_tenant_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant info retrieval when tenant parameter is None. @@ -290,7 +295,7 @@ class TestWorkspaceService: assert result is None def test_get_tenant_info_with_custom_config_variations( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with various custom config configurations. @@ -323,10 +328,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -353,11 +356,11 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] == config["remove_webapp_brand"] # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_editor_role_and_limited_permissions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for editor role with limited permissions. @@ -375,11 +378,14 @@ class TestWorkspaceService: ) # Update the join to have editor role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.EDITOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -400,11 +406,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_dataset_operator_role( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval for dataset operator role. @@ -422,11 +428,14 @@ class TestWorkspaceService: ) # Update the join to have dataset operator role - from extensions.ext_database import db - join = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() + join = ( + db_session_with_containers.query(TenantAccountJoin) + .filter_by(tenant_id=tenant.id, account_id=account.id) + .first() + ) join.role = TenantAccountRole.DATASET_OPERATOR - db.session.commit() + db_session_with_containers.commit() # Setup mocks for feature service and tenant service mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -447,11 +456,11 @@ class TestWorkspaceService: assert "custom_config" not in result # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None def test_get_tenant_info_with_complex_custom_config_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant info retrieval with complex custom config scenarios. @@ -491,10 +500,8 @@ class TestWorkspaceService: # Update tenant custom config import json - from extensions.ext_database import db - tenant.custom_config = json.dumps(config) - db.session.commit() + db_session_with_containers.commit() # Setup mocks mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True @@ -525,5 +532,5 @@ class TestWorkspaceService: assert result["custom_config"]["remove_webapp_brand"] is False # Verify database state - db.session.refresh(tenant) + db_session_with_containers.refresh(tenant) assert tenant.id is not None diff --git a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py index 2ff71ea6ea..bffdca623a 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_api_tools_manage_service.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import TypeAdapter, ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ApiProviderSchemaType from models import Account, Tenant @@ -34,7 +35,7 @@ class TestApiToolManageService: "provider_controller": mock_provider_controller, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -55,18 +56,16 @@ class TestApiToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -77,8 +76,8 @@ class TestApiToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -118,7 +117,7 @@ class TestApiToolManageService: """ def test_parser_api_schema_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful parsing of API schema. @@ -163,7 +162,7 @@ class TestApiToolManageService: assert api_key_value_field["default"] == "" def test_parser_api_schema_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of invalid API schema. @@ -183,7 +182,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_parser_api_schema_malformed_json( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test parsing of malformed JSON schema. @@ -203,7 +202,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_convert_schema_to_tool_bundles_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles. @@ -233,7 +232,7 @@ class TestApiToolManageService: assert tool_bundle.operation_id == "testOperation" def test_convert_schema_to_tool_bundles_with_extra_info( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of schema to tool bundles with extra info. @@ -259,7 +258,7 @@ class TestApiToolManageService: assert isinstance(schema_type, str) def test_convert_schema_to_tool_bundles_invalid_schema( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of invalid schema to tool bundles. @@ -279,7 +278,7 @@ class TestApiToolManageService: assert "invalid schema" in str(exc_info.value) def test_create_api_tool_provider_success( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider. @@ -324,10 +323,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) @@ -347,7 +345,7 @@ class TestApiToolManageService: mock_external_service_dependencies["provider_controller"].load_bundled_tools.assert_called_once() def test_create_api_tool_provider_duplicate_name( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with duplicate name. @@ -404,7 +402,7 @@ class TestApiToolManageService: assert f"provider {provider_name} already exists" in str(exc_info.value) def test_create_api_tool_provider_invalid_schema_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with invalid schema type. @@ -436,7 +434,7 @@ class TestApiToolManageService: assert "validation error" in str(exc_info.value) def test_create_api_tool_provider_missing_auth_type( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test creation of API tool provider with missing auth type. @@ -479,7 +477,7 @@ class TestApiToolManageService: assert "auth_type is required" in str(exc_info.value) def test_create_api_tool_provider_with_api_key_auth( - self, flask_req_ctx_with_containers, db_session_with_containers, mock_external_service_dependencies + self, flask_req_ctx_with_containers, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful creation of API tool provider with API key authentication. @@ -522,10 +520,9 @@ class TestApiToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db provider = ( - db.session.query(ApiToolProvider) + db_session_with_containers.query(ApiToolProvider) .filter(ApiToolProvider.tenant_id == tenant.id, ApiToolProvider.name == provider_name) .first() ) diff --git a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py index 6cae83ac37..0f2e3980af 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_mcp_tools_manage_service.py @@ -2,6 +2,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import ToolProviderType from models import Account, Tenant @@ -41,7 +42,7 @@ class TestMCPToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -62,18 +63,16 @@ class TestMCPToolManageService: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -84,8 +83,8 @@ class TestMCPToolManageService: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant @@ -93,7 +92,7 @@ class TestMCPToolManageService: return account, tenant def _create_test_mcp_provider( - self, db_session_with_containers, mock_external_service_dependencies, tenant_id, user_id + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant_id, user_id ): """ Helper method to create a test MCP tool provider for testing. @@ -124,15 +123,13 @@ class TestMCPToolManageService: sse_read_timeout=300.0, ) - from extensions.ext_database import db - - db.session.add(mcp_provider) - db.session.commit() + db_session_with_containers.add(mcp_provider) + db_session_with_containers.commit() return mcp_provider def test_get_mcp_provider_by_provider_id_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by provider ID. @@ -153,9 +150,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(provider_id=mcp_provider.id, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -166,12 +162,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.server_identifier == mcp_provider.server_identifier def test_get_mcp_provider_by_provider_id_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by provider ID. @@ -190,14 +186,13 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=non_existent_id, tenant_id=tenant.id) def test_get_mcp_provider_by_provider_id_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by provider ID. @@ -223,14 +218,13 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(provider_id=mcp_provider1.id, tenant_id=tenant2.id) def test_get_mcp_provider_by_server_identifier_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful retrieval of MCP provider by server identifier. @@ -251,9 +245,8 @@ class TestMCPToolManageService: ) # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.get_provider(server_identifier=mcp_provider.server_identifier, tenant_id=tenant.id) # Assert: Verify the expected outcomes @@ -264,12 +257,12 @@ class TestMCPToolManageService: assert result.user_id == account.id # Verify database state - db.session.refresh(result) + db_session_with_containers.refresh(result) assert result.id is not None assert result.name == mcp_provider.name def test_get_mcp_provider_by_server_identifier_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP provider is not found by server identifier. @@ -288,14 +281,13 @@ class TestMCPToolManageService: non_existent_identifier = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=non_existent_identifier, tenant_id=tenant.id) def test_get_mcp_provider_by_server_identifier_tenant_isolation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tenant isolation when retrieving MCP provider by server identifier. @@ -321,13 +313,12 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.get_provider(server_identifier=mcp_provider1.server_identifier, tenant_id=tenant2.id) - def test_create_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful creation of MCP provider. @@ -365,9 +356,8 @@ class TestMCPToolManageService: # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -389,10 +379,9 @@ class TestMCPToolManageService: assert result.type == ToolProviderType.MCP # Verify database state - from extensions.ext_database import db created_provider = ( - db.session.query(MCPToolProvider) + db_session_with_containers.query(MCPToolProvider) .filter(MCPToolProvider.tenant_id == tenant.id, MCPToolProvider.name == "Test MCP Provider") .first() ) @@ -410,7 +399,9 @@ class TestMCPToolManageService: ) mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_called_once() - def test_create_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when creating MCP provider with duplicate name. @@ -427,9 +418,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider", @@ -463,7 +453,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server URL. @@ -481,9 +471,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -517,7 +506,7 @@ class TestMCPToolManageService: ) def test_create_mcp_provider_duplicate_server_identifier( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when creating MCP provider with duplicate server identifier. @@ -535,9 +524,8 @@ class TestMCPToolManageService: # Create first provider from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.create_provider( tenant_id=tenant.id, name="Test MCP Provider 1", @@ -570,7 +558,7 @@ class TestMCPToolManageService: ), ) - def test_retrieve_mcp_tools_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful retrieval of MCP tools for a tenant. @@ -602,9 +590,7 @@ class TestMCPToolManageService: ) provider3.name = "Gamma Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Setup mock for transformation service from core.tools.entities.api_entities import ToolProviderApiEntity @@ -647,9 +633,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=True) # Assert: Verify the expected outcomes @@ -666,7 +651,9 @@ class TestMCPToolManageService: mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.call_count == 3 ) - def test_retrieve_mcp_tools_empty_list(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_empty_list( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test retrieval of MCP tools when tenant has no providers. @@ -684,9 +671,8 @@ class TestMCPToolManageService: # No MCP providers created for this tenant # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_providers(tenant_id=tenant.id, for_list=False) # Assert: Verify the expected outcomes @@ -697,7 +683,9 @@ class TestMCPToolManageService: # Verify no transformation service calls for empty list mock_external_service_dependencies["tool_transform_service"].mcp_provider_to_user_provider.assert_not_called() - def test_retrieve_mcp_tools_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_retrieve_mcp_tools_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when retrieving MCP tools. @@ -756,9 +744,8 @@ class TestMCPToolManageService: ] # Act: Execute the method under test for both tenants - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result1 = service.list_providers(tenant_id=tenant1.id, for_list=True) result2 = service.list_providers(tenant_id=tenant2.id, for_list=True) @@ -769,7 +756,7 @@ class TestMCPToolManageService: assert result2[0].id == provider2.id def test_list_mcp_tool_from_remote_server_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful listing of MCP tools from remote server. @@ -797,9 +784,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to list tools mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -821,9 +806,8 @@ class TestMCPToolManageService: mock_client_instance.list_tools.return_value = mock_tools # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) result = service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes @@ -834,7 +818,7 @@ class TestMCPToolManageService: # Note: server_url is mocked, so we skip that assertion to avoid encryption issues # Verify database state was updated - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.tools != "[]" assert mcp_provider.updated_at is not None @@ -844,7 +828,7 @@ class TestMCPToolManageService: mock_mcp_client.assert_called_once() def test_list_mcp_tool_from_remote_server_auth_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server requires authentication. @@ -871,9 +855,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -887,19 +869,18 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPAuthError("Authentication required") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Please auth the tool first"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" def test_list_mcp_tool_from_remote_server_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when MCP server connection fails. @@ -926,9 +907,7 @@ class TestMCPToolManageService: mcp_provider.authed = True # Provider must be authenticated to test connection errors mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the decryption process at the rsa level to avoid key file issues with patch("libs.rsa.decrypt") as mock_decrypt: @@ -942,18 +921,17 @@ class TestMCPToolManageService: mock_client_instance.list_tools.side_effect = MCPError("Connection failed") # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="Failed to connect to MCP server: Connection failed"): service.list_provider_tools(tenant_id=tenant.id, provider_id=mcp_provider.id) # Verify database state was not changed - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True # Provider remains authenticated assert mcp_provider.tools == "[]" - def test_delete_mcp_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful deletion of MCP tool. @@ -974,20 +952,19 @@ class TestMCPToolManageService: ) # Verify provider exists - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() is not None # Act: Execute the method under test - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.delete_provider(tenant_id=tenant.id, provider_id=mcp_provider.id) # Assert: Verify the expected outcomes # Provider should be deleted from database - deleted_provider = db.session.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() + deleted_provider = db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider.id).first() assert deleted_provider is None - def test_delete_mcp_tool_not_found(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_not_found(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test error handling when deleting non-existent MCP tool. @@ -1005,13 +982,14 @@ class TestMCPToolManageService: non_existent_id = str(fake.uuid4()) # Act & Assert: Verify proper error handling - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant.id, provider_id=non_existent_id) - def test_delete_mcp_tool_tenant_isolation(self, db_session_with_containers, mock_external_service_dependencies): + def test_delete_mcp_tool_tenant_isolation( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test tenant isolation when deleting MCP tool. @@ -1036,18 +1014,16 @@ class TestMCPToolManageService: ) # Act & Assert: Verify tenant isolation - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool not found"): service.delete_provider(tenant_id=tenant2.id, provider_id=mcp_provider1.id) # Verify provider still exists in tenant1 - from extensions.ext_database import db - assert db.session.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None + assert db_session_with_containers.query(MCPToolProvider).filter_by(id=mcp_provider1.id).first() is not None - def test_update_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful update of MCP provider. @@ -1070,14 +1046,12 @@ class TestMCPToolManageService: original_name = mcp_provider.name original_icon = mcp_provider.icon - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act: Execute the method under test from core.entities.mcp_provider import MCPConfiguration - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider( tenant_id=tenant.id, provider_id=mcp_provider.id, @@ -1094,7 +1068,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.name == "Updated MCP Provider" assert mcp_provider.server_identifier == "updated_identifier_123" assert mcp_provider.timeout == 45.0 @@ -1108,7 +1082,9 @@ class TestMCPToolManageService: assert icon_data["content"] == "🚀" assert icon_data["background"] == "#4ECDC4" - def test_update_mcp_provider_duplicate_name(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_mcp_provider_duplicate_name( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test error handling when updating MCP provider with duplicate name. @@ -1134,15 +1110,12 @@ class TestMCPToolManageService: ) provider2.name = "Second Provider" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Act & Assert: Verify proper error handling for duplicate name from core.entities.mcp_provider import MCPConfiguration - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) with pytest.raises(ValueError, match="MCP tool First Provider already exists"): service.update_provider( tenant_id=tenant.id, @@ -1160,7 +1133,7 @@ class TestMCPToolManageService: ) def test_update_mcp_provider_credentials_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful update of MCP provider credentials. @@ -1185,9 +1158,7 @@ class TestMCPToolManageService: mcp_provider.authed = False mcp_provider.tools = "[]" - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1202,9 +1173,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1213,7 +1183,7 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is True assert mcp_provider.updated_at is not None @@ -1225,7 +1195,7 @@ class TestMCPToolManageService: assert "new_key" in credentials def test_update_mcp_provider_credentials_not_authed( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test update of MCP provider credentials when not authenticated. @@ -1249,9 +1219,7 @@ class TestMCPToolManageService: mcp_provider.authed = True mcp_provider.tools = '[{"name": "test_tool"}]' - from extensions.ext_database import db - - db.session.commit() + db_session_with_containers.commit() # Mock the provider controller and encryption with ( @@ -1266,9 +1234,8 @@ class TestMCPToolManageService: mock_encrypter_instance.encrypt.return_value = {"new_key": "encrypted_value"} # Act: Execute the method under test - from extensions.ext_database import db - service = MCPToolManageService(db.session()) + service = MCPToolManageService(db_session_with_containers) service.update_provider_credentials( provider_id=mcp_provider.id, tenant_id=tenant.id, @@ -1277,12 +1244,14 @@ class TestMCPToolManageService: ) # Assert: Verify the expected outcomes - db.session.refresh(mcp_provider) + db_session_with_containers.refresh(mcp_provider) assert mcp_provider.authed is False assert mcp_provider.tools == "[]" assert mcp_provider.updated_at is not None - def test_re_connect_mcp_provider_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful reconnection to MCP provider. @@ -1343,7 +1312,9 @@ class TestMCPToolManageService: sse_read_timeout=mcp_provider.sse_read_timeout, ) - def test_re_connect_mcp_provider_auth_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_re_connect_mcp_provider_auth_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test reconnection to MCP provider when authentication fails. @@ -1385,7 +1356,7 @@ class TestMCPToolManageService: assert result.encrypted_credentials == "{}" def test_re_connect_mcp_provider_connection_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test reconnection to MCP provider when connection fails. diff --git a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py index fa13790942..0f38218c51 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_tools_transform_service.py @@ -2,6 +2,7 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject @@ -27,7 +28,7 @@ class TestToolTransformService: } def _create_test_tool_provider( - self, db_session_with_containers, mock_external_service_dependencies, provider_type="api" + self, db_session_with_containers: Session, mock_external_service_dependencies, provider_type="api" ): """ Helper method to create a test tool provider for testing. @@ -47,41 +48,42 @@ class TestToolTransformService: name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', - icon_dark='{"background": "#252525", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", - credentials={"auth_type": "api_key_header", "api_key": "test_key"}, - provider_type="api", + credentials_str='{"auth_type": "api_key_header", "api_key": "test_key"}', + schema="{}", + schema_type_str="openapi", + tools_str="[]", ) elif provider_type == "builtin": provider = BuiltinToolProvider( name=fake.company(), - description=fake.text(max_nb_chars=100), - icon="🔧", - icon_dark="🔧", tenant_id="test_tenant_id", + user_id="test_user_id", provider="test_provider", credential_type="api_key", - credentials={"api_key": "test_key"}, + encrypted_credentials='{"api_key": "test_key"}', ) elif provider_type == "workflow": provider = WorkflowToolProvider( name=fake.company(), description=fake.text(max_nb_chars=100), icon='{"background": "#FF6B6B", "content": "🔧"}', - icon_dark='{"background": "#252525", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", - workflow_id="test_workflow_id", + app_id="test_workflow_id", + label="Test Workflow", + version="1.0.0", + parameter_configuration="[]", ) elif provider_type == "mcp": provider = MCPToolProvider( name=fake.company(), - description=fake.text(max_nb_chars=100), - provider_icon='{"background": "#FF6B6B", "content": "🔧"}', + icon='{"background": "#FF6B6B", "content": "🔧"}', tenant_id="test_tenant_id", user_id="test_user_id", server_url="https://mcp.example.com", + server_url_hash="test_server_url_hash", server_identifier="test_server", tools='[{"name": "test_tool", "description": "Test tool"}]', authed=True, @@ -89,14 +91,12 @@ class TestToolTransformService: else: raise ValueError(f"Unknown provider type: {provider_type}") - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() return provider - def test_get_plugin_icon_url_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_get_plugin_icon_url_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful plugin icon URL generation. @@ -126,7 +126,7 @@ class TestToolTransformService: assert result == expected_url def test_get_plugin_icon_url_with_empty_console_url( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test plugin icon URL generation when CONSOLE_API_URL is empty. @@ -156,7 +156,7 @@ class TestToolTransformService: assert result == expected_url def test_get_tool_provider_icon_url_builtin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for builtin providers. @@ -194,7 +194,7 @@ class TestToolTransformService: assert result == expected_encoded def test_get_tool_provider_icon_url_api_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for API providers. @@ -220,7 +220,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_api_invalid_json( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for API providers with invalid JSON. @@ -246,7 +246,7 @@ class TestToolTransformService: assert result["content"] == "😁" or result["content"] == "\ud83d\ude01" def test_get_tool_provider_icon_url_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for workflow providers. @@ -271,7 +271,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_mcp_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful tool provider icon URL generation for MCP providers. @@ -296,7 +296,7 @@ class TestToolTransformService: assert result["content"] == "🔧" def test_get_tool_provider_icon_url_unknown_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test tool provider icon URL generation for unknown provider types. @@ -317,7 +317,9 @@ class TestToolTransformService: # Assert: Verify the expected outcomes assert result == "" - def test_repack_provider_dict_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_dict_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with dictionary input. @@ -341,7 +343,9 @@ class TestToolTransformService: # Note: provider name may contain spaces that get URL encoded assert provider["name"].replace(" ", "%20") in provider["icon"] or provider["name"] in provider["icon"] - def test_repack_provider_entity_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful provider repacking with ToolProviderApiEntity input. @@ -389,7 +393,7 @@ class TestToolTransformService: assert "test_icon_dark.png" in provider.icon_dark def test_repack_provider_entity_no_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful provider repacking with ToolProviderApiEntity input without plugin_id. @@ -435,7 +439,9 @@ class TestToolTransformService: assert provider.icon_dark["background"] == "#252525" assert provider.icon_dark["content"] == "🔧" - def test_repack_provider_entity_no_dark_icon(self, db_session_with_containers, mock_external_service_dependencies): + def test_repack_provider_entity_no_dark_icon( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test provider repacking with ToolProviderApiEntity input without dark icon. @@ -477,7 +483,7 @@ class TestToolTransformService: assert provider.icon_dark == "" def test_builtin_provider_to_user_provider_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider. @@ -545,7 +551,7 @@ class TestToolTransformService: assert result.original_credentials == {"api_key": "decrypted_key"} def test_builtin_provider_to_user_provider_plugin_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of builtin provider to user provider with plugin. @@ -589,7 +595,7 @@ class TestToolTransformService: assert result.allow_delete is False def test_builtin_provider_to_user_provider_no_credentials( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of builtin provider to user provider without credentials. @@ -630,7 +636,9 @@ class TestToolTransformService: assert result.allow_delete is False assert result.masked_credentials == {"api_key": ""} - def test_api_provider_to_controller_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_api_provider_to_controller_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion of API provider to controller. @@ -655,10 +663,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -669,7 +675,7 @@ class TestToolTransformService: # Additional assertions would depend on the actual controller implementation def test_api_provider_to_controller_api_key_query( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with api_key_query auth type. @@ -693,10 +699,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -706,7 +710,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_api_provider_to_controller_backward_compatibility( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test conversion of API provider to controller with backward compatibility auth types. @@ -731,10 +735,8 @@ class TestToolTransformService: tools_str="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Act: Execute the method under test result = ToolTransformService.api_provider_to_controller(provider) @@ -744,7 +746,7 @@ class TestToolTransformService: assert hasattr(result, "from_db") def test_workflow_provider_to_controller_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of workflow provider to controller. @@ -769,10 +771,8 @@ class TestToolTransformService: parameter_configuration="[]", ) - from extensions.ext_database import db - - db.session.add(provider) - db.session.commit() + db_session_with_containers.add(provider) + db_session_with_containers.commit() # Mock the WorkflowToolProviderController.from_db method to avoid app dependency with patch("services.tools.tools_transform_service.WorkflowToolProviderController.from_db") as mock_from_db: diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index 24fe5c4670..34906a4e54 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -4,6 +4,7 @@ from unittest.mock import patch import pytest from faker import Faker from pydantic import ValidationError +from sqlalchemy.orm import Session from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError @@ -12,6 +13,7 @@ from models.workflow import Workflow as WorkflowModel from services.account_service import AccountService, TenantService from services.app_service import AppService from services.tools.workflow_tools_manage_service import WorkflowToolManageService +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestWorkflowToolManageService: @@ -63,7 +65,7 @@ class TestWorkflowToolManageService: "tool_transform_service": mock_tool_transform_service, } - def _create_test_app_and_account(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_app_and_account(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test app and account for testing. @@ -86,7 +88,7 @@ class TestWorkflowToolManageService: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -119,14 +121,12 @@ class TestWorkflowToolManageService: conversation_variables=[], ) - from extensions.ext_database import db - - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Update app to reference the workflow app.workflow_id = workflow.id - db.session.commit() + db_session_with_containers.commit() return app, account, workflow @@ -153,7 +153,9 @@ class TestWorkflowToolManageService: ), ] - def test_create_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_create_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool creation with valid parameters. @@ -198,11 +200,10 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state - from extensions.ext_database import db # Check if workflow tool provider was created created_tool_provider = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -230,7 +231,7 @@ class TestWorkflowToolManageService: ].workflow_provider_to_controller.assert_called_once() def test_create_workflow_tool_duplicate_name_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when name already exists. @@ -280,10 +281,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {first_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -293,7 +293,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_invalid_app_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app does not exist. @@ -331,10 +331,9 @@ class TestWorkflowToolManageService: assert f"App {non_existent_app_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -344,7 +343,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_invalid_parameters_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when parameters are invalid. @@ -387,10 +386,9 @@ class TestWorkflowToolManageService: assert "validation error" in str(exc_info.value).lower() # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -400,7 +398,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_duplicate_app_id_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app_id already exists. @@ -450,10 +448,9 @@ class TestWorkflowToolManageService: assert f"Tool with name {second_tool_name} or app_id {app.id} already exists" in str(exc_info.value) # Verify only one tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -463,7 +460,7 @@ class TestWorkflowToolManageService: assert tool_count == 1 def test_create_workflow_tool_workflow_not_found_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when app has no workflow. @@ -481,10 +478,9 @@ class TestWorkflowToolManageService: ) # Remove workflow reference from app - from extensions.ext_database import db app.workflow_id = None - db.session.commit() + db_session_with_containers.commit() # Attempt to create workflow tool for app without workflow tool_parameters = self._create_test_workflow_tool_parameters() @@ -505,7 +501,7 @@ class TestWorkflowToolManageService: # Verify no workflow tool was created tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -515,7 +511,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_create_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation fails when workflow contains human input nodes. @@ -558,10 +554,8 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - from extensions.ext_database import db - tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -570,7 +564,9 @@ class TestWorkflowToolManageService: assert tool_count == 0 - def test_update_workflow_tool_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful workflow tool update with valid parameters. @@ -603,10 +599,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -641,7 +636,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify database state was updated - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool is not None assert created_tool.name == updated_tool_name assert created_tool.label == updated_tool_label @@ -658,7 +653,7 @@ class TestWorkflowToolManageService: mock_external_service_dependencies["tool_transform_service"].workflow_provider_to_controller.assert_called() def test_update_workflow_tool_human_input_node_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update fails when workflow contains human input nodes. @@ -689,10 +684,8 @@ class TestWorkflowToolManageService: parameters=initial_tool_parameters, ) - from extensions.ext_database import db - created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -712,7 +705,7 @@ class TestWorkflowToolManageService: ] } ) - db.session.commit() + db_session_with_containers.commit() with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: WorkflowToolManageService.update_workflow_tool( @@ -728,10 +721,12 @@ class TestWorkflowToolManageService: assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == original_name - def test_update_workflow_tool_not_found_error(self, db_session_with_containers, mock_external_service_dependencies): + def test_update_workflow_tool_not_found_error( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test workflow tool update fails when tool does not exist. @@ -768,10 +763,9 @@ class TestWorkflowToolManageService: assert f"Tool {non_existent_tool_id} not found" in str(exc_info.value) # Verify no workflow tool was created - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, ) @@ -781,7 +775,7 @@ class TestWorkflowToolManageService: assert tool_count == 0 def test_update_workflow_tool_same_name_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool update succeeds when keeping the same name. @@ -813,10 +807,9 @@ class TestWorkflowToolManageService: ) # Get the created tool - from extensions.ext_database import db created_tool = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.app_id == app.id, @@ -840,12 +833,12 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} # Verify tool still exists with the same name - db.session.refresh(created_tool) + db_session_with_containers.refresh(created_tool) assert created_tool.name == first_tool_name assert created_tool.updated_at is not None def test_create_workflow_tool_with_file_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILE parameter having a file object as default. @@ -916,7 +909,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_with_files_parameter_default( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test workflow tool creation with FILES (Array[File]) parameter having file objects as default. @@ -991,7 +984,7 @@ class TestWorkflowToolManageService: assert result == {"result": "success"} def test_create_workflow_tool_db_commit_before_validation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that database commit happens before validation, causing DB pollution on validation failure. @@ -1035,10 +1028,9 @@ class TestWorkflowToolManageService: # Verify the tool was NOT created in database # This is the expected behavior (no pollution) - from extensions.ext_database import db tool_count = ( - db.session.query(WorkflowToolProvider) + db_session_with_containers.query(WorkflowToolProvider) .where( WorkflowToolProvider.tenant_id == account.current_tenant.id, WorkflowToolProvider.name == tool_name, diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index 2ffb884b82..c3fe6a2950 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -3,6 +3,7 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -11,9 +12,9 @@ from core.app.app_config.entities import ( ModelConfigEntity, PromptTemplateEntity, ) -from core.model_runtime.entities.llm_entities import LLMMode from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant from models.api_based_extension import APIBasedExtension from models.model import App, AppMode, AppModelConfig @@ -79,7 +80,7 @@ class TestWorkflowConverter: mock_config.app_model_config_dict = {} return mock_config - def _create_test_account_and_tenant(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_account_and_tenant(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Helper method to create a test account and tenant for testing. @@ -100,18 +101,16 @@ class TestWorkflowConverter: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join from models.account import TenantAccountJoin, TenantAccountRole @@ -122,15 +121,17 @@ class TestWorkflowConverter: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_app(self, db_session_with_containers, mock_external_service_dependencies, tenant, account): + def _create_test_app( + self, db_session_with_containers: Session, mock_external_service_dependencies, tenant, account + ): """ Helper method to create a test app for testing. @@ -163,10 +164,8 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Create app model config app_model_config = AppModelConfig( @@ -177,16 +176,16 @@ class TestWorkflowConverter: created_by=account.id, updated_by=account.id, ) - db.session.add(app_model_config) - db.session.commit() + db_session_with_containers.add(app_model_config) + db_session_with_containers.commit() # Link app model config to app app.app_model_config_id = app_model_config.id - db.session.commit() + db_session_with_containers.commit() return app - def test_convert_to_workflow_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_workflow_success(self, db_session_with_containers: Session, mock_external_service_dependencies): """ Test successful conversion of app to workflow. @@ -225,19 +224,18 @@ class TestWorkflowConverter: assert new_app.created_by == account.id # Verify database state - from extensions.ext_database import db - db.session.refresh(new_app) + db_session_with_containers.refresh(new_app) assert new_app.id is not None # Verify workflow was created - workflow = db.session.query(Workflow).where(Workflow.app_id == new_app.id).first() + workflow = db_session_with_containers.query(Workflow).where(Workflow.app_id == new_app.id).first() assert workflow is not None assert workflow.tenant_id == app.tenant_id assert workflow.type == "chat" def test_convert_to_workflow_without_app_model_config_error( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling when app model config is missing. @@ -270,16 +268,14 @@ class TestWorkflowConverter: updated_by=account.id, ) - from extensions.ext_database import db - - db.session.add(app) - db.session.commit() + db_session_with_containers.add(app) + db_session_with_containers.commit() # Act & Assert: Verify proper error handling workflow_converter = WorkflowConverter() # Check initial state - initial_workflow_count = db.session.query(Workflow).count() + initial_workflow_count = db_session_with_containers.query(Workflow).count() with pytest.raises(ValueError, match="App model config is required"): workflow_converter.convert_to_workflow( @@ -294,12 +290,12 @@ class TestWorkflowConverter: # Verify database state remains unchanged # The workflow creation happens in convert_app_model_config_to_workflow # which is called before the app_model_config check, so we need to clean up - db.session.rollback() - final_workflow_count = db.session.query(Workflow).count() + db_session_with_containers.rollback() + final_workflow_count = db_session_with_containers.query(Workflow).count() assert final_workflow_count == initial_workflow_count def test_convert_app_model_config_to_workflow_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion of app model config to workflow. @@ -356,16 +352,17 @@ class TestWorkflowConverter: assert answer_node["id"] == "answer" # Verify database state - from extensions.ext_database import db - db.session.refresh(workflow) + db_session_with_containers.refresh(workflow) assert workflow.id is not None # Verify features were set features = json.loads(workflow._features) if workflow._features else {} assert isinstance(features, dict) - def test_convert_to_start_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_start_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to start node. @@ -410,7 +407,9 @@ class TestWorkflowConverter: assert second_variable["label"] == "Number Input" assert second_variable["type"] == "number" - def test_convert_to_http_request_node_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_convert_to_http_request_node_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful conversion to HTTP request node. @@ -436,10 +435,8 @@ class TestWorkflowConverter: api_endpoint="https://api.example.com/test", ) - from extensions.ext_database import db - - db.session.add(api_based_extension) - db.session.commit() + db_session_with_containers.add(api_based_extension) + db_session_with_containers.commit() # Mock encrypter mock_external_service_dependencies["encrypter"].decrypt_token.return_value = "decrypted_api_key" @@ -489,7 +486,7 @@ class TestWorkflowConverter: assert external_data_variable_node_mapping["external_data"] == code_node["id"] def test_convert_to_knowledge_retrieval_node_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful conversion to knowledge retrieval node. @@ -513,7 +510,7 @@ class TestWorkflowConverter: retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, top_k=10, score_threshold=0.8, - reranking_model={"provider": "cohere", "model": "rerank-v2"}, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, reranking_enabled=True, ), ) @@ -546,8 +543,8 @@ class TestWorkflowConverter: multiple_config = node["data"]["multiple_retrieval_config"] assert multiple_config["top_k"] == 10 assert multiple_config["score_threshold"] == 0.8 - assert multiple_config["reranking_model"]["provider"] == "cohere" - assert multiple_config["reranking_model"]["model"] == "rerank-v2" + assert multiple_config["reranking_model"]["reranking_provider_name"] == "cohere" + assert multiple_config["reranking_model"]["reranking_model_name"] == "rerank-v2" # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py new file mode 100644 index 0000000000..29e1e240b4 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_deletion.py @@ -0,0 +1,158 @@ +"""Testcontainers integration tests for WorkflowService.delete_workflow.""" + +import json +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session, sessionmaker + +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin +from models.model import App +from models.tools import WorkflowToolProvider +from models.workflow import Workflow +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService + + +class TestWorkflowDeletion: + def _create_tenant_and_account(self, session: Session) -> tuple[Tenant, Account]: + tenant = Tenant(name=f"Tenant {uuid4()}") + session.add(tenant) + session.flush() + + account = Account( + name=f"Account {uuid4()}", + email=f"wf_del_{uuid4()}@example.com", + password="hashed", + password_salt="salt", + interface_language="en-US", + timezone="UTC", + ) + session.add(account) + session.flush() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role="owner", + current=True, + ) + session.add(join) + session.flush() + return tenant, account + + def _create_app(self, session: Session, *, tenant: Tenant, account: Account, workflow_id: str | None = None) -> App: + app = App( + tenant_id=tenant.id, + name=f"App {uuid4()}", + description="", + mode="workflow", + icon_type="emoji", + icon="bot", + icon_background="#FFFFFF", + enable_site=False, + enable_api=True, + api_rpm=100, + api_rph=100, + is_demo=False, + is_public=False, + is_universal=False, + created_by=account.id, + updated_by=account.id, + workflow_id=workflow_id, + ) + session.add(app) + session.flush() + return app + + def _create_workflow( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str = "1.0" + ) -> Workflow: + workflow = Workflow( + id=str(uuid4()), + tenant_id=tenant.id, + app_id=app.id, + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + _features=json.dumps({}), + created_by=account.id, + updated_by=account.id, + ) + session.add(workflow) + session.flush() + return workflow + + def _create_tool_provider( + self, session: Session, *, tenant: Tenant, app: App, account: Account, version: str + ) -> WorkflowToolProvider: + provider = WorkflowToolProvider( + name=f"tool-{uuid4()}", + label=f"Tool {uuid4()}", + icon="wrench", + app_id=app.id, + version=version, + user_id=account.id, + tenant_id=tenant.id, + description="test tool provider", + ) + session.add(provider) + session.flush() + return provider + + def test_delete_workflow_success(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + db_session_with_containers.commit() + workflow_id = workflow.id + + service = WorkflowService(sessionmaker(bind=db.engine)) + result = service.delete_workflow( + session=db_session_with_containers, workflow_id=workflow_id, tenant_id=tenant.id + ) + + assert result is True + db_session_with_containers.expire_all() + assert db_session_with_containers.get(Workflow, workflow_id) is None + + def test_delete_draft_workflow_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="draft" + ) + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(DraftWorkflowDeletionError): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_in_use_by_app_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + # Point app to this workflow + app.workflow_id = workflow.id + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="currently in use by app"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) + + def test_delete_workflow_published_as_tool_raises_error(self, db_session_with_containers): + tenant, account = self._create_tenant_and_account(db_session_with_containers) + app = self._create_app(db_session_with_containers, tenant=tenant, account=account) + workflow = self._create_workflow( + db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0" + ) + self._create_tool_provider(db_session_with_containers, tenant=tenant, app=app, account=account, version="1.0") + db_session_with_containers.commit() + + service = WorkflowService(sessionmaker(bind=db.engine)) + with pytest.raises(WorkflowInUseError, match="published as a tool"): + service.delete_workflow(session=db_session_with_containers, workflow_id=workflow.id, tenant_id=tenant.id) diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index f3ba126706..af9e8d0b2c 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -4,7 +4,7 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowNodeExecutionStatus +from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py index 8bb536c34a..94173c34bf 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_add_document_to_index_task.py @@ -2,12 +2,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetAutoDisableLog, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.add_document_to_index_task import add_document_to_index_task @@ -31,7 +32,9 @@ class TestAddDocumentToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +54,15 @@ class TestAddDocumentToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +71,8 @@ class TestAddDocumentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -77,12 +80,12 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -90,24 +93,24 @@ class TestAddDocumentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document - def _create_test_segments(self, db_session_with_containers, document, dataset): + def _create_test_segments(self, db_session_with_containers: Session, document, dataset): """ Helper method to create test document segments. @@ -135,16 +138,18 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments - def test_add_document_to_index_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_add_document_to_index_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful document indexing with paragraph index type. @@ -180,9 +185,9 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -191,7 +196,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with different index types. @@ -209,10 +214,10 @@ class TestAddDocumentToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -237,9 +242,9 @@ class TestAddDocumentToIndexTask: assert len(documents) == 3 # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -248,7 +253,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -275,7 +280,7 @@ class TestAddDocumentToIndexTask: # because indexing_cache_key is not defined in that case def test_add_document_to_index_invalid_indexing_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid indexing status. @@ -293,8 +298,8 @@ class TestAddDocumentToIndexTask: ) # Set invalid indexing status - document.indexing_status = "processing" - db.session.commit() + document.indexing_status = IndexingStatus.INDEXING + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) @@ -304,7 +309,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_add_document_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when document's dataset doesn't exist. @@ -326,16 +331,16 @@ class TestAddDocumentToIndexTask: redis_client.set(indexing_cache_key, "processing", ex=300) # Delete the dataset to simulate dataset not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Act: Execute the task add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "doesn't exist" in document.error assert document.disabled_at is not None @@ -348,7 +353,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing with parent-child structure. @@ -367,10 +372,10 @@ class TestAddDocumentToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -413,9 +418,9 @@ class TestAddDocumentToIndexTask: assert len(doc.children) == 2 # Each document has 2 children # Verify database state changes - db.session.refresh(document) + db_session_with_containers.refresh(document) for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True assert segment.disabled_at is None assert segment.disabled_by is None @@ -424,13 +429,13 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_with_already_enabled_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test document indexing when segments are already enabled. This test verifies: - - Segments with status="completed" are processed regardless of enabled status + - Segments with status=SegmentStatus.COMPLETED are processed regardless of enabled status - Index processing occurs with all completed segments - Auto disable log deletion still occurs - Redis cache is cleared @@ -456,13 +461,13 @@ class TestAddDocumentToIndexTask: index_node_id=f"node_{i}", index_node_hash=f"hash_{i}", enabled=True, # Already enabled - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -478,7 +483,7 @@ class TestAddDocumentToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_called_once() # Verify the load method was called with all completed segments - # (implementation doesn't filter by enabled status, only by status="completed") + # (implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED) call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list @@ -488,7 +493,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_auto_disable_log_deletion( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test that auto disable logs are properly deleted during indexing. @@ -515,10 +520,10 @@ class TestAddDocumentToIndexTask: document_id=document.id, ) log_entry.id = str(fake.uuid4()) - db.session.add(log_entry) + db_session_with_containers.add(log_entry) auto_disable_logs.append(log_entry) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -526,7 +531,9 @@ class TestAddDocumentToIndexTask: # Verify logs exist before processing existing_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(existing_logs) == 2 @@ -535,7 +542,9 @@ class TestAddDocumentToIndexTask: # Assert: Verify auto disable logs were deleted remaining_logs = ( - db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == document.id).all() + db_session_with_containers.query(DatasetAutoDisableLog) + .where(DatasetAutoDisableLog.document_id == document.id) + .all() ) assert len(remaining_logs) == 0 @@ -547,14 +556,14 @@ class TestAddDocumentToIndexTask: # Verify segments were enabled for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify redis cache was cleared assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -584,29 +593,29 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False - assert document.indexing_status == "error" + assert document.indexing_status == IndexingStatus.ERROR assert document.error is not None assert "Index processing failed" in document.error assert document.disabled_at is not None # Verify segments were not enabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False # Should remain disabled due to error # Verify redis cache was still cleared despite error assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_segment_filtering_edge_cases( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segment filtering with various edge cases. This test verifies: - - Only segments with status="completed" are processed (regardless of enabled status) + - Only segments with status=SegmentStatus.COMPLETED are processed (regardless of enabled status) - Segments with status!="completed" are NOT processed - Segments are ordered by position correctly - Mixed segment states are handled properly @@ -622,7 +631,7 @@ class TestAddDocumentToIndexTask: fake = Faker() segments = [] - # Segment 1: Should be processed (enabled=False, status="completed") + # Segment 1: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment1 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -635,14 +644,14 @@ class TestAddDocumentToIndexTask: index_node_id="node_0", index_node_hash="hash_0", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) - db.session.add(segment1) + db_session_with_containers.add(segment1) segments.append(segment1) - # Segment 2: Should be processed (enabled=True, status="completed") - # Note: Implementation doesn't filter by enabled status, only by status="completed" + # Segment 2: Should be processed (enabled=True, status=SegmentStatus.COMPLETED) + # Note: Implementation doesn't filter by enabled status, only by status=SegmentStatus.COMPLETED segment2 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -655,10 +664,10 @@ class TestAddDocumentToIndexTask: index_node_id="node_1", index_node_hash="hash_1", enabled=True, # Already enabled, but will still be processed - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) - db.session.add(segment2) + db_session_with_containers.add(segment2) segments.append(segment2) # Segment 3: Should NOT be processed (enabled=False, status="processing") @@ -674,13 +683,13 @@ class TestAddDocumentToIndexTask: index_node_id="node_2", index_node_hash="hash_2", enabled=False, - status="processing", # Not completed + status=SegmentStatus.INDEXING, # Not completed created_by=document.created_by, ) - db.session.add(segment3) + db_session_with_containers.add(segment3) segments.append(segment3) - # Segment 4: Should be processed (enabled=False, status="completed") + # Segment 4: Should be processed (enabled=False, status=SegmentStatus.COMPLETED) segment4 = DocumentSegment( id=fake.uuid4(), tenant_id=document.tenant_id, @@ -693,13 +702,13 @@ class TestAddDocumentToIndexTask: index_node_id="node_3", index_node_hash="hash_3", enabled=False, - status="completed", + status=SegmentStatus.COMPLETED, created_by=document.created_by, ) - db.session.add(segment4) + db_session_with_containers.add(segment4) segments.append(segment4) - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -718,7 +727,7 @@ class TestAddDocumentToIndexTask: call_args = mock_external_service_dependencies["index_processor"].load.call_args assert call_args is not None documents = call_args[0][1] # Second argument should be documents list - assert len(documents) == 3 # 3 segments with status="completed" should be processed + assert len(documents) == 3 # 3 segments with status=SegmentStatus.COMPLETED should be processed # Verify correct segments were processed (by position order) # Segments 1, 2, 4 should be processed (positions 0, 1, 3) @@ -728,11 +737,11 @@ class TestAddDocumentToIndexTask: assert documents[2].metadata["doc_id"] == "node_3" # segment4, position 3 # Verify database state changes - db.session.refresh(document) - db.session.refresh(segment1) - db.session.refresh(segment2) - db.session.refresh(segment3) - db.session.refresh(segment4) + db_session_with_containers.refresh(document) + db_session_with_containers.refresh(segment1) + db_session_with_containers.refresh(segment2) + db_session_with_containers.refresh(segment3) + db_session_with_containers.refresh(segment4) # All segments should be enabled because the task updates ALL segments for the document assert segment1.enabled is True @@ -744,7 +753,7 @@ class TestAddDocumentToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_add_document_to_index_comprehensive_error_scenarios( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test comprehensive error scenarios and recovery. @@ -779,7 +788,7 @@ class TestAddDocumentToIndexTask: document.indexing_status = "completed" document.error = None document.disabled_at = None - db.session.commit() + db_session_with_containers.commit() # Set up Redis cache key indexing_cache_key = f"document_{document.id}_indexing" @@ -789,16 +798,16 @@ class TestAddDocumentToIndexTask: add_document_to_index_task(document.id) # Assert: Verify consistent error handling - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.enabled is False, f"Document should be disabled for {error_name}" - assert document.indexing_status == "error", f"Document status should be error for {error_name}" + assert document.indexing_status == IndexingStatus.ERROR, f"Document status should be error for {error_name}" assert document.error is not None, f"Error should be recorded for {error_name}" assert str(exception) in document.error, f"Error message should contain exception for {error_name}" assert document.disabled_at is not None, f"Disabled timestamp should be set for {error_name}" # Verify segments remain disabled due to error for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False, f"Segments should remain disabled for {error_name}" # Verify redis cache was still cleared despite error diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py index f94c5b19e6..210d9eb39e 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_clean_document_task.py @@ -11,11 +11,13 @@ from unittest.mock import Mock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_clean_document_task import batch_clean_document_task @@ -49,7 +51,7 @@ class TestBatchCleanDocumentTask: "get_image_ids": mock_get_image_ids, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -69,16 +71,16 @@ class TestBatchCleanDocumentTask: status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -87,15 +89,15 @@ class TestBatchCleanDocumentTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account - def _create_test_dataset(self, db_session_with_containers, account): + def _create_test_dataset(self, db_session_with_containers: Session, account): """ Helper method to create a test dataset for testing. @@ -113,18 +115,18 @@ class TestBatchCleanDocumentTask: tenant_id=account.current_tenant.id, name=fake.word(), description=fake.sentence(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, embedding_model="text-embedding-ada-002", embedding_model_provider="openai", ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, dataset, account): + def _create_test_document(self, db_session_with_containers: Session, dataset, account): """ Helper method to create a test document for testing. @@ -144,21 +146,21 @@ class TestBatchCleanDocumentTask: dataset_id=dataset.id, position=0, name=fake.word(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info=json.dumps({"upload_file_id": str(uuid.uuid4())}), batch="test_batch", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_document_segment(self, db_session_with_containers, document, account): + def _create_test_document_segment(self, db_session_with_containers: Session, document, account): """ Helper method to create a test document segment for testing. @@ -183,15 +185,15 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def _create_test_upload_file(self, db_session_with_containers, account): + def _create_test_upload_file(self, db_session_with_containers: Session, account): """ Helper method to create a test upload file for testing. @@ -208,7 +210,7 @@ class TestBatchCleanDocumentTask: upload_file = UploadFile( tenant_id=account.current_tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -220,13 +222,13 @@ class TestBatchCleanDocumentTask: used=False, ) - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file def test_batch_clean_document_task_successful_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful cleanup of documents with segments and files. @@ -245,7 +247,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -261,18 +263,18 @@ class TestBatchCleanDocumentTask: # The task should have processed the segment and cleaned up the database # Verify database cleanup - db.session.commit() # Ensure all changes are committed + db_session_with_containers.commit() # Ensure all changes are committed # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_with_image_files( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of documents containing image references. @@ -297,11 +299,11 @@ class TestBatchCleanDocumentTask: tokens=50, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Store original IDs for verification segment_id = segment.id @@ -313,17 +315,17 @@ class TestBatchCleanDocumentTask: ) # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Verify that the task completed successfully by checking the log output # The task should have processed the segment and cleaned up the database def test_batch_clean_document_task_no_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when document has no segments. @@ -339,7 +341,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -354,21 +356,21 @@ class TestBatchCleanDocumentTask: # Since there are no segments, the task should handle this gracefully # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when dataset is not found. @@ -386,8 +388,8 @@ class TestBatchCleanDocumentTask: dataset_id = dataset.id # Delete the dataset to simulate not found scenario - db.session.delete(dataset) - db.session.commit() + db_session_with_containers.delete(dataset) + db_session_with_containers.commit() # Execute the task with non-existent dataset batch_clean_document_task(document_ids=[document_id], dataset_id=dataset_id, doc_form="text_model", file_ids=[]) @@ -399,14 +401,14 @@ class TestBatchCleanDocumentTask: mock_external_service_dependencies["storage"].delete.assert_not_called() # Verify that no database cleanup occurred - db.session.commit() + db_session_with_containers.commit() # Document should still exist since cleanup failed - existing_document = db.session.query(Document).filter_by(id=document_id).first() + existing_document = db_session_with_containers.query(Document).filter_by(id=document_id).first() assert existing_document is not None def test_batch_clean_document_task_storage_cleanup_failure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup when storage operations fail. @@ -423,7 +425,7 @@ class TestBatchCleanDocumentTask: # Update document to reference the upload file document.data_source_info = json.dumps({"upload_file_id": upload_file.id}) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_id = document.id @@ -442,18 +444,18 @@ class TestBatchCleanDocumentTask: # The task should continue processing even when storage operations fail # Verify database cleanup still occurred despite storage failure - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted from database - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted from database - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_multiple_documents( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup of multiple documents in a single batch operation. @@ -482,7 +484,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -498,20 +500,20 @@ class TestBatchCleanDocumentTask: # The task should process all documents and clean up all associated resources # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_different_doc_forms( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup with different document form types. @@ -527,12 +529,12 @@ class TestBatchCleanDocumentTask: for doc_form in doc_forms: dataset = self._create_test_dataset(db_session_with_containers, account) - db.session.commit() + db_session_with_containers.commit() document = self._create_test_document(db_session_with_containers, dataset, account) # Update document doc_form document.doc_form = doc_form - db.session.commit() + db_session_with_containers.commit() segment = self._create_test_document_segment(db_session_with_containers, document, account) @@ -549,20 +551,20 @@ class TestBatchCleanDocumentTask: # The task should handle different document forms correctly # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that segment is deleted - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None except Exception as e: # If the task fails due to external service issues (e.g., plugin daemon), # we should still verify that the database state is consistent # This is a common scenario in test environments where external services may not be available - db.session.commit() + db_session_with_containers.commit() # Check if the segment still exists (task may have failed before deletion) - existing_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + existing_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() if existing_segment is not None: # If segment still exists, the task failed before deletion # This is acceptable in test environments with external service issues @@ -572,7 +574,7 @@ class TestBatchCleanDocumentTask: pass def test_batch_clean_document_task_large_batch_performance( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test cleanup performance with a large batch of documents. @@ -604,7 +606,7 @@ class TestBatchCleanDocumentTask: segments.append(segment) upload_files.append(upload_file) - db.session.commit() + db_session_with_containers.commit() # Store original IDs for verification document_ids = [doc.id for doc in documents] @@ -629,20 +631,20 @@ class TestBatchCleanDocumentTask: # The task should handle large batches efficiently # Verify database cleanup for all resources - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that all upload files are deleted for file_id in file_ids: - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None def test_batch_clean_document_task_integration_with_real_database( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test full integration with real database operations. @@ -671,7 +673,7 @@ class TestBatchCleanDocumentTask: tokens=25 + i * 5, index_node_id=str(uuid.uuid4()), created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) segments.append(segment) @@ -683,12 +685,12 @@ class TestBatchCleanDocumentTask: # Add all to database for segment in segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Verify initial state - assert db.session.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 - assert db.session.query(UploadFile).filter_by(id=upload_file.id).first() is not None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).count() == 3 + assert db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).first() is not None # Store original IDs for verification document_id = document.id @@ -704,17 +706,17 @@ class TestBatchCleanDocumentTask: # The task should process all segments and clean up all associated resources # Verify database cleanup - db.session.commit() + db_session_with_containers.commit() # Check that all segments are deleted for segment_id in segment_ids: - deleted_segment = db.session.query(DocumentSegment).filter_by(id=segment_id).first() + deleted_segment = db_session_with_containers.query(DocumentSegment).filter_by(id=segment_id).first() assert deleted_segment is None # Check that upload file is deleted - deleted_file = db.session.query(UploadFile).filter_by(id=file_id).first() + deleted_file = db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() assert deleted_file is None # Verify final database state - assert db.session.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 - assert db.session.query(UploadFile).filter_by(id=file_id).first() is None + assert db_session_with_containers.query(DocumentSegment).filter_by(document_id=document_id).count() == 0 + assert db_session_with_containers.query(UploadFile).filter_by(id=file_id).first() is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 2156743c17..202ccb0098 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -17,10 +17,12 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session +from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment -from models.enums import CreatorUserRole +from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from models.model import UploadFile from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task @@ -29,20 +31,19 @@ class TestBatchCreateSegmentToIndexTask: """Integration tests for batch_create_segment_to_index_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" - from extensions.ext_database import db from extensions.ext_redis import redis_client # Clear all test data - db.session.query(DocumentSegment).delete() - db.session.query(Document).delete() - db.session.query(Dataset).delete() - db.session.query(UploadFile).delete() - db.session.query(TenantAccountJoin).delete() - db.session.query(Tenant).delete() - db.session.query(Account).delete() - db.session.commit() + db_session_with_containers.query(DocumentSegment).delete() + db_session_with_containers.query(Document).delete() + db_session_with_containers.query(Dataset).delete() + db_session_with_containers.query(UploadFile).delete() + db_session_with_containers.query(TenantAccountJoin).delete() + db_session_with_containers.query(Tenant).delete() + db_session_with_containers.query(Account).delete() + db_session_with_containers.commit() # Clear Redis cache redis_client.flushdb() @@ -75,7 +76,7 @@ class TestBatchCreateSegmentToIndexTask: "embedding_model": mock_embedding_model, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -95,18 +96,16 @@ class TestBatchCreateSegmentToIndexTask: status="active", ) - from extensions.ext_database import db - - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant for the account tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -115,15 +114,15 @@ class TestBatchCreateSegmentToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -141,21 +140,19 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model="text-embedding-ada-002", embedding_model_provider="openai", created_by=account.id, ) - from extensions.ext_database import db - - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -174,26 +171,24 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="text_model", word_count=0, ) - from extensions.ext_database import db - - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -209,7 +204,7 @@ class TestBatchCreateSegmentToIndexTask: upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -221,10 +216,8 @@ class TestBatchCreateSegmentToIndexTask: used=False, ) - from extensions.ext_database import db - - db.session.add(upload_file) - db.session.commit() + db_session_with_containers.add(upload_file) + db_session_with_containers.commit() return upload_file @@ -252,7 +245,7 @@ class TestBatchCreateSegmentToIndexTask: return csv_content def test_batch_create_segment_to_index_task_success_text_model( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful batch creation of segments for text model documents. @@ -293,11 +286,10 @@ class TestBatchCreateSegmentToIndexTask: ) # Verify results - from extensions.ext_database import db # Check that segments were created segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -310,13 +302,13 @@ class TestBatchCreateSegmentToIndexTask: assert segment.dataset_id == dataset.id assert segment.document_id == document.id assert segment.position == i + 1 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.answer is None # text_model doesn't have answers # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called @@ -331,7 +323,7 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"completed" def test_batch_create_segment_to_index_task_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when dataset does not exist. @@ -370,17 +362,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created (since dataset doesn't exist) - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify no documents were modified - documents = db.session.query(Document).all() + documents = db_session_with_containers.query(Document).all() assert len(documents) == 0 def test_batch_create_segment_to_index_task_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document does not exist. @@ -419,18 +410,17 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify dataset remains unchanged (no segments were added to the dataset) - db.session.refresh(dataset) - segments_for_dataset = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() + db_session_with_containers.refresh(dataset) + segments_for_dataset = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all() assert len(segments_for_dataset) == 0 def test_batch_create_segment_to_index_task_document_not_available( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when document is not available for indexing. @@ -453,12 +443,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="disabled_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # Document is disabled archived=False, doc_form="text_model", @@ -469,12 +459,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="archived_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # Document is archived doc_form="text_model", @@ -485,12 +475,12 @@ class TestBatchCreateSegmentToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="incomplete_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="indexing", # Not completed + indexing_status=IndexingStatus.INDEXING, # Not completed enabled=True, archived=False, doc_form="text_model", @@ -498,11 +488,9 @@ class TestBatchCreateSegmentToIndexTask: ), ] - from extensions.ext_database import db - for document in test_cases: - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Test each unavailable document for document in test_cases: @@ -524,11 +512,11 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - segments = db.session.query(DocumentSegment).filter_by(document_id=document.id).all() + segments = db_session_with_containers.query(DocumentSegment).filter_by(document_id=document.id).all() assert len(segments) == 0 def test_batch_create_segment_to_index_task_upload_file_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when upload file does not exist. @@ -567,17 +555,16 @@ class TestBatchCreateSegmentToIndexTask: assert cache_value == b"error" # Verify no segments were created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_empty_csv_file( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test task failure when CSV file is empty. @@ -619,17 +606,16 @@ class TestBatchCreateSegmentToIndexTask: # Verify error handling # Since exception was raised, no segments should be created - from extensions.ext_database import db - segments = db.session.query(DocumentSegment).all() + segments = db_session_with_containers.query(DocumentSegment).all() assert len(segments) == 0 # Verify document remains unchanged - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count == 0 def test_batch_create_segment_to_index_task_position_calculation( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test proper position calculation for segments when existing segments exist. @@ -658,17 +644,15 @@ class TestBatchCreateSegmentToIndexTask: word_count=len(f"Existing segment {i + 1}"), tokens=10, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash=f"hash_{i}", ) existing_segments.append(segment) - from extensions.ext_database import db - for segment in existing_segments: - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() # Create CSV content csv_content = self._create_test_csv_content("text_model") @@ -695,7 +679,7 @@ class TestBatchCreateSegmentToIndexTask: # Verify results # Check that new segments were created with correct positions all_segments = ( - db.session.query(DocumentSegment) + db_session_with_containers.query(DocumentSegment) .filter_by(document_id=document.id) .order_by(DocumentSegment.position) .all() @@ -711,12 +695,12 @@ class TestBatchCreateSegmentToIndexTask: for i, segment in enumerate(new_segments): expected_position = 4 + i # Should start at position 4 assert segment.position == expected_position - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None # Check that document word count was updated - db.session.refresh(document) + db_session_with_containers.refresh(document) assert document.word_count > 0 # Verify vector service was called diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py index cd99b2965f..1cd698b870 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_dataset_task.py @@ -16,7 +16,9 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session +from extensions.storage.storage_type import StorageType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -28,7 +30,14 @@ from models.dataset import ( Document, DocumentSegment, ) -from models.enums import CreatorUserRole +from models.enums import ( + CreatorUserRole, + DatasetMetadataType, + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + SegmentStatus, +) from models.model import UploadFile from tasks.clean_dataset_task import clean_dataset_task @@ -37,7 +46,7 @@ class TestCleanDatasetTask: """Integration tests for clean_dataset_task using testcontainers.""" @pytest.fixture(autouse=True) - def cleanup_database(self, db_session_with_containers): + def cleanup_database(self, db_session_with_containers: Session): """Clean up database before each test to ensure isolation.""" from extensions.ext_redis import redis_client @@ -82,7 +91,7 @@ class TestCleanDatasetTask: "index_processor": mock_index_processor, } - def _create_test_account_and_tenant(self, db_session_with_containers): + def _create_test_account_and_tenant(self, db_session_with_containers: Session): """ Helper method to create a test account and tenant for testing. @@ -109,7 +118,7 @@ class TestCleanDatasetTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) @@ -127,7 +136,7 @@ class TestCleanDatasetTask: return account, tenant - def _create_test_dataset(self, db_session_with_containers, account, tenant): + def _create_test_dataset(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test dataset for testing. @@ -157,7 +166,7 @@ class TestCleanDatasetTask: return dataset - def _create_test_document(self, db_session_with_containers, account, tenant, dataset): + def _create_test_document(self, db_session_with_containers: Session, account, tenant, dataset): """ Helper method to create a test document for testing. @@ -175,12 +184,12 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name="test_document", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form="paragraph_index", @@ -194,7 +203,7 @@ class TestCleanDatasetTask: return document - def _create_test_segment(self, db_session_with_containers, account, tenant, dataset, document): + def _create_test_segment(self, db_session_with_containers: Session, account, tenant, dataset, document): """ Helper method to create a test document segment for testing. @@ -218,7 +227,7 @@ class TestCleanDatasetTask: word_count=20, tokens=30, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -230,7 +239,7 @@ class TestCleanDatasetTask: return segment - def _create_test_upload_file(self, db_session_with_containers, account, tenant): + def _create_test_upload_file(self, db_session_with_containers: Session, account, tenant): """ Helper method to create a test upload file for testing. @@ -246,7 +255,7 @@ class TestCleanDatasetTask: upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{fake.file_name()}", name=fake.file_name(), size=1024, @@ -264,7 +273,7 @@ class TestCleanDatasetTask: return upload_file def test_clean_dataset_task_success_basic_cleanup( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful basic dataset cleanup with minimal data. @@ -325,7 +334,7 @@ class TestCleanDatasetTask: mock_storage.delete.assert_not_called() def test_clean_dataset_task_success_with_documents_and_segments( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with documents and segments. @@ -372,7 +381,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name="test_metadata", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -433,7 +442,7 @@ class TestCleanDatasetTask: assert mock_storage.delete.call_count == 3 def test_clean_dataset_task_success_with_invalid_doc_form( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful dataset cleanup with invalid doc_form handling. @@ -493,7 +502,7 @@ class TestCleanDatasetTask: assert mock_factory.call_count == 4 def test_clean_dataset_task_error_handling_and_rollback( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test error handling and rollback mechanism when database operations fail. @@ -542,7 +551,7 @@ class TestCleanDatasetTask: # This demonstrates the resilience of the cleanup process def test_clean_dataset_task_with_image_file_references( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with image file references in document segments. @@ -586,7 +595,7 @@ class TestCleanDatasetTask: word_count=len(segment_content), tokens=50, created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash", created_at=datetime.now(), @@ -634,7 +643,7 @@ class TestCleanDatasetTask: mock_get_image_ids.assert_called_once() def test_clean_dataset_task_performance_with_large_dataset( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup performance with large amounts of data. @@ -685,7 +694,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"test_metadata_{i}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) metadata.id = str(uuid.uuid4()) @@ -704,11 +713,9 @@ class TestCleanDatasetTask: binding.created_at = datetime.now() bindings.append(binding) - from extensions.ext_database import db - - db.session.add_all(metadata_items) - db.session.add_all(bindings) - db.session.commit() + db_session_with_containers.add_all(metadata_items) + db_session_with_containers.add_all(bindings) + db_session_with_containers.commit() # Measure cleanup performance import time @@ -772,7 +779,7 @@ class TestCleanDatasetTask: print(f"Average time per document: {cleanup_duration / len(documents):.3f} seconds") def test_clean_dataset_task_storage_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup when storage operations fail. @@ -838,7 +845,7 @@ class TestCleanDatasetTask: # consistency in the database def test_clean_dataset_task_edge_cases_and_boundary_conditions( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test dataset cleanup with edge cases and boundary conditions. @@ -881,11 +888,11 @@ class TestCleanDatasetTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, data_source_info="{}", batch="test_batch", name=f"test_doc_{special_content}", - created_from="test", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, created_at=datetime.now(), updated_at=datetime.now(), @@ -906,7 +913,7 @@ class TestCleanDatasetTask: word_count=len(segment_content.split()), tokens=len(segment_content) // 4, # Rough token estimation created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, index_node_id=str(uuid.uuid4()), index_node_hash="test_hash_" + "x" * 50, # Long hash within limits created_at=datetime.now(), @@ -919,7 +926,7 @@ class TestCleanDatasetTask: special_filename = f"test_file_{special_content}.txt" upload_file = UploadFile( tenant_id=tenant.id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test_files/{special_filename}", name=special_filename, size=1024, @@ -947,7 +954,7 @@ class TestCleanDatasetTask: dataset_id=dataset.id, tenant_id=tenant.id, name=f"metadata_{special_content}", - type="string", + type=DatasetMetadataType.STRING, created_by=account.id, ) special_metadata.id = str(uuid.uuid4()) diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index 379986c191..a2a190fd69 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -13,8 +13,10 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.clean_notion_document_task import clean_notion_document_task +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestCleanNotionDocumentTask: @@ -76,7 +78,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -87,7 +89,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -104,17 +106,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", # Set doc_form to ensure dataset.doc_form works doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -133,7 +135,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -208,7 +210,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -219,7 +221,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -252,7 +254,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -268,7 +270,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{index_type}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -280,17 +282,17 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form=index_type, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -307,7 +309,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -345,7 +347,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -356,7 +358,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -368,16 +370,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -396,7 +398,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=None, # No index node ID created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) segments.append(segment) @@ -431,7 +433,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -442,7 +444,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -459,16 +461,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -487,7 +489,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -546,7 +548,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -557,7 +559,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -569,22 +571,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() # Create segments with different statuses - segment_statuses = ["waiting", "processing", "completed", "error"] + segment_statuses = [SegmentStatus.WAITING, SegmentStatus.INDEXING, SegmentStatus.COMPLETED, SegmentStatus.ERROR] segments = [] index_node_ids = [] @@ -642,7 +644,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -653,7 +655,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -665,16 +667,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": "workspace_test", "notion_page_id": "page_test", "type": "page"} ), batch="test_batch", name="Test Notion Page", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -691,7 +693,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id="test_node", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) db_session_with_containers.commit() @@ -724,7 +726,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -735,7 +737,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -753,16 +755,16 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -782,7 +784,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -834,7 +836,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -847,7 +849,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=f"{fake.company()}_{i}", description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -865,16 +867,16 @@ class TestCleanNotionDocumentTask: tenant_id=account.current_tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) db_session_with_containers.add(document) db_session_with_containers.flush() @@ -893,7 +895,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -951,7 +953,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -962,14 +964,22 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, ) db_session_with_containers.add(dataset) db_session_with_containers.flush() # Create documents with different indexing statuses - document_statuses = ["waiting", "parsing", "cleaning", "splitting", "indexing", "completed", "error"] + document_statuses = [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ] documents = [] all_segments = [] all_index_node_ids = [] @@ -980,13 +990,13 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( {"notion_workspace_id": f"workspace_{i}", "notion_page_id": f"page_{i}", "type": "page"} ), batch="test_batch", name=f"Notion Page {i}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", indexing_status=status, @@ -1008,7 +1018,7 @@ class TestCleanNotionDocumentTask: tokens=50, index_node_id=f"node_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, ) db_session_with_containers.add(segment) all_segments.append(segment) @@ -1054,7 +1064,7 @@ class TestCleanNotionDocumentTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -1065,7 +1075,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, created_by=account.id, built_in_field_enabled=True, ) @@ -1078,7 +1088,7 @@ class TestCleanNotionDocumentTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps( { "notion_workspace_id": "workspace_test", @@ -1090,10 +1100,10 @@ class TestCleanNotionDocumentTask: ), batch="test_batch", name="Test Notion Page with Metadata", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_metadata={ "document_name": "Test Notion Page with Metadata", "uploader": account.name, @@ -1121,7 +1131,7 @@ class TestCleanNotionDocumentTask: tokens=75, index_node_id=f"node_{i}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, keywords={"key1": ["value1", "value2"], "key2": ["value3"]}, ) db_session_with_containers.add(segment) diff --git a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py index 4fa52ff2a9..132f43c320 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_create_segment_to_index_task.py @@ -15,6 +15,7 @@ from faker import Faker from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.create_segment_to_index_task import create_segment_to_index_task @@ -118,7 +119,7 @@ class TestCreateSegmentToIndexTask: name=fake.company(), description=fake.text(max_nb_chars=100), tenant_id=tenant_id, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", embedding_model_provider="openai", embedding_model="text-embedding-ada-002", @@ -133,13 +134,13 @@ class TestCreateSegmentToIndexTask: dataset_id=dataset.id, tenant_id=tenant_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account_id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="qa_model", ) db_session_with_containers.add(document) @@ -148,7 +149,7 @@ class TestCreateSegmentToIndexTask: return dataset, document def _create_test_segment( - self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status="waiting" + self, db_session_with_containers, dataset_id, document_id, tenant_id, account_id, status=SegmentStatus.WAITING ): """ Helper method to create a test document segment for testing. @@ -200,7 +201,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -208,7 +209,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify segment status changes db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None @@ -257,7 +258,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="completed" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.COMPLETED ) # Act: Execute the task @@ -268,7 +269,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status unchanged db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is None # Verify no index processor calls were made @@ -293,20 +294,25 @@ class TestCreateSegmentToIndexTask: dataset_id=invalid_dataset_id, tenant_id=tenant.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, enabled=True, archived=False, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, doc_form="text_model", ) db_session_with_containers.add(document) db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, invalid_dataset_id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, + invalid_dataset_id, + document.id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -317,7 +323,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -337,7 +343,12 @@ class TestCreateSegmentToIndexTask: invalid_document_id = str(uuid4()) segment = self._create_test_segment( - db_session_with_containers, dataset.id, invalid_document_id, tenant.id, account.id, status="waiting" + db_session_with_containers, + dataset.id, + invalid_document_id, + tenant.id, + account.id, + status=SegmentStatus.WAITING, ) # Act: Execute the task @@ -348,7 +359,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -373,7 +384,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -384,7 +395,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -409,7 +420,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -420,7 +431,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -445,7 +456,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -456,7 +467,7 @@ class TestCreateSegmentToIndexTask: # Verify segment status changed to indexing (task updates status before checking document) db_session_with_containers.refresh(segment) - assert segment.status == "indexing" + assert segment.status == SegmentStatus.INDEXING # Verify no index processor calls were made mock_external_service_dependencies["index_processor_factory"].assert_not_called() @@ -477,7 +488,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Mock processor to raise exception @@ -488,7 +499,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error == "Processor failed" @@ -512,7 +523,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) custom_keywords = ["custom", "keywords", "test"] @@ -521,7 +532,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -555,7 +566,7 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.commit() segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -563,7 +574,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify correct doc_form was passed to factory mock_external_service_dependencies["index_processor_factory"].assert_called_with(doc_form) @@ -583,7 +594,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task and measure time @@ -597,7 +608,7 @@ class TestCreateSegmentToIndexTask: # Verify successful completion db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_create_segment_to_index_concurrent_execution( self, db_session_with_containers, mock_external_service_dependencies @@ -617,7 +628,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(3): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -629,7 +640,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify all segments processed for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -665,7 +676,7 @@ class TestCreateSegmentToIndexTask: keywords=["large", "content", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -681,7 +692,7 @@ class TestCreateSegmentToIndexTask: assert execution_time < 10.0 # Should complete within 10 seconds db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -700,7 +711,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Set up Redis cache key to simulate indexing in progress @@ -718,7 +729,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify indexing still completed successfully despite Redis failure db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -740,7 +751,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Simulate an error during indexing to trigger rollback path @@ -752,7 +763,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify error handling and rollback db_session_with_containers.refresh(segment) - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.enabled is False assert segment.disabled_at is not None assert segment.error is not None @@ -772,7 +783,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task @@ -780,7 +791,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED # Verify index processor was called with correct metadata mock_processor = mock_external_service_dependencies["index_processor"] @@ -814,11 +825,11 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Verify initial state - assert segment.status == "waiting" + assert segment.status == SegmentStatus.WAITING assert segment.indexing_at is None assert segment.completed_at is None @@ -827,7 +838,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify final state db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -861,7 +872,7 @@ class TestCreateSegmentToIndexTask: keywords=[], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -872,7 +883,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -907,7 +918,7 @@ class TestCreateSegmentToIndexTask: keywords=["special", "unicode", "test"], index_node_id=str(uuid4()), index_node_hash=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, created_by=account.id, ) db_session_with_containers.add(segment) @@ -918,7 +929,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -937,7 +948,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Create long keyword list @@ -948,7 +959,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -979,10 +990,10 @@ class TestCreateSegmentToIndexTask: ) segment1 = self._create_test_segment( - db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status="waiting" + db_session_with_containers, dataset1.id, document1.id, tenant1.id, account1.id, status=SegmentStatus.WAITING ) segment2 = self._create_test_segment( - db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status="waiting" + db_session_with_containers, dataset2.id, document2.id, tenant2.id, account2.id, status=SegmentStatus.WAITING ) # Act: Execute tasks for both tenants @@ -993,8 +1004,8 @@ class TestCreateSegmentToIndexTask: db_session_with_containers.refresh(segment1) db_session_with_containers.refresh(segment2) - assert segment1.status == "completed" - assert segment2.status == "completed" + assert segment1.status == SegmentStatus.COMPLETED + assert segment2.status == SegmentStatus.COMPLETED assert segment1.tenant_id == tenant1.id assert segment2.tenant_id == tenant2.id assert segment1.tenant_id != segment2.tenant_id @@ -1014,7 +1025,7 @@ class TestCreateSegmentToIndexTask: account, tenant = self._create_test_account_and_tenant(db_session_with_containers) dataset, document = self._create_test_dataset_and_document(db_session_with_containers, tenant.id, account.id) segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) # Act: Execute the task with None keywords @@ -1022,7 +1033,7 @@ class TestCreateSegmentToIndexTask: # Assert: Verify successful indexing db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None @@ -1050,7 +1061,7 @@ class TestCreateSegmentToIndexTask: segments = [] for i in range(5): segment = self._create_test_segment( - db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting" + db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status=SegmentStatus.WAITING ) segments.append(segment) @@ -1067,7 +1078,7 @@ class TestCreateSegmentToIndexTask: # Verify all segments processed successfully for segment in segments: db_session_with_containers.refresh(segment) - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED assert segment.indexing_at is not None assert segment.completed_at is not None assert segment.error is None diff --git a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py index 207bdad751..67f9dc7011 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_dataset_indexing_task.py @@ -11,6 +11,7 @@ from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, _document_indexing_with_tenant_queue, @@ -139,7 +140,7 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -155,12 +156,12 @@ class TestDatasetIndexingTaskIntegration: tenant_id=tenant.id, dataset_id=dataset.id, position=position, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=f"doc-{position}.txt", - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -181,7 +182,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None def _assert_documents_error_contains( @@ -195,7 +196,7 @@ class TestDatasetIndexingTaskIntegration: for document_id in document_ids: updated = self._query_document(db_session_with_containers, document_id) assert updated is not None - assert updated.indexing_status == "error" + assert updated.indexing_status == IndexingStatus.ERROR assert updated.error is not None assert expected_error_substring in updated.error assert updated.stopped_at is not None @@ -322,11 +323,14 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_called_once_with( - tenant_id=next_task["tenant_id"], - dataset_id=next_task["dataset_id"], - document_ids=next_task["document_ids"], - ) + # apply_async is used by implementation; assert it was called once with expected kwargs + assert task_dispatch_spy.apply_async.call_count == 1 + call_kwargs = task_dispatch_spy.apply_async.call_args.kwargs.get("kwargs", {}) + assert call_kwargs == { + "tenant_id": next_task["tenant_id"], + "dataset_id": next_task["dataset_id"], + "document_ids": next_task["document_ids"], + } set_waiting_spy.assert_called_once() delete_key_spy.assert_not_called() @@ -352,7 +356,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_not_called() + task_dispatch_spy.apply_async.assert_not_called() delete_key_spy.assert_called_once() def test_validation_failure_sets_error_status_when_vector_space_at_limit( @@ -447,7 +451,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - task_dispatch_spy.delay.assert_called_once() + task_dispatch_spy.apply_async.assert_called_once() def test_sessions_close_on_successful_indexing( self, @@ -534,7 +538,7 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - assert task_dispatch_spy.delay.call_count == concurrency_limit + assert task_dispatch_spy.apply_async.call_count == concurrency_limit assert set_waiting_spy.call_count == concurrency_limit def test_task_queue_fifo_ordering(self, db_session_with_containers, patched_external_dependencies): @@ -565,9 +569,10 @@ class TestDatasetIndexingTaskIntegration: _document_indexing_with_tenant_queue(dataset.tenant_id, dataset.id, document_ids, task_dispatch_spy) # Assert - assert task_dispatch_spy.delay.call_count == 3 + assert task_dispatch_spy.apply_async.call_count == 3 for index, expected_task in enumerate(ordered_tasks): - assert task_dispatch_spy.delay.call_args_list[index].kwargs["document_ids"] == expected_task["document_ids"] + call_kwargs = task_dispatch_spy.apply_async.call_args_list[index].kwargs.get("kwargs", {}) + assert call_kwargs.get("document_ids") == expected_task["document_ids"] def test_billing_disabled_skips_limit_checks(self, db_session_with_containers, patched_external_dependencies): """Skip limit checks when billing feature is disabled.""" diff --git a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py index 58c3ab5509..e80b37ac1b 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_deal_dataset_vector_index_task.py @@ -13,8 +13,10 @@ import pytest from faker import Faker from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from services.account_service import AccountService, TenantService from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task +from tests.test_containers_integration_tests.helpers import generate_valid_password class TestDealDatasetVectorIndexTask: @@ -61,7 +63,7 @@ class TestDealDatasetVectorIndexTask: email=fake.email(), name=fake.name(), interface_language="en-US", - password=fake.password(length=12), + password=generate_valid_password(fake), ) TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) tenant = account.current_tenant @@ -89,7 +91,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -101,13 +103,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -149,7 +151,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -161,13 +163,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -181,13 +183,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -208,7 +210,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -219,7 +221,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load method was called mock_factory = mock_index_processor_factory.return_value @@ -250,7 +252,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -262,13 +264,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -282,13 +284,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="parent_child_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -309,7 +311,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -320,7 +322,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor clean and load methods were called mock_factory = mock_index_processor_factory.return_value @@ -366,7 +368,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -398,7 +400,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -410,13 +412,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -429,7 +431,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify that no index processor load was called since no segments exist mock_factory = mock_index_processor_factory.return_value @@ -454,7 +456,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -487,7 +489,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -499,13 +501,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -519,13 +521,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -546,7 +548,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -562,7 +564,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to error updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Test exception during indexing" in updated_document.error def test_deal_dataset_vector_index_task_with_custom_index_type( @@ -583,7 +585,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -595,13 +597,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="qa_index", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -622,7 +624,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -633,7 +635,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with custom index type mock_index_processor_factory.assert_called_once_with("qa_index") @@ -659,7 +661,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -671,13 +673,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -698,7 +700,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -709,7 +711,7 @@ class TestDealDatasetVectorIndexTask: # Verify document status was updated to indexing then completed updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor was initialized with the document's index type mock_index_processor_factory.assert_called_once_with("text_model") @@ -735,7 +737,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -747,13 +749,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -769,13 +771,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name=f"Test Document {i}", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -800,7 +802,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{i}_{j}", index_node_hash=f"hash_{i}_{j}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -813,7 +815,7 @@ class TestDealDatasetVectorIndexTask: # Verify all documents were processed for document in documents: updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED # Verify index processor load was called multiple times mock_factory = mock_index_processor_factory.return_value @@ -838,7 +840,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -850,13 +852,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -870,13 +872,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Test Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -897,7 +899,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -915,7 +917,7 @@ class TestDealDatasetVectorIndexTask: # Verify final document status updated_document = db_session_with_containers.query(Document).filter_by(id=document.id).first() - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED def test_deal_dataset_vector_index_task_with_disabled_documents( self, db_session_with_containers, mock_index_processor_factory, account_and_tenant @@ -935,7 +937,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -947,13 +949,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -967,13 +969,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Enabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -986,13 +988,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Disabled Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, # This document should be skipped archived=False, batch="test_batch", @@ -1014,7 +1016,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1025,13 +1027,13 @@ class TestDealDatasetVectorIndexTask: # Verify only enabled document was processed updated_enabled_document = db_session_with_containers.query(Document).filter_by(id=enabled_document.id).first() - assert updated_enabled_document.indexing_status == "completed" + assert updated_enabled_document.indexing_status == IndexingStatus.COMPLETED # Verify disabled document status remains unchanged updated_disabled_document = ( db_session_with_containers.query(Document).filter_by(id=disabled_document.id).first() ) - assert updated_disabled_document.indexing_status == "completed" # Should not change + assert updated_disabled_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for enabled document) mock_factory = mock_index_processor_factory.return_value @@ -1056,7 +1058,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1068,13 +1070,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1088,13 +1090,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Active Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1107,13 +1109,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Archived Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=True, # This document should be skipped batch="test_batch", @@ -1135,7 +1137,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1146,13 +1148,13 @@ class TestDealDatasetVectorIndexTask: # Verify only active document was processed updated_active_document = db_session_with_containers.query(Document).filter_by(id=active_document.id).first() - assert updated_active_document.indexing_status == "completed" + assert updated_active_document.indexing_status == IndexingStatus.COMPLETED # Verify archived document status remains unchanged updated_archived_document = ( db_session_with_containers.query(Document).filter_by(id=archived_document.id).first() ) - assert updated_archived_document.indexing_status == "completed" # Should not change + assert updated_archived_document.indexing_status == IndexingStatus.COMPLETED # Should not change # Verify index processor load was called only once (for active document) mock_factory = mock_index_processor_factory.return_value @@ -1177,7 +1179,7 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=account.id, ) db_session_with_containers.add(dataset) @@ -1189,13 +1191,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Document for doc_form", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1209,13 +1211,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Completed Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, batch="test_batch", @@ -1228,13 +1230,13 @@ class TestDealDatasetVectorIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="file_import", + data_source_type=DataSourceType.UPLOAD_FILE, name="Incomplete Document", - created_from="file_import", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, doc_form="text_model", doc_language="en", - indexing_status="indexing", # This document should be skipped + indexing_status=IndexingStatus.INDEXING, # This document should be skipped enabled=True, archived=False, batch="test_batch", @@ -1256,7 +1258,7 @@ class TestDealDatasetVectorIndexTask: index_node_id=f"node_{uuid.uuid4()}", index_node_hash=f"hash_{uuid.uuid4()}", created_by=account.id, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, ) db_session_with_containers.add(segment) @@ -1269,13 +1271,13 @@ class TestDealDatasetVectorIndexTask: updated_completed_document = ( db_session_with_containers.query(Document).filter_by(id=completed_document.id).first() ) - assert updated_completed_document.indexing_status == "completed" + assert updated_completed_document.indexing_status == IndexingStatus.COMPLETED # Verify incomplete document status remains unchanged updated_incomplete_document = ( db_session_with_containers.query(Document).filter_by(id=incomplete_document.id).first() ) - assert updated_incomplete_document.indexing_status == "indexing" # Should not change + assert updated_incomplete_document.indexing_status == IndexingStatus.INDEXING # Should not change # Verify index processor load was called only once (for completed document) mock_factory = mock_index_processor_factory.return_value diff --git a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py index bc0ed3bd2b..6fc2a53f9c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_delete_segment_from_index_task.py @@ -14,6 +14,7 @@ from faker import Faker from core.rag.index_processor.constant.index_type import IndexStructureType from models import Account, Dataset, Document, DocumentSegment, Tenant +from models.enums import DataSourceType, DocumentCreatedFrom, DocumentDocType, IndexingStatus, SegmentStatus from tasks.delete_segment_from_index_task import delete_segment_from_index_task logger = logging.getLogger(__name__) @@ -48,7 +49,7 @@ class TestDeleteSegmentFromIndexTask: Tenant: Created test tenant instance """ fake = fake or Faker() - tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="active") + tenant = Tenant(name=f"Test Tenant {fake.company()}", plan="basic", status="normal") tenant.id = fake.uuid4() tenant.created_at = fake.date_time_this_year() tenant.updated_at = tenant.created_at @@ -106,7 +107,7 @@ class TestDeleteSegmentFromIndexTask: dataset.description = fake.text(max_nb_chars=200) dataset.provider = "vendor" dataset.permission = "only_me" - dataset.data_source_type = "upload_file" + dataset.data_source_type = DataSourceType.UPLOAD_FILE dataset.indexing_technique = "high_quality" dataset.index_struct = '{"type": "paragraph"}' dataset.created_by = account.id @@ -145,7 +146,7 @@ class TestDeleteSegmentFromIndexTask: document.data_source_info = kwargs.get("data_source_info", "{}") document.batch = kwargs.get("batch", fake.uuid4()) document.name = kwargs.get("name", f"Test Document {fake.word()}") - document.created_from = kwargs.get("created_from", "api") + document.created_from = kwargs.get("created_from", DocumentCreatedFrom.API) document.created_by = account.id document.created_at = fake.date_time_this_year() document.processing_started_at = kwargs.get("processing_started_at", fake.date_time_this_year()) @@ -162,7 +163,7 @@ class TestDeleteSegmentFromIndexTask: document.enabled = kwargs.get("enabled", True) document.archived = kwargs.get("archived", False) document.updated_at = fake.date_time_this_year() - document.doc_type = kwargs.get("doc_type", "text") + document.doc_type = kwargs.get("doc_type", DocumentDocType.PERSONAL_DOCUMENT) document.doc_metadata = kwargs.get("doc_metadata", {}) document.doc_form = kwargs.get("doc_form", IndexStructureType.PARAGRAPH_INDEX) document.doc_language = kwargs.get("doc_language", "en") @@ -204,7 +205,7 @@ class TestDeleteSegmentFromIndexTask: segment.index_node_hash = fake.sha256() segment.hit_count = 0 segment.enabled = True - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.created_at = fake.date_time_this_year() segment.updated_by = account.id @@ -386,7 +387,7 @@ class TestDeleteSegmentFromIndexTask: account = self._create_test_account(db_session_with_containers, tenant, fake) dataset = self._create_test_dataset(db_session_with_containers, tenant, account, fake) document = self._create_test_document( - db_session_with_containers, dataset, account, fake, indexing_status="indexing" + db_session_with_containers, dataset, account, fake, indexing_status=IndexingStatus.INDEXING ) segments = self._create_test_document_segments(db_session_with_containers, document, account, 3, fake) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py index 8785c948d1..da42fc7167 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segment_from_index_task.py @@ -13,11 +13,12 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from extensions.ext_redis import redis_client from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.disable_segment_from_index_task import disable_segment_from_index_task logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ class TestDisableSegmentFromIndexTask: mock_processor.clean.return_value = None yield mock_processor - def _create_test_account_and_tenant(self, db_session_with_containers) -> tuple[Account, Tenant]: + def _create_test_account_and_tenant(self, db_session_with_containers: Session) -> tuple[Account, Tenant]: """ Helper method to create a test account and tenant for testing. @@ -53,8 +54,8 @@ class TestDisableSegmentFromIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( @@ -62,8 +63,8 @@ class TestDisableSegmentFromIndexTask: status="normal", plan="basic", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join with owner role join = TenantAccountJoin( @@ -72,15 +73,15 @@ class TestDisableSegmentFromIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Set current tenant for account account.current_tenant = tenant return account, tenant - def _create_test_dataset(self, tenant: Tenant, account: Account) -> Dataset: + def _create_test_dataset(self, db_session_with_containers: Session, tenant: Tenant, account: Account) -> Dataset: """ Helper method to create a test dataset. @@ -97,17 +98,22 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, name=fake.sentence(nb_words=3), description=fake.text(max_nb_chars=200), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() return dataset def _create_test_document( - self, dataset: Dataset, tenant: Tenant, account: Account, doc_form: str = "text_model" + self, + db_session_with_containers: Session, + dataset: Dataset, + tenant: Tenant, + account: Account, + doc_form: str = "text_model", ) -> Document: """ Helper method to create a test document. @@ -127,12 +133,12 @@ class TestDisableSegmentFromIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch=fake.uuid4(), name=fake.file_name(), - created_from="api", + created_from=DocumentCreatedFrom.API, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, doc_form=doc_form, @@ -140,13 +146,14 @@ class TestDisableSegmentFromIndexTask: tokens=500, completed_at=datetime.now(UTC), ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() return document def _create_test_segment( self, + db_session_with_containers: Session, document: Document, dataset: Dataset, tenant: Tenant, @@ -183,14 +190,14 @@ class TestDisableSegmentFromIndexTask: status=status, enabled=enabled, created_by=account.id, - completed_at=datetime.now(UTC) if status == "completed" else None, + completed_at=datetime.now(UTC) if status == SegmentStatus.COMPLETED else None, ) - db.session.add(segment) - db.session.commit() + db_session_with_containers.add(segment) + db_session_with_containers.commit() return segment - def test_disable_segment_success(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_success(self, db_session_with_containers: Session, mock_index_processor): """ Test successful segment disabling from index. @@ -202,9 +209,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -226,10 +233,10 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Verify segment is still in database - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_not_found(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_found(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not found. @@ -251,7 +258,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_not_completed(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment is not in completed status. @@ -262,9 +269,11 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with non-completed segment account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account, status="indexing", enabled=True) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment( + db_session_with_containers, document, dataset, tenant, account, status=SegmentStatus.INDEXING, enabled=True + ) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -275,7 +284,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_dataset(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_dataset(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated dataset. @@ -286,13 +295,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove dataset association segment.dataset_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -303,7 +312,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_no_document(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_no_document(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when segment has no associated document. @@ -314,13 +323,13 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Manually remove document association segment.document_id = "00000000-0000-0000-0000-000000000000" - db.session.commit() + db_session_with_containers.commit() # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -331,7 +340,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_disabled(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_disabled(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is disabled. @@ -342,12 +351,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with disabled document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.enabled = False - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -358,7 +367,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_archived(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_archived(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when document is archived. @@ -369,12 +378,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with archived document account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.archived = True - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -385,7 +394,9 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_document_indexing_not_completed(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_document_indexing_not_completed( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test handling when document indexing is not completed. @@ -396,12 +407,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data with incomplete indexing account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -412,7 +423,7 @@ class TestDisableSegmentFromIndexTask: # Verify index processor was not called mock_index_processor.clean.assert_not_called() - def test_disable_segment_index_processor_exception(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_index_processor_exception(self, db_session_with_containers: Session, mock_index_processor): """ Test handling when index processor raises an exception. @@ -424,9 +435,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Set up Redis cache indexing_cache_key = f"segment_{segment.id}_indexing" @@ -449,13 +460,13 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][1] == [segment.index_node_id] # Check index node IDs # Verify segment was re-enabled - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is True # Verify Redis cache was still cleared assert redis_client.get(indexing_cache_key) is None - def test_disable_segment_different_doc_forms(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_different_doc_forms(self, db_session_with_containers: Session, mock_index_processor): """ Test disabling segments with different document forms. @@ -470,9 +481,11 @@ class TestDisableSegmentFromIndexTask: for doc_form in doc_forms: # Arrange: Create test data for each form account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account, doc_form=doc_form) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document( + db_session_with_containers, dataset, tenant, account, doc_form=doc_form + ) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Reset mock for each iteration mock_index_processor.reset_mock() @@ -489,7 +502,7 @@ class TestDisableSegmentFromIndexTask: assert call_args[0][0].id == dataset.id # Check dataset ID assert call_args[0][1] == [segment.index_node_id] # Check index node IDs - def test_disable_segment_redis_cache_handling(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_redis_cache_handling(self, db_session_with_containers: Session, mock_index_processor): """ Test Redis cache handling during segment disabling. @@ -500,9 +513,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Test with cache present indexing_cache_key = f"segment_{segment.id}_indexing" @@ -517,13 +530,13 @@ class TestDisableSegmentFromIndexTask: assert redis_client.get(indexing_cache_key) is None # Test with no cache present - segment2 = self._create_test_segment(document, dataset, tenant, account) + segment2 = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) result2 = disable_segment_from_index_task(segment2.id) # Assert: Verify task still works without cache assert result2 is None - def test_disable_segment_performance_timing(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_performance_timing(self, db_session_with_containers: Session, mock_index_processor): """ Test performance timing of segment disabling task. @@ -534,9 +547,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task and measure time start_time = time.perf_counter() @@ -548,7 +561,9 @@ class TestDisableSegmentFromIndexTask: execution_time = end_time - start_time assert execution_time < 5.0 # Should complete within 5 seconds - def test_disable_segment_database_session_management(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_database_session_management( + self, db_session_with_containers: Session, mock_index_processor + ): """ Test database session management during task execution. @@ -559,9 +574,9 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create test data account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) - segment = self._create_test_segment(document, dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) # Act: Execute the task result = disable_segment_from_index_task(segment.id) @@ -570,10 +585,10 @@ class TestDisableSegmentFromIndexTask: assert result is None # Verify segment is still accessible (session was properly managed) - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.id is not None - def test_disable_segment_concurrent_execution(self, db_session_with_containers, mock_index_processor): + def test_disable_segment_concurrent_execution(self, db_session_with_containers: Session, mock_index_processor): """ Test concurrent execution of segment disabling tasks. @@ -584,12 +599,12 @@ class TestDisableSegmentFromIndexTask: """ # Arrange: Create multiple test segments account, tenant = self._create_test_account_and_tenant(db_session_with_containers) - dataset = self._create_test_dataset(tenant, account) - document = self._create_test_document(dataset, tenant, account) + dataset = self._create_test_dataset(db_session_with_containers, tenant, account) + document = self._create_test_document(db_session_with_containers, dataset, tenant, account) segments = [] for i in range(3): - segment = self._create_test_segment(document, dataset, tenant, account) + segment = self._create_test_segment(db_session_with_containers, document, dataset, tenant, account) segments.append(segment) # Act: Execute tasks concurrently (simulated) diff --git a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py index a93a80e231..4bc9bb4749 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_disable_segments_from_index_task.py @@ -9,10 +9,12 @@ The task is responsible for removing document segments from the search index whe from unittest.mock import MagicMock, patch from faker import Faker +from sqlalchemy.orm import Session from models import Account, Dataset, DocumentSegment from models import Document as DatasetDocument from models.dataset import DatasetProcessRule +from models.enums import DataSourceType, DocumentCreatedFrom, ProcessRuleMode, SegmentStatus from tasks.disable_segments_from_index_task import disable_segments_from_index_task @@ -31,7 +33,7 @@ class TestDisableSegmentsFromIndexTask: and realistic testing environment with actual database interactions. """ - def _create_test_account(self, db_session_with_containers, fake=None): + def _create_test_account(self, db_session_with_containers: Session, fake=None): """ Helper method to create a test account with realistic data. @@ -64,7 +66,7 @@ class TestDisableSegmentsFromIndexTask: tenant = Tenant( name=f"Test Tenant {fake.company()}", plan="basic", - status="active", + status="normal", ) tenant.id = account.tenant_id tenant.created_at = fake.date_time_this_year() @@ -79,7 +81,7 @@ class TestDisableSegmentsFromIndexTask: return account - def _create_test_dataset(self, db_session_with_containers, account, fake=None): + def _create_test_dataset(self, db_session_with_containers: Session, account, fake=None): """ Helper method to create a test dataset with realistic data. @@ -99,7 +101,7 @@ class TestDisableSegmentsFromIndexTask: description=fake.text(max_nb_chars=200), provider="vendor", permission="only_me", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, updated_by=account.id, @@ -113,7 +115,7 @@ class TestDisableSegmentsFromIndexTask: return dataset - def _create_test_document(self, db_session_with_containers, dataset, account, fake=None): + def _create_test_document(self, db_session_with_containers: Session, dataset, account, fake=None): """ Helper method to create a test document with realistic data. @@ -133,11 +135,11 @@ class TestDisableSegmentsFromIndexTask: document.tenant_id = dataset.tenant_id document.dataset_id = dataset.id document.position = 1 - document.data_source_type = "upload_file" + document.data_source_type = DataSourceType.UPLOAD_FILE document.data_source_info = '{"upload_file_id": "test_file_id"}' document.batch = fake.uuid4() document.name = f"Test Document {fake.word()}.txt" - document.created_from = "upload_file" + document.created_from = DocumentCreatedFrom.WEB document.created_by = account.id document.created_api_request_id = fake.uuid4() document.processing_started_at = fake.date_time_this_year() @@ -158,7 +160,9 @@ class TestDisableSegmentsFromIndexTask: return document - def _create_test_segments(self, db_session_with_containers, document, dataset, account, count=3, fake=None): + def _create_test_segments( + self, db_session_with_containers: Session, document, dataset, account, count=3, fake=None + ): """ Helper method to create test document segments with realistic data. @@ -194,7 +198,7 @@ class TestDisableSegmentsFromIndexTask: segment.enabled = True segment.disabled_at = None segment.disabled_by = None - segment.status = "completed" + segment.status = SegmentStatus.COMPLETED segment.created_by = account.id segment.updated_by = account.id segment.indexing_at = fake.date_time_this_year() @@ -210,7 +214,7 @@ class TestDisableSegmentsFromIndexTask: return segments - def _create_dataset_process_rule(self, db_session_with_containers, dataset, fake=None): + def _create_dataset_process_rule(self, db_session_with_containers: Session, dataset, fake=None): """ Helper method to create a dataset process rule. @@ -227,7 +231,7 @@ class TestDisableSegmentsFromIndexTask: process_rule.id = fake.uuid4() process_rule.tenant_id = dataset.tenant_id process_rule.dataset_id = dataset.id - process_rule.mode = "automatic" + process_rule.mode = ProcessRuleMode.AUTOMATIC process_rule.rules = ( "{" '"mode": "automatic", ' @@ -239,14 +243,12 @@ class TestDisableSegmentsFromIndexTask: process_rule.created_by = dataset.created_by process_rule.updated_by = dataset.updated_by - from extensions.ext_database import db - - db.session.add(process_rule) - db.session.commit() + db_session_with_containers.add(process_rule) + db_session_with_containers.commit() return process_rule - def test_disable_segments_success(self, db_session_with_containers): + def test_disable_segments_success(self, db_session_with_containers: Session): """ Test successful disabling of segments from index. @@ -297,7 +299,7 @@ class TestDisableSegmentsFromIndexTask: expected_key = f"segment_{segment.id}_indexing" mock_redis.delete.assert_any_call(expected_key) - def test_disable_segments_dataset_not_found(self, db_session_with_containers): + def test_disable_segments_dataset_not_found(self, db_session_with_containers: Session): """ Test handling when dataset is not found. @@ -320,7 +322,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when dataset is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_not_found(self, db_session_with_containers): + def test_disable_segments_document_not_found(self, db_session_with_containers: Session): """ Test handling when document is not found. @@ -344,7 +346,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when document is not found mock_redis.delete.assert_not_called() - def test_disable_segments_document_invalid_status(self, db_session_with_containers): + def test_disable_segments_document_invalid_status(self, db_session_with_containers: Session): """ Test handling when document has invalid status for disabling. @@ -360,9 +362,8 @@ class TestDisableSegmentsFromIndexTask: # Test case 1: Document not enabled document.enabled = False - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() segment_ids = [segment.id for segment in segments] @@ -379,7 +380,7 @@ class TestDisableSegmentsFromIndexTask: # Test case 2: Document archived document.enabled = True document.archived = True - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -393,7 +394,7 @@ class TestDisableSegmentsFromIndexTask: document.enabled = True document.archived = False document.indexing_status = "indexing" - db.session.commit() + db_session_with_containers.commit() with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis: # Act @@ -403,7 +404,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_redis.delete.assert_not_called() - def test_disable_segments_no_segments_found(self, db_session_with_containers): + def test_disable_segments_no_segments_found(self, db_session_with_containers: Session): """ Test handling when no segments are found for the given IDs. @@ -430,7 +431,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are found mock_redis.delete.assert_not_called() - def test_disable_segments_index_processor_error(self, db_session_with_containers): + def test_disable_segments_index_processor_error(self, db_session_with_containers: Session): """ Test handling when index processor encounters an error. @@ -464,13 +465,14 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Verify segments were rolled back to enabled state - from extensions.ext_database import db - db.session.refresh(segments[0]) - db.session.refresh(segments[1]) + db_session_with_containers.refresh(segments[0]) + db_session_with_containers.refresh(segments[1]) # Check that segments are re-enabled after error - updated_segments = db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + updated_segments = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).all() + ) for segment in updated_segments: assert segment.enabled is True @@ -480,7 +482,7 @@ class TestDisableSegmentsFromIndexTask: # Verify Redis cache cleanup was still called assert mock_redis.delete.call_count == len(segments) - def test_disable_segments_with_different_doc_forms(self, db_session_with_containers): + def test_disable_segments_with_different_doc_forms(self, db_session_with_containers: Session): """ Test disabling segments with different document forms. @@ -503,9 +505,8 @@ class TestDisableSegmentsFromIndexTask: for doc_form in doc_forms: # Update document form document.doc_form = doc_form - from extensions.ext_database import db - db.session.commit() + db_session_with_containers.commit() # Mock the index processor factory with patch("tasks.disable_segments_from_index_task.IndexProcessorFactory") as mock_factory: @@ -523,7 +524,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value mock_factory.assert_called_with(doc_form) - def test_disable_segments_performance_timing(self, db_session_with_containers): + def test_disable_segments_performance_timing(self, db_session_with_containers: Session): """ Test that the task properly measures and logs performance timing. @@ -568,7 +569,7 @@ class TestDisableSegmentsFromIndexTask: assert performance_log is not None assert "0.5" in performance_log # Should log the execution time - def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers): + def test_disable_segments_redis_cache_cleanup(self, db_session_with_containers: Session): """ Test that Redis cache is properly cleaned up for all segments. @@ -610,7 +611,7 @@ class TestDisableSegmentsFromIndexTask: for expected_key in expected_keys: assert expected_key in actual_calls - def test_disable_segments_database_session_cleanup(self, db_session_with_containers): + def test_disable_segments_database_session_cleanup(self, db_session_with_containers: Session): """ Test that database session is properly closed after task execution. @@ -643,7 +644,7 @@ class TestDisableSegmentsFromIndexTask: assert result is None # Task should complete without returning a value # Session lifecycle is managed by context manager; no explicit close assertion - def test_disable_segments_empty_segment_ids(self, db_session_with_containers): + def test_disable_segments_empty_segment_ids(self, db_session_with_containers: Session): """ Test handling when empty segment IDs list is provided. @@ -669,7 +670,7 @@ class TestDisableSegmentsFromIndexTask: # Redis should not be called when no segments are provided mock_redis.delete.assert_not_called() - def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers): + def test_disable_segments_mixed_valid_invalid_ids(self, db_session_with_containers: Session): """ Test handling when some segment IDs are valid and others are invalid. diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py index df5c5dc54b..6a17a19a54 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_sync_task.py @@ -16,6 +16,7 @@ import pytest from core.indexing_runner import DocumentIsPausedError, IndexingRunner from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_sync_task import document_indexing_sync_task @@ -54,7 +55,7 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, name=f"dataset-{uuid4()}", description="sync test dataset", - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, indexing_technique="high_quality", created_by=created_by, ) @@ -76,11 +77,11 @@ class DocumentIndexingSyncTaskTestDataFactory: tenant_id=tenant_id, dataset_id=dataset_id, position=0, - data_source_type="notion_import", + data_source_type=DataSourceType.NOTION_IMPORT, data_source_info=json.dumps(data_source_info) if data_source_info is not None else None, batch="test-batch", name=f"doc-{uuid4()}", - created_from="notion_import", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, indexing_status=indexing_status, enabled=True, @@ -113,7 +114,7 @@ class DocumentIndexingSyncTaskTestDataFactory: word_count=10, tokens=5, index_node_id=f"node-{document_id}-{i}", - status="completed", + status=SegmentStatus.COMPLETED, created_by=created_by, ) db_session_with_containers.add(segment) @@ -181,7 +182,7 @@ class TestDocumentIndexingSyncTask: dataset_id=dataset.id, created_by=account.id, data_source_info=notion_info, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) segments = DocumentIndexingSyncTaskTestDataFactory.create_segments( @@ -276,7 +277,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Datasource credential not found" in updated_document.error assert updated_document.stopped_at is not None mock_external_dependencies["indexing_runner"].run.assert_not_called() @@ -301,7 +302,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "completed" + assert updated_document.indexing_status == IndexingStatus.COMPLETED assert updated_document.processing_started_at is None assert remaining_segments == 3 mock_external_dependencies["index_processor"].clean.assert_not_called() @@ -327,7 +328,7 @@ class TestDocumentIndexingSyncTask: ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None assert updated_document.data_source_info_dict.get("last_edited_time") == "2024-01-02T00:00:00Z" assert remaining_segments == 0 @@ -369,7 +370,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING mock_external_dependencies["index_processor"].clean.assert_not_called() mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -393,7 +394,7 @@ class TestDocumentIndexingSyncTask: .count() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert remaining_segments == 0 mock_external_dependencies["indexing_runner"].run.assert_called_once() @@ -412,7 +413,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.error is None def test_indexing_runner_general_error(self, db_session_with_containers, mock_external_dependencies): @@ -430,7 +431,7 @@ class TestDocumentIndexingSyncTask: db_session_with_containers.query(Document).where(Document.id == context["document"].id).first() ) assert updated_document is not None - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert "Indexing error" in updated_document.error assert updated_document.stopped_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py index 4be1180c73..9421b07285 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_task.py @@ -8,6 +8,7 @@ from core.entities.document_task import DocumentTask from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus from tasks.document_indexing_task import ( _document_indexing, # Core function _document_indexing_with_tenant_queue, # Tenant queue wrapper function @@ -97,7 +98,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -112,12 +113,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -179,7 +180,7 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -194,12 +195,12 @@ class TestDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -250,7 +251,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -320,7 +321,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -367,7 +368,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_mixed_document_states( @@ -397,12 +398,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=2, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="completed", # Already completed + indexing_status=IndexingStatus.COMPLETED, # Already completed enabled=True, ) db_session_with_containers.add(doc1) @@ -414,12 +415,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=False, # Disabled ) db_session_with_containers.add(doc2) @@ -444,7 +445,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with all documents @@ -482,12 +483,12 @@ class TestDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, ) db_session_with_containers.add(document) @@ -507,7 +508,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error assert updated_document.stopped_at is not None @@ -548,7 +549,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None def test_document_indexing_task_document_is_paused_error( @@ -591,7 +592,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # ==================== NEW TESTS FOR REFACTORED FUNCTIONS ==================== @@ -702,7 +703,7 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -762,11 +763,12 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify task function was called for each waiting task - assert mock_task_func.delay.call_count == 1 + assert mock_task_func.apply_async.call_count == 1 # Verify correct parameters for each call - calls = mock_task_func.delay.call_args_list - assert calls[0][1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + calls = mock_task_func.apply_async.call_args_list + sent_kwargs = calls[0][1]["kwargs"] + assert sent_kwargs == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} # Verify queue is empty after processing (tasks were pulled) remaining_tasks = queue.pull_tasks(count=10) # Pull more than we added @@ -826,15 +828,19 @@ class TestDocumentIndexingTasks: # Re-query documents from database since _document_indexing uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify waiting task was still processed despite core processing error - mock_task_func.delay.assert_called_once() + mock_task_func.apply_async.assert_called_once() # Verify correct parameters for the call - call = mock_task_func.delay.call_args - assert call[1] == {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["waiting-doc-1"]} + call = mock_task_func.apply_async.call_args + assert call[1]["kwargs"] == { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": ["waiting-doc-1"], + } # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) @@ -896,9 +902,13 @@ class TestDocumentIndexingTasks: mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() # Verify only tenant1's waiting task was processed - mock_task_func.delay.assert_called_once() - call = mock_task_func.delay.call_args - assert call[1] == {"tenant_id": tenant1_id, "dataset_id": dataset1_id, "document_ids": ["tenant1-doc-1"]} + mock_task_func.apply_async.assert_called_once() + call = mock_task_func.apply_async.call_args + assert call[1]["kwargs"] == { + "tenant_id": tenant1_id, + "dataset_id": dataset1_id, + "document_ids": ["tenant1-doc-1"], + } # Verify tenant1's queue is empty remaining_tasks1 = queue1.pull_tasks(count=10) diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py index 9da9a4132e..2fbea1388c 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -5,6 +5,7 @@ from faker import Faker from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.document_indexing_update_task import document_indexing_update_task @@ -61,7 +62,7 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=64), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -72,12 +73,12 @@ class TestDocumentIndexingUpdateTask: tenant_id=tenant.id, dataset_id=dataset.id, position=0, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -98,7 +99,7 @@ class TestDocumentIndexingUpdateTask: word_count=10, tokens=5, index_node_id=node_id, - status="completed", + status=SegmentStatus.COMPLETED, created_by=account.id, ) db_session_with_containers.add(seg) @@ -122,7 +123,7 @@ class TestDocumentIndexingUpdateTask: # Assert document status updated before reindex updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() - assert updated.indexing_status == "parsing" + assert updated.indexing_status == IndexingStatus.PARSING assert updated.processing_started_at is not None # Segments should be deleted diff --git a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py index b2e1ce3b89..f1f5a4b105 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_duplicate_document_indexing_task.py @@ -3,9 +3,11 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from core.indexing_runner import DocumentIsPausedError from enums.cloud_plan import CloudPlan from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.duplicate_document_indexing_task import ( _duplicate_document_indexing_task, # Core function _duplicate_document_indexing_task_with_tenant_queue, # Tenant queue wrapper function @@ -106,7 +108,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -121,12 +123,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -176,7 +178,7 @@ class TestDuplicateDocumentIndexingTasks: content=fake.text(max_nb_chars=200), word_count=50, tokens=100, - status="completed", + status=SegmentStatus.COMPLETED, enabled=True, indexing_at=fake.date_time_this_year(), created_by=dataset.created_by, # Add required field @@ -241,7 +243,7 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) @@ -256,12 +258,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=tenant.id, dataset_id=dataset.id, position=i, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -282,7 +284,7 @@ class TestDuplicateDocumentIndexingTasks: return dataset, documents - def test_duplicate_document_indexing_task_success( + def _test_duplicate_document_indexing_task_success( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -315,7 +317,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with correct documents @@ -324,7 +326,7 @@ class TestDuplicateDocumentIndexingTasks: processed_documents = call_args[0][0] # First argument should be documents list assert len(processed_documents) == 3 - def test_duplicate_document_indexing_task_with_segment_cleanup( + def _test_duplicate_document_indexing_task_with_segment_cleanup( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -367,14 +369,14 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were updated to parsing status for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify indexing runner was called mock_external_service_dependencies["indexing_runner"].assert_called_once() mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() - def test_duplicate_document_indexing_task_dataset_not_found( + def _test_duplicate_document_indexing_task_dataset_not_found( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -436,7 +438,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in existing_document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None # Verify the run method was called with only existing documents @@ -445,7 +447,7 @@ class TestDuplicateDocumentIndexingTasks: processed_documents = call_args[0][0] # First argument should be documents list assert len(processed_documents) == 2 # Only existing documents - def test_duplicate_document_indexing_task_indexing_runner_exception( + def _test_duplicate_document_indexing_task_indexing_runner_exception( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -483,10 +485,10 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task close the session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING assert updated_document.processing_started_at is not None - def test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + def _test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -515,12 +517,12 @@ class TestDuplicateDocumentIndexingTasks: tenant_id=dataset.tenant_id, dataset_id=dataset.id, position=i + 3, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=dataset.created_by, - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, enabled=True, doc_form="text_model", ) @@ -541,7 +543,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "batch upload" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -549,7 +551,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify indexing runner was not called due to early validation error mock_external_service_dependencies["indexing_runner_instance"].run.assert_not_called() - def test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + def _test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( self, db_session_with_containers, mock_external_service_dependencies ): """ @@ -583,7 +585,7 @@ class TestDuplicateDocumentIndexingTasks: # Re-query documents from database since _duplicate_document_indexing_task uses a different session for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "error" + assert updated_document.indexing_status == IndexingStatus.ERROR assert updated_document.error is not None assert "limit" in updated_document.error.lower() assert updated_document.stopped_at is not None @@ -647,7 +649,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_normal_duplicate_document_indexing_task_with_tenant_queue( @@ -690,7 +692,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_priority_duplicate_document_indexing_task_with_tenant_queue( @@ -734,7 +736,7 @@ class TestDuplicateDocumentIndexingTasks: # Verify documents were processed for doc_id in document_ids: updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() - assert updated_document.indexing_status == "parsing" + assert updated_document.indexing_status == IndexingStatus.PARSING @patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) def test_tenant_queue_wrapper_processes_next_tasks( @@ -783,3 +785,90 @@ class TestDuplicateDocumentIndexingTasks: document_ids=document_ids, ) mock_queue.delete_task_key.assert_not_called() + + def test_successful_duplicate_document_indexing( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test successful duplicate document indexing flow.""" + self._test_duplicate_document_indexing_task_success( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_dataset_not_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when dataset is not found.""" + self._test_duplicate_document_indexing_task_dataset_not_found( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing with billing enabled and sandbox plan.""" + self._test_duplicate_document_indexing_task_billing_sandbox_plan_batch_limit( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_with_billing_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when billing limit is exceeded.""" + self._test_duplicate_document_indexing_task_billing_vector_space_limit_exceeded( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_runner_error( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when IndexingRunner raises an error.""" + self._test_duplicate_document_indexing_task_indexing_runner_exception( + db_session_with_containers, mock_external_service_dependencies + ) + + def _test_duplicate_document_indexing_task_document_is_paused( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when document is paused.""" + # Arrange + dataset, documents = self._create_test_dataset_and_documents( + db_session_with_containers, mock_external_service_dependencies, document_count=2 + ) + for document in documents: + document.is_paused = True + db_session_with_containers.add(document) + db_session_with_containers.commit() + + document_ids = [doc.id for doc in documents] + mock_external_service_dependencies["indexing_runner_instance"].run.side_effect = DocumentIsPausedError( + "Document paused" + ) + + # Act + _duplicate_document_indexing_task(dataset.id, document_ids) + db_session_with_containers.expire_all() + + # Assert + for doc_id in document_ids: + updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first() + assert updated_document.is_paused is True + assert updated_document.indexing_status == IndexingStatus.PARSING + assert updated_document.display_status == "paused" + assert updated_document.processing_started_at is not None + mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once() + + def test_duplicate_document_indexing_document_is_paused( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test duplicate document indexing when document is paused.""" + self._test_duplicate_document_indexing_task_document_is_paused( + db_session_with_containers, mock_external_service_dependencies + ) + + def test_duplicate_document_indexing_cleans_old_segments( + self, db_session_with_containers, mock_external_service_dependencies + ): + """Test that duplicate document indexing cleans old segments.""" + self._test_duplicate_document_indexing_task_with_segment_cleanup( + db_session_with_containers, mock_external_service_dependencies + ) diff --git a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py index b3d9e49b30..54b50016a8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_enable_segments_to_index_task.py @@ -2,12 +2,13 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType -from extensions.ext_database import db from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, Document, DocumentSegment +from models.enums import DataSourceType, DocumentCreatedFrom, IndexingStatus, SegmentStatus from tasks.enable_segments_to_index_task import enable_segments_to_index_task @@ -31,7 +32,9 @@ class TestEnableSegmentsToIndexTask: "index_processor": mock_processor, } - def _create_test_dataset_and_document(self, db_session_with_containers, mock_external_service_dependencies): + def _create_test_dataset_and_document( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Helper method to create a test dataset and document for testing. @@ -51,15 +54,15 @@ class TestEnableSegmentsToIndexTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -68,8 +71,8 @@ class TestEnableSegmentsToIndexTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create dataset dataset = Dataset( @@ -77,12 +80,12 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, name=fake.company(), description=fake.text(max_nb_chars=100), - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, indexing_technique="high_quality", created_by=account.id, ) - db.session.add(dataset) - db.session.commit() + db_session_with_containers.add(dataset) + db_session_with_containers.commit() # Create document document = Document( @@ -90,25 +93,31 @@ class TestEnableSegmentsToIndexTask: tenant_id=tenant.id, dataset_id=dataset.id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="test_batch", name=fake.file_name(), - created_from="upload_file", + created_from=DocumentCreatedFrom.WEB, created_by=account.id, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, doc_form=IndexStructureType.PARAGRAPH_INDEX, ) - db.session.add(document) - db.session.commit() + db_session_with_containers.add(document) + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property works correctly - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) return dataset, document def _create_test_segments( - self, db_session_with_containers, document, dataset, count=3, enabled=False, status="completed" + self, + db_session_with_containers: Session, + document, + dataset, + count=3, + enabled=False, + status=SegmentStatus.COMPLETED, ): """ Helper method to create test document segments. @@ -144,14 +153,14 @@ class TestEnableSegmentsToIndexTask: status=status, created_by=document.created_by, ) - db.session.add(segment) + db_session_with_containers.add(segment) segments.append(segment) - db.session.commit() + db_session_with_containers.commit() return segments def test_enable_segments_to_index_with_different_index_type( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with different index types. @@ -169,10 +178,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use different index type document.doc_form = IndexStructureType.QA_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -204,7 +213,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_dataset_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent dataset. @@ -229,7 +238,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_document_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of non-existent document. @@ -256,7 +265,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_invalid_document_status( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling of document with invalid status. @@ -276,7 +285,7 @@ class TestEnableSegmentsToIndexTask: invalid_statuses = [ ("disabled", {"enabled": False}), ("archived", {"archived": True}), - ("not_completed", {"indexing_status": "processing"}), + ("not_completed", {"indexing_status": IndexingStatus.INDEXING}), ] for _, status_attrs in invalid_statuses: @@ -284,12 +293,12 @@ class TestEnableSegmentsToIndexTask: document.enabled = True document.archived = False document.indexing_status = "completed" - db.session.commit() + db_session_with_containers.commit() # Set invalid status for attr, value in status_attrs.items(): setattr(document, attr, value) - db.session.commit() + db_session_with_containers.commit() # Create segments segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -304,11 +313,11 @@ class TestEnableSegmentsToIndexTask: # Clean up segments for next iteration for segment in segments: - db.session.delete(segment) - db.session.commit() + db_session_with_containers.delete(segment) + db_session_with_containers.commit() def test_enable_segments_to_index_segments_not_found( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test handling when no segments are found. @@ -338,7 +347,7 @@ class TestEnableSegmentsToIndexTask: mock_external_service_dependencies["index_processor"].load.assert_not_called() def test_enable_segments_to_index_with_parent_child_structure( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test segments indexing with parent-child structure. @@ -357,10 +366,10 @@ class TestEnableSegmentsToIndexTask: # Update document to use parent-child index type document.doc_form = IndexStructureType.PARENT_CHILD_INDEX - db.session.commit() + db_session_with_containers.commit() # Refresh dataset to ensure doc_form property reflects the updated document - db.session.refresh(dataset) + db_session_with_containers.refresh(dataset) # Create segments with mock child chunks segments = self._create_test_segments(db_session_with_containers, document, dataset) @@ -410,7 +419,7 @@ class TestEnableSegmentsToIndexTask: assert redis_client.exists(indexing_cache_key) == 0 def test_enable_segments_to_index_general_exception_handling( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test general exception handling during indexing process. @@ -443,9 +452,9 @@ class TestEnableSegmentsToIndexTask: # Assert: Verify error handling for segment in segments: - db.session.refresh(segment) + db_session_with_containers.refresh(segment) assert segment.enabled is False - assert segment.status == "error" + assert segment.status == SegmentStatus.ERROR assert segment.error is not None assert "Index processing failed" in segment.error assert segment.disabled_at is not None diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py index 6c3a9ef20a..ff72232d12 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_account_deletion_task.py @@ -2,8 +2,8 @@ from unittest.mock import patch import pytest from faker import Faker +from sqlalchemy.orm import Session -from extensions.ext_database import db from libs.email_i18n import EmailType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from tasks.mail_account_deletion_task import send_account_deletion_verification_code, send_deletion_success_task @@ -30,7 +30,7 @@ class TestMailAccountDeletionTask: "email_service": mock_email_service, } - def _create_test_account(self, db_session_with_containers): + def _create_test_account(self, db_session_with_containers: Session): """ Helper method to create a test account for testing. @@ -49,16 +49,16 @@ class TestMailAccountDeletionTask: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() # Create tenant tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -67,12 +67,14 @@ class TestMailAccountDeletionTask: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() return account - def test_send_deletion_success_task_success(self, db_session_with_containers, mock_external_service_dependencies): + def test_send_deletion_success_task_success( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): """ Test successful account deletion success email sending. @@ -109,7 +111,7 @@ class TestMailAccountDeletionTask: ) def test_send_deletion_success_task_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when mail service is not initialized. @@ -132,7 +134,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_deletion_success_task_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion success email when email service raises exception. @@ -154,7 +156,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_called_once() def test_send_account_deletion_verification_code_success( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test successful account deletion verification code email sending. @@ -193,7 +195,7 @@ class TestMailAccountDeletionTask: ) def test_send_account_deletion_verification_code_mail_not_initialized( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when mail service is not initialized. @@ -217,7 +219,7 @@ class TestMailAccountDeletionTask: mock_external_service_dependencies["email_service"].send_email.assert_not_called() def test_send_account_deletion_verification_code_email_service_exception( - self, db_session_with_containers, mock_external_service_dependencies + self, db_session_with_containers: Session, mock_external_service_dependencies ): """ Test account deletion verification code email when email service raises exception. diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py index 3cdec70df7..c0ddc27286 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_email_code_login_task.py @@ -118,7 +118,7 @@ class TestSendEmailCodeLoginMailTask: tenant = Tenant( name=fake.company(), plan="basic", - status="active", + status="normal", ) db_session_with_containers.add(tenant) diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 5fd6c56f7a..0876a39f82 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,8 +9,8 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import ( +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -18,7 +18,7 @@ from core.workflow.nodes.human_input.entities import ( HumanInputNodeData, MemberRecipient, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -96,8 +96,7 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - engine = db_session_with_containers.get_bind() - repo = HumanInputFormRepositoryImpl(session_factory=engine, tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) params = FormCreateParams( app_id=app_id, workflow_execution_id=workflow_execution_id, diff --git a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py index b9977b1fb6..f01fcc1742 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py +++ b/api/tests/test_containers_integration_tests/tasks/test_rag_pipeline_run_tasks.py @@ -1,14 +1,14 @@ import json import uuid -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from extensions.ext_database import db from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Pipeline from models.workflow import Workflow @@ -52,7 +52,7 @@ class TestRagPipelineRunTasks: "delete_file": mock_delete_file, } - def _create_test_pipeline_and_workflow(self, db_session_with_containers): + def _create_test_pipeline_and_workflow(self, db_session_with_containers: Session): """ Helper method to create test pipeline and workflow for testing. @@ -71,15 +71,15 @@ class TestRagPipelineRunTasks: interface_language="en-US", status="active", ) - db.session.add(account) - db.session.commit() + db_session_with_containers.add(account) + db_session_with_containers.commit() tenant = Tenant( name=fake.company(), status="normal", ) - db.session.add(tenant) - db.session.commit() + db_session_with_containers.add(tenant) + db_session_with_containers.commit() # Create tenant-account join join = TenantAccountJoin( @@ -88,8 +88,8 @@ class TestRagPipelineRunTasks: role=TenantAccountRole.OWNER, current=True, ) - db.session.add(join) - db.session.commit() + db_session_with_containers.add(join) + db_session_with_containers.commit() # Create workflow workflow = Workflow( @@ -107,8 +107,8 @@ class TestRagPipelineRunTasks: conversation_variables=[], rag_pipeline_variables=[], ) - db.session.add(workflow) - db.session.commit() + db_session_with_containers.add(workflow) + db_session_with_containers.commit() # Create pipeline pipeline = Pipeline( @@ -119,14 +119,14 @@ class TestRagPipelineRunTasks: created_by=account.id, ) pipeline.id = str(uuid.uuid4()) - db.session.add(pipeline) - db.session.commit() + db_session_with_containers.add(pipeline) + db_session_with_containers.commit() # Refresh entities to ensure they're properly loaded - db.session.refresh(account) - db.session.refresh(tenant) - db.session.refresh(workflow) - db.session.refresh(pipeline) + db_session_with_containers.refresh(account) + db_session_with_containers.refresh(tenant) + db_session_with_containers.refresh(workflow) + db_session_with_containers.refresh(pipeline) return account, tenant, pipeline, workflow @@ -209,7 +209,7 @@ class TestRagPipelineRunTasks: return json.dumps(entities_data) def test_priority_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful priority RAG pipeline run task execution. @@ -254,7 +254,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_rag_pipeline_run_task_success( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test successful regular RAG pipeline run task execution. @@ -299,7 +299,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_priority_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with waiting tasks in queue using real Redis. @@ -351,7 +351,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 1 # 2 original - 1 pulled = 1 remaining def test_rag_pipeline_run_task_legacy_compatibility( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with legacy Redis queue format for backward compatibility. @@ -388,8 +388,10 @@ class TestRagPipelineRunTasks: # Set the task key to indicate there are waiting tasks (legacy behavior) redis_client.set(legacy_task_key, 1, ex=60 * 60) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the priority task with new code but legacy queue data rag_pipeline_run_task(file_id, tenant.id) @@ -398,13 +400,14 @@ class TestRagPipelineRunTasks: mock_file_service["delete_file"].assert_called_once_with(file_id) assert mock_pipeline_generator.call_count == 1 - # Verify waiting tasks were processed, pull 1 task a time by default - assert mock_delay.call_count == 1 + # Verify waiting tasks were processed via group, pull 1 task a time by default + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == legacy_file_ids[0] + assert first_kwargs.get("tenant_id") == tenant.id # Verify that new code can process legacy queue entries # The new TenantIsolatedTaskQueue should be able to read from the legacy format @@ -419,7 +422,7 @@ class TestRagPipelineRunTasks: redis_client.delete(legacy_task_key) def test_rag_pipeline_run_task_with_waiting_tasks( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with waiting tasks in queue using real Redis. @@ -446,8 +449,10 @@ class TestRagPipelineRunTasks: waiting_file_ids = [str(uuid.uuid4()) for _ in range(3)] queue.push_tasks(waiting_file_ids) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task rag_pipeline_run_task(file_id, tenant.id) @@ -456,20 +461,21 @@ class TestRagPipelineRunTasks: mock_file_service["delete_file"].assert_called_once_with(file_id) assert mock_pipeline_generator.call_count == 1 - # Verify waiting tasks were processed, pull 1 task a time by default - assert mock_delay.call_count == 1 + # Verify waiting tasks were processed via group.apply_async + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_ids[0] + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue still has remaining tasks (only 1 was pulled) remaining_tasks = queue.pull_tasks(count=10) assert len(remaining_tasks) == 2 # 3 original - 1 pulled = 2 remaining def test_priority_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in priority RAG pipeline run task using real Redis. @@ -526,7 +532,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_error_handling( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test error handling in regular RAG pipeline run task using real Redis. @@ -557,8 +563,10 @@ class TestRagPipelineRunTasks: waiting_file_id = str(uuid.uuid4()) queue.push_tasks([waiting_file_id]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task (should not raise exception) rag_pipeline_run_task(file_id, tenant.id) @@ -569,19 +577,20 @@ class TestRagPipelineRunTasks: assert mock_pipeline_generator.call_count == 1 # Verify waiting task was still processed despite core processing error - mock_delay.assert_called_once() + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) assert len(remaining_tasks) == 0 def test_priority_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in priority RAG pipeline run task using real Redis. @@ -648,7 +657,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_rag_pipeline_run_task_tenant_isolation( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test tenant isolation in regular RAG pipeline run task using real Redis. @@ -684,8 +693,10 @@ class TestRagPipelineRunTasks: queue1.push_tasks([waiting_file_id1]) queue2.push_tasks([waiting_file_id2]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act: Execute the regular task for tenant1 only rag_pipeline_run_task(file_id1, tenant1.id) @@ -694,11 +705,12 @@ class TestRagPipelineRunTasks: assert mock_file_service["delete_file"].call_count == 1 assert mock_pipeline_generator.call_count == 1 - # Verify only tenant1's waiting task was processed - mock_delay.assert_called_once() - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 - assert call_kwargs.get("tenant_id") == tenant1.id + # Verify only tenant1's waiting task was processed (via group) + assert mock_group.return_value.apply_async.called + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id1 + assert first_kwargs.get("tenant_id") == tenant1.id # Verify tenant1's queue is empty remaining_tasks1 = queue1.pull_tasks(count=10) @@ -713,7 +725,7 @@ class TestRagPipelineRunTasks: assert queue1._task_key != queue2._task_key def test_run_single_rag_pipeline_task_success( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test successful run_single_rag_pipeline_task execution. @@ -748,7 +760,7 @@ class TestRagPipelineRunTasks: assert isinstance(call_kwargs["application_generate_entity"], RagPipelineGenerateEntity) def test_run_single_rag_pipeline_task_entity_validation_error( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with invalid entity data. @@ -793,7 +805,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_run_single_rag_pipeline_task_database_entity_not_found( - self, db_session_with_containers, mock_pipeline_generator, flask_app_with_containers + self, db_session_with_containers: Session, mock_pipeline_generator, flask_app_with_containers ): """ Test run_single_rag_pipeline_task with non-existent database entities. @@ -838,7 +850,7 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() def test_priority_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test priority RAG pipeline run task with non-existent file. @@ -888,7 +900,7 @@ class TestRagPipelineRunTasks: assert len(remaining_tasks) == 0 def test_rag_pipeline_run_task_file_not_found( - self, db_session_with_containers, mock_pipeline_generator, mock_file_service + self, db_session_with_containers: Session, mock_pipeline_generator, mock_file_service ): """ Test regular RAG pipeline run task with non-existent file. @@ -913,8 +925,10 @@ class TestRagPipelineRunTasks: waiting_file_id = str(uuid.uuid4()) queue.push_tasks([waiting_file_id]) - # Mock the task function calls - with patch("tasks.rag_pipeline.rag_pipeline_run_task.rag_pipeline_run_task.delay") as mock_delay: + # Mock the Celery group scheduling used by the implementation + with patch("tasks.rag_pipeline.rag_pipeline_run_task.group") as mock_group: + mock_group.return_value.apply_async = MagicMock() + # Act & Assert: Execute the regular task (should raise Exception) with pytest.raises(Exception, match="File not found"): rag_pipeline_run_task(file_id, tenant.id) @@ -924,12 +938,13 @@ class TestRagPipelineRunTasks: mock_pipeline_generator.assert_not_called() # Verify waiting task was still processed despite file error - mock_delay.assert_called_once() + assert mock_group.return_value.apply_async.called - # Verify correct parameters for the call - call_kwargs = mock_delay.call_args[1] if mock_delay.call_args else {} - assert call_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id - assert call_kwargs.get("tenant_id") == tenant.id + # Verify correct parameters for the first scheduled job signature + jobs = mock_group.call_args.args[0] if mock_group.call_args else [] + first_kwargs = jobs[0].kwargs if jobs else {} + assert first_kwargs.get("rag_pipeline_invoke_entities_file_id") == waiting_file_id + assert first_kwargs.get("tenant_id") == tenant.id # Verify queue is empty after processing (task was pulled) remaining_tasks = queue.pull_tasks(count=10) diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 8501a8e39b..5bded4d670 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -4,8 +4,9 @@ from unittest.mock import ANY, call, patch import pytest from core.db.session_factory import session_factory -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType +from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Tenant from models.enums import CreatorUserRole @@ -78,7 +79,7 @@ def _create_offload_data(db_session_with_containers, *, tenant_id: str, app_id: for i in range(count): upload_file = UploadFile( tenant_id=tenant_id, - storage_type="local", + storage_type=StorageType.LOCAL, key=f"test/file-{uuid.uuid4()}-{i}.json", name=f"file-{i}.json", size=1024 + i, diff --git a/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py new file mode 100644 index 0000000000..34a1941c39 --- /dev/null +++ b/api/tests/test_containers_integration_tests/test_opendal_fs_default_root.py @@ -0,0 +1,56 @@ +from pathlib import Path + +from extensions.storage.opendal_storage import OpenDALStorage + + +class TestOpenDALFsDefaultRoot: + """Test that OpenDALStorage with scheme='fs' works correctly when no root is provided.""" + + def test_fs_without_root_uses_default(self, tmp_path, monkeypatch): + """When no root is specified, the default 'storage' should be used and passed to the Operator.""" + # Change to tmp_path so the default "storage" dir is created there + monkeypatch.chdir(tmp_path) + # Ensure no OPENDAL_FS_ROOT env var is set + monkeypatch.delenv("OPENDAL_FS_ROOT", raising=False) + + storage = OpenDALStorage(scheme="fs") + + # The default directory should have been created + assert (tmp_path / "storage").is_dir() + # The storage should be functional + storage.save("test_default_root.txt", b"hello") + assert storage.exists("test_default_root.txt") + assert storage.load_once("test_default_root.txt") == b"hello" + + # Cleanup + storage.delete("test_default_root.txt") + + def test_fs_with_explicit_root(self, tmp_path): + """When root is explicitly provided, it should be used.""" + custom_root = str(tmp_path / "custom_storage") + storage = OpenDALStorage(scheme="fs", root=custom_root) + + assert Path(custom_root).is_dir() + storage.save("test_explicit_root.txt", b"world") + assert storage.exists("test_explicit_root.txt") + assert storage.load_once("test_explicit_root.txt") == b"world" + + # Cleanup + storage.delete("test_explicit_root.txt") + + def test_fs_with_env_var_root(self, tmp_path, monkeypatch): + """When OPENDAL_FS_ROOT env var is set, it should be picked up via _get_opendal_kwargs.""" + env_root = str(tmp_path / "env_storage") + monkeypatch.setenv("OPENDAL_FS_ROOT", env_root) + # Ensure .env file doesn't interfere + monkeypatch.chdir(tmp_path) + + storage = OpenDALStorage(scheme="fs") + + assert Path(env_root).is_dir() + storage.save("test_env_root.txt", b"env_data") + assert storage.exists("test_env_root.txt") + assert storage.load_once("test_env_root.txt") == b"env_data" + + # Cleanup + storage.delete("test_env_root.txt") diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index 5f4f28cf4f..ca76fa0a4b 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -27,8 +27,8 @@ import pytest from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from core.workflow.entities import WorkflowExecution -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.entities import WorkflowExecution +from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/tests/test_containers_integration_tests/trigger/conftest.py b/api/tests/test_containers_integration_tests/trigger/conftest.py index 9c1fd5e0ec..e3832fb2ef 100644 --- a/api/tests/test_containers_integration_tests/trigger/conftest.py +++ b/api/tests/test_containers_integration_tests/trigger/conftest.py @@ -105,18 +105,26 @@ def app_model( class MockCeleryGroup: - """Mock for celery group() function that collects dispatched tasks.""" + """Mock for celery group() function that collects dispatched tasks. + + Matches the Celery group API loosely, accepting arbitrary kwargs on apply_async + (e.g. producer) so production code can pass broker-related options without + breaking tests. + """ def __init__(self) -> None: self.collected: list[dict[str, Any]] = [] self._applied = False + self.last_apply_async_kwargs: dict[str, Any] | None = None def __call__(self, items: Any) -> MockCeleryGroup: self.collected = list(items) return self - def apply_async(self) -> None: + def apply_async(self, **kwargs: Any) -> None: + # Accept arbitrary kwargs like producer to be compatible with Celery self._applied = True + self.last_apply_async_kwargs = kwargs @property def applied(self) -> bool: diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 604d68f257..4ea8d8c1c7 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -14,11 +14,16 @@ from sqlalchemy.orm import Session from configs import dify_config from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus @@ -48,10 +53,10 @@ WEBHOOK_ID_DEBUG = "whdebug1234567890123456" TEST_TRIGGER_URL = "https://trigger.example.com/base" -def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: +def _build_workflow_graph(root_node_id: str, trigger_type: str) -> str: """Build a minimal workflow graph JSON for testing.""" - node_data: dict[str, Any] = {"type": trigger_type.value, "title": "trigger"} - if trigger_type == NodeType.TRIGGER_WEBHOOK: + node_data: dict[str, Any] = {"type": trigger_type, "title": "trigger"} + if trigger_type == TRIGGER_WEBHOOK_NODE_TYPE: node_data.update( { "method": "POST", @@ -64,7 +69,7 @@ def _build_workflow_graph(root_node_id: str, trigger_type: NodeType) -> str: graph = { "nodes": [ {"id": root_node_id, "data": node_data}, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": root_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -82,8 +87,8 @@ def test_publish_blocks_start_and_trigger_coexistence( graph = { "nodes": [ - {"id": "start", "data": {"type": NodeType.START.value}}, - {"id": "trig", "data": {"type": NodeType.TRIGGER_WEBHOOK.value}}, + {"id": "start", "data": {"type": BuiltinNodeTypes.START}}, + {"id": "trig", "data": {"type": TRIGGER_WEBHOOK_NODE_TYPE}}, ], "edges": [], } @@ -152,7 +157,7 @@ def test_webhook_trigger_creates_trigger_log( tenant, account = tenant_and_account webhook_node_id = "webhook-node" - graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + graph_json = _build_workflow_graph(webhook_node_id, TRIGGER_WEBHOOK_NODE_TYPE) published_workflow = Workflow.new( tenant_id=tenant.id, app_id=app_model.id, @@ -282,7 +287,7 @@ def test_schedule_visual_debug_poll_generates_event(monkeypatch: pytest.MonkeyPa node_config = { "id": "schedule-visual", "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "mode": "visual", "frequency": "daily", "visual_config": {"time": "3:00 PM"}, @@ -372,7 +377,7 @@ def test_webhook_debug_dispatches_event( """Webhook single-step debug should dispatch debug event and be pollable.""" tenant, account = tenant_and_account webhook_node_id = "webhook-debug-node" - graph_json = _build_workflow_graph(webhook_node_id, NodeType.TRIGGER_WEBHOOK) + graph_json = _build_workflow_graph(webhook_node_id, TRIGGER_WEBHOOK_NODE_TYPE) draft_workflow = Workflow.new( tenant_id=tenant.id, app_id=app_model.id, @@ -443,7 +448,7 @@ def test_plugin_single_step_debug_flow( node_config = { "id": node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin", "plugin_id": "plugin-1", "plugin_unique_identifier": "plugin-1", @@ -519,14 +524,14 @@ def test_schedule_trigger_creates_trigger_log( { "id": schedule_node_id, "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "title": "schedule", "mode": "cron", "cron_expression": "0 9 * * *", "timezone": "UTC", }, }, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": schedule_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -639,7 +644,7 @@ def test_schedule_visual_cron_conversion( node_config: dict[str, Any] = { "id": "schedule-node", "data": { - "type": NodeType.TRIGGER_SCHEDULE.value, + "type": TRIGGER_SCHEDULE_NODE_TYPE, "mode": mode, "timezone": "UTC", }, @@ -680,7 +685,7 @@ def test_plugin_trigger_full_chain_with_db_verification( { "id": plugin_node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin", "plugin_id": "test-plugin", "plugin_unique_identifier": "test-plugin", @@ -690,7 +695,7 @@ def test_plugin_trigger_full_chain_with_db_verification( "parameters": {}, }, }, - {"id": "answer-1", "data": {"type": NodeType.ANSWER.value, "title": "answer"}}, + {"id": "answer-1", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "answer"}}, ], "edges": [{"source": plugin_node_id, "target": "answer-1", "sourceHandle": "success"}], } @@ -826,7 +831,7 @@ def test_plugin_debug_via_http_endpoint( node_config = { "id": node_id, "data": { - "type": NodeType.TRIGGER_PLUGIN.value, + "type": TRIGGER_PLUGIN_NODE_TYPE, "title": "plugin-debug", "plugin_id": "debug-plugin", "plugin_unique_identifier": "debug-plugin", diff --git a/api/tests/unit_tests/commands/test_clean_expired_messages.py b/api/tests/unit_tests/commands/test_clean_expired_messages.py new file mode 100644 index 0000000000..5375988a69 --- /dev/null +++ b/api/tests/unit_tests/commands/test_clean_expired_messages.py @@ -0,0 +1,184 @@ +import datetime +import re +from unittest.mock import MagicMock, patch + +import click +import pytest + +from commands import clean_expired_messages + + +def _mock_service() -> MagicMock: + service = MagicMock() + service.run.return_value = { + "batches": 1, + "total_messages": 10, + "filtered_messages": 5, + "total_deleted": 5, + } + return service + + +def test_absolute_mode_calls_from_time_range(): + policy = object() + service = _mock_service() + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 2, 1, 0, 0, 0) + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range, + patch("commands.retention.MessagesCleanService.from_days") as mock_from_days, + ): + clean_expired_messages.callback( + batch_size=200, + graceful_period=21, + start_from=start_from, + end_before=end_before, + from_days_ago=None, + before_days=None, + dry_run=True, + ) + + mock_from_time_range.assert_called_once_with( + policy=policy, + start_from=start_from, + end_before=end_before, + batch_size=200, + dry_run=True, + task_label="custom", + ) + mock_from_days.assert_not_called() + + +def test_relative_mode_before_days_only_calls_from_days(): + policy = object() + service = _mock_service() + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_days", return_value=service) as mock_from_days, + patch("commands.retention.MessagesCleanService.from_time_range") as mock_from_time_range, + ): + clean_expired_messages.callback( + batch_size=500, + graceful_period=14, + start_from=None, + end_before=None, + from_days_ago=None, + before_days=30, + dry_run=False, + ) + + mock_from_days.assert_called_once_with( + policy=policy, + days=30, + batch_size=500, + dry_run=False, + task_label="before-30", + ) + mock_from_time_range.assert_not_called() + + +def test_relative_mode_with_from_days_ago_calls_from_time_range(): + policy = object() + service = _mock_service() + fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0) + + with ( + patch("commands.retention.create_message_clean_policy", return_value=policy), + patch("commands.retention.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range, + patch("commands.retention.MessagesCleanService.from_days") as mock_from_days, + patch("commands.retention.naive_utc_now", return_value=fixed_now), + ): + clean_expired_messages.callback( + batch_size=1000, + graceful_period=21, + start_from=None, + end_before=None, + from_days_ago=60, + before_days=30, + dry_run=False, + ) + + mock_from_time_range.assert_called_once_with( + policy=policy, + start_from=fixed_now - datetime.timedelta(days=60), + end_before=fixed_now - datetime.timedelta(days=30), + batch_size=1000, + dry_run=False, + task_label="60to30", + ) + mock_from_days.assert_not_called() + + +@pytest.mark.parametrize( + ("kwargs", "message"), + [ + ( + { + "start_from": datetime.datetime(2024, 1, 1), + "end_before": datetime.datetime(2024, 2, 1), + "from_days_ago": None, + "before_days": 30, + }, + "mutually exclusive", + ), + ( + { + "start_from": datetime.datetime(2024, 1, 1), + "end_before": None, + "from_days_ago": None, + "before_days": None, + }, + "Both --start-from and --end-before are required", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": 10, + "before_days": None, + }, + "--from-days-ago must be used together with --before-days", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": None, + "before_days": -1, + }, + "--before-days must be >= 0", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": 30, + "before_days": 30, + }, + "--from-days-ago must be greater than --before-days", + ), + ( + { + "start_from": None, + "end_before": None, + "from_days_ago": None, + "before_days": None, + }, + "You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])", + ), + ], +) +def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str): + with pytest.raises(click.UsageError, match=re.escape(message)): + clean_expired_messages.callback( + batch_size=1000, + graceful_period=21, + start_from=kwargs["start_from"], + end_before=kwargs["end_before"], + from_days_ago=kwargs["from_days_ago"], + before_days=kwargs["before_days"], + dry_run=False, + ) diff --git a/api/tests/unit_tests/commands/test_upgrade_db.py b/api/tests/unit_tests/commands/test_upgrade_db.py index 80173f5d46..5aa0313429 100644 --- a/api/tests/unit_tests/commands/test_upgrade_db.py +++ b/api/tests/unit_tests/commands/test_upgrade_db.py @@ -4,6 +4,7 @@ import types from unittest.mock import MagicMock import commands +from commands import system as system_commands from libs.db_migration_lock import LockNotOwnedError, RedisError HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0 @@ -24,11 +25,11 @@ def _invoke_upgrade_db() -> int: def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234) lock = MagicMock() lock.acquire.return_value = False - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock exit_code = _invoke_upgrade_db() captured = capsys.readouterr() @@ -36,18 +37,18 @@ def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys): assert exit_code == 0 assert "Database migration skipped" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_not_called() def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321) lock = MagicMock() lock.acquire.return_value = True lock.release.side_effect = LockNotOwnedError("simulated") - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock def _upgrade(): raise RuntimeError("boom") @@ -60,18 +61,18 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): assert exit_code == 1 assert "Database migration failed: boom" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_called_once() def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys): - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999) lock = MagicMock() lock.acquire.return_value = True lock.release.side_effect = LockNotOwnedError("simulated") - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock _install_fake_flask_migrate(monkeypatch, lambda: None) @@ -81,7 +82,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy assert exit_code == 0 assert "Database migration successful!" in captured.out - commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False) + system_commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False) lock.acquire.assert_called_once_with(blocking=False) lock.release.assert_called_once() @@ -92,11 +93,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys): """ # Use a small TTL so the heartbeat interval triggers quickly. - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) lock = MagicMock() lock.acquire.return_value = True - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock renewed = threading.Event() @@ -120,11 +121,11 @@ def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys): def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys): # Use a small TTL so heartbeat runs during the upgrade call. - monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + monkeypatch.setattr(system_commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) lock = MagicMock() lock.acquire.return_value = True - commands.redis_client.lock.return_value = lock + system_commands.redis_client.lock.return_value = lock attempted = threading.Event() diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index cf52980e57..d6933e2180 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -25,7 +25,8 @@ def test_dify_config(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "300") # Custom value for testing # load dotenv file with pydantic-settings - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # constant values assert config.COMMIT_SHA == "" @@ -59,7 +60,8 @@ def test_http_timeout_defaults(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("DB_PORT", "5432") monkeypatch.setenv("DB_DATABASE", "dify") - config = DifyConfig() + # Disable `.env` loading to ensure test stability across environments + config = DifyConfig(_env_file=None) # Verify default timeout values assert config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT == 10 @@ -86,7 +88,8 @@ def test_flask_configs(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*") monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/") - flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore + # Disable `.env` loading to ensure test stability across environments + flask_app.config.from_mapping(DifyConfig(_env_file=None).model_dump()) # pyright: ignore config = flask_app.config # configs read from pydantic-settings diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index d2111ebac8..3f75fd2851 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -32,11 +32,6 @@ os.environ.setdefault("OPENDAL_SCHEME", "fs") os.environ.setdefault("OPENDAL_FS_ROOT", "/tmp/dify-storage") os.environ.setdefault("STORAGE_TYPE", "opendal") -# Add the API directory to Python path to ensure proper imports -import sys - -sys.path.insert(0, PROJECT_DIR) - from core.db.session_factory import configure_session_factory, session_factory from extensions import ext_redis diff --git a/api/tests/unit_tests/controllers/common/test_errors.py b/api/tests/unit_tests/controllers/common/test_errors.py new file mode 100644 index 0000000000..25a9fe5b66 --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_errors.py @@ -0,0 +1,70 @@ +from controllers.common.errors import ( + BlockedFileExtensionError, + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + RemoteFileUploadError, + TooManyFilesError, + UnsupportedFileTypeError, +) + + +class TestFilenameNotExistsError: + def test_defaults(self): + error = FilenameNotExistsError() + + assert error.code == 400 + assert error.description == "The specified filename does not exist." + + +class TestRemoteFileUploadError: + def test_defaults(self): + error = RemoteFileUploadError() + + assert error.code == 400 + assert error.description == "Error uploading remote file." + + +class TestFileTooLargeError: + def test_defaults(self): + error = FileTooLargeError() + + assert error.code == 413 + assert error.error_code == "file_too_large" + assert error.description == "File size exceeded. {message}" + + +class TestUnsupportedFileTypeError: + def test_defaults(self): + error = UnsupportedFileTypeError() + + assert error.code == 415 + assert error.error_code == "unsupported_file_type" + assert error.description == "File type not allowed." + + +class TestBlockedFileExtensionError: + def test_defaults(self): + error = BlockedFileExtensionError() + + assert error.code == 400 + assert error.error_code == "file_extension_blocked" + assert error.description == "The file extension is blocked for security reasons." + + +class TestTooManyFilesError: + def test_defaults(self): + error = TooManyFilesError() + + assert error.code == 400 + assert error.error_code == "too_many_files" + assert error.description == "Only one file is allowed." + + +class TestNoFileUploadedError: + def test_defaults(self): + error = NoFileUploadedError() + + assert error.code == 400 + assert error.error_code == "no_file_uploaded" + assert error.description == "Please upload your file." diff --git a/api/tests/unit_tests/controllers/common/test_file_response.py b/api/tests/unit_tests/controllers/common/test_file_response.py index 2487c362bd..b7500fb7f9 100644 --- a/api/tests/unit_tests/controllers/common/test_file_response.py +++ b/api/tests/unit_tests/controllers/common/test_file_response.py @@ -1,22 +1,95 @@ from flask import Response -from controllers.common.file_response import enforce_download_for_html, is_html_content +from controllers.common.file_response import ( + _normalize_mime_type, + enforce_download_for_html, + is_html_content, +) -class TestFileResponseHelpers: - def test_is_html_content_detects_mime_type(self): +class TestNormalizeMimeType: + def test_returns_empty_string_for_none(self): + assert _normalize_mime_type(None) == "" + + def test_returns_empty_string_for_empty_string(self): + assert _normalize_mime_type("") == "" + + def test_normalizes_mime_type(self): + assert _normalize_mime_type("Text/HTML; Charset=UTF-8") == "text/html" + + +class TestIsHtmlContent: + def test_detects_html_via_mime_type(self): mime_type = "text/html; charset=UTF-8" - result = is_html_content(mime_type, filename="file.txt", extension="txt") + result = is_html_content( + mime_type=mime_type, + filename="file.txt", + extension="txt", + ) assert result is True - def test_is_html_content_detects_extension(self): - result = is_html_content("text/plain", filename="report.html", extension=None) + def test_detects_html_via_extension_argument(self): + result = is_html_content( + mime_type="text/plain", + filename=None, + extension="html", + ) assert result is True - def test_enforce_download_for_html_sets_headers(self): + def test_detects_html_via_filename_extension(self): + result = is_html_content( + mime_type="text/plain", + filename="report.html", + extension=None, + ) + + assert result is True + + def test_returns_false_when_no_html_detected_anywhere(self): + """ + Missing negative test: + - MIME type is not HTML + - filename has no HTML extension + - extension argument is not HTML + """ + result = is_html_content( + mime_type="application/json", + filename="data.json", + extension="json", + ) + + assert result is False + + def test_returns_false_when_all_inputs_are_none(self): + result = is_html_content( + mime_type=None, + filename=None, + extension=None, + ) + + assert result is False + + +class TestEnforceDownloadForHtml: + def test_sets_attachment_when_filename_missing(self): + response = Response("payload", mimetype="text/html") + + updated = enforce_download_for_html( + response, + mime_type="text/html", + filename=None, + extension="html", + ) + + assert updated is True + assert response.headers["Content-Disposition"] == "attachment" + assert response.headers["Content-Type"] == "application/octet-stream" + assert response.headers["X-Content-Type-Options"] == "nosniff" + + def test_sets_headers_when_filename_present(self): response = Response("payload", mimetype="text/html") updated = enforce_download_for_html( @@ -27,11 +100,12 @@ class TestFileResponseHelpers: ) assert updated is True - assert "attachment" in response.headers["Content-Disposition"] + assert response.headers["Content-Disposition"].startswith("attachment") + assert "unsafe.html" in response.headers["Content-Disposition"] assert response.headers["Content-Type"] == "application/octet-stream" assert response.headers["X-Content-Type-Options"] == "nosniff" - def test_enforce_download_for_html_no_change_for_non_html(self): + def test_does_not_modify_response_for_non_html_content(self): response = Response("payload", mimetype="text/plain") updated = enforce_download_for_html( diff --git a/api/tests/unit_tests/controllers/common/test_helpers.py b/api/tests/unit_tests/controllers/common/test_helpers.py new file mode 100644 index 0000000000..59c463177c --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_helpers.py @@ -0,0 +1,188 @@ +from uuid import UUID + +import httpx +import pytest + +from controllers.common import helpers +from controllers.common.helpers import FileInfo, guess_file_info_from_response + + +def make_response( + url="https://example.com/file.txt", + headers=None, + content=None, +): + return httpx.Response( + 200, + request=httpx.Request("GET", url), + headers=headers or {}, + content=content or b"", + ) + + +class TestGuessFileInfoFromResponse: + def test_filename_from_url(self): + response = make_response( + url="https://example.com/test.pdf", + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + assert info.filename == "test.pdf" + assert info.extension == ".pdf" + assert info.mimetype == "application/pdf" + + def test_filename_from_content_disposition(self): + headers = { + "Content-Disposition": "attachment; filename=myfile.csv", + "Content-Type": "text/csv", + } + response = make_response( + url="https://example.com/", + headers=headers, + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + assert info.filename == "myfile.csv" + assert info.extension == ".csv" + assert info.mimetype == "text/csv" + + @pytest.mark.parametrize( + ("magic_available", "expected_ext"), + [ + (True, "txt"), + (False, "bin"), + ], + ) + def test_generated_filename_when_missing(self, monkeypatch, magic_available, expected_ext): + if magic_available: + if helpers.magic is None: + pytest.skip("python-magic is not installed, cannot run 'magic_available=True' test variant") + else: + monkeypatch.setattr(helpers, "magic", None) + + response = make_response( + url="https://example.com/", + content=b"Hello World", + ) + + info = guess_file_info_from_response(response) + + name, ext = info.filename.split(".") + UUID(name) + assert ext == expected_ext + + def test_mimetype_from_header_when_unknown(self): + headers = {"Content-Type": "application/json"} + response = make_response( + url="https://example.com/file.unknown", + headers=headers, + content=b'{"a": 1}', + ) + + info = guess_file_info_from_response(response) + + assert info.mimetype == "application/json" + + def test_extension_added_when_missing(self): + headers = {"Content-Type": "image/png"} + response = make_response( + url="https://example.com/image", + headers=headers, + content=b"fakepngdata", + ) + + info = guess_file_info_from_response(response) + + assert info.extension == ".png" + assert info.filename.endswith(".png") + + def test_content_length_used_as_size(self): + headers = { + "Content-Length": "1234", + "Content-Type": "text/plain", + } + response = make_response( + url="https://example.com/a.txt", + headers=headers, + content=b"a" * 1234, + ) + + info = guess_file_info_from_response(response) + + assert info.size == 1234 + + def test_size_minus_one_when_header_missing(self): + response = make_response(url="https://example.com/a.txt") + + info = guess_file_info_from_response(response) + + assert info.size == -1 + + def test_fallback_to_bin_extension(self): + headers = {"Content-Type": "application/octet-stream"} + response = make_response( + url="https://example.com/download", + headers=headers, + content=b"\x00\x01\x02\x03", + ) + + info = guess_file_info_from_response(response) + + assert info.extension == ".bin" + assert info.filename.endswith(".bin") + + def test_return_type(self): + response = make_response() + + info = guess_file_info_from_response(response) + + assert isinstance(info, FileInfo) + + +class TestMagicImportWarnings: + @pytest.mark.parametrize( + ("platform_name", "expected_message"), + [ + ("Windows", "pip install python-magic-bin"), + ("Darwin", "brew install libmagic"), + ("Linux", "sudo apt-get install libmagic1"), + ("Other", "install `libmagic`"), + ], + ) + def test_magic_import_warning_per_platform( + self, + monkeypatch, + platform_name, + expected_message, + ): + import builtins + import importlib + + # Force ImportError when "magic" is imported + real_import = builtins.__import__ + + def fake_import(name, *args, **kwargs): + if name == "magic": + raise ImportError("No module named magic") + return real_import(name, *args, **kwargs) + + monkeypatch.setattr(builtins, "__import__", fake_import) + monkeypatch.setattr("platform.system", lambda: platform_name) + + # Remove helpers so it imports fresh + import sys + + original_helpers = sys.modules.get(helpers.__name__) + sys.modules.pop(helpers.__name__, None) + + try: + with pytest.warns(UserWarning, match="To use python-magic") as warning: + imported_helpers = importlib.import_module(helpers.__name__) + assert expected_message in str(warning[0].message) + finally: + if original_helpers is not None: + sys.modules[helpers.__name__] = original_helpers diff --git a/api/tests/unit_tests/controllers/common/test_schema.py b/api/tests/unit_tests/controllers/common/test_schema.py new file mode 100644 index 0000000000..56c8160f02 --- /dev/null +++ b/api/tests/unit_tests/controllers/common/test_schema.py @@ -0,0 +1,189 @@ +import sys +from enum import StrEnum +from unittest.mock import MagicMock, patch + +import pytest +from flask_restx import Namespace +from pydantic import BaseModel + + +class UserModel(BaseModel): + id: int + name: str + + +class ProductModel(BaseModel): + id: int + price: float + + +@pytest.fixture(autouse=True) +def mock_console_ns(): + """Mock the console_ns to avoid circular imports during test collection.""" + mock_ns = MagicMock(spec=Namespace) + mock_ns.models = {} + + # Inject mock before importing schema module + with patch.dict(sys.modules, {"controllers.console": MagicMock(console_ns=mock_ns)}): + yield mock_ns + + +def test_default_ref_template_value(): + from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0 + + assert DEFAULT_REF_TEMPLATE_SWAGGER_2_0 == "#/definitions/{model}" + + +def test_register_schema_model_calls_namespace_schema_model(): + from controllers.common.schema import register_schema_model + + namespace = MagicMock(spec=Namespace) + + register_schema_model(namespace, UserModel) + + namespace.schema_model.assert_called_once() + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "UserModel" + assert isinstance(schema, dict) + assert "properties" in schema + + +def test_register_schema_model_passes_schema_from_pydantic(): + from controllers.common.schema import DEFAULT_REF_TEMPLATE_SWAGGER_2_0, register_schema_model + + namespace = MagicMock(spec=Namespace) + + register_schema_model(namespace, UserModel) + + schema = namespace.schema_model.call_args.args[1] + + expected_schema = UserModel.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0) + + assert schema == expected_schema + + +def test_register_schema_models_registers_multiple_models(): + from controllers.common.schema import register_schema_models + + namespace = MagicMock(spec=Namespace) + + register_schema_models(namespace, UserModel, ProductModel) + + assert namespace.schema_model.call_count == 2 + + called_names = [call.args[0] for call in namespace.schema_model.call_args_list] + assert called_names == ["UserModel", "ProductModel"] + + +def test_register_schema_models_calls_register_schema_model(monkeypatch): + from controllers.common.schema import register_schema_models + + namespace = MagicMock(spec=Namespace) + + calls = [] + + def fake_register(ns, model): + calls.append((ns, model)) + + monkeypatch.setattr( + "controllers.common.schema.register_schema_model", + fake_register, + ) + + register_schema_models(namespace, UserModel, ProductModel) + + assert calls == [ + (namespace, UserModel), + (namespace, ProductModel), + ] + + +class StatusEnum(StrEnum): + ACTIVE = "active" + INACTIVE = "inactive" + + +class PriorityEnum(StrEnum): + HIGH = "high" + LOW = "low" + + +def test_get_or_create_model_returns_existing_model(mock_console_ns): + from controllers.common.schema import get_or_create_model + + existing_model = MagicMock() + mock_console_ns.models = {"TestModel": existing_model} + + result = get_or_create_model("TestModel", {"key": "value"}) + + assert result == existing_model + mock_console_ns.model.assert_not_called() + + +def test_get_or_create_model_creates_new_model_when_not_exists(mock_console_ns): + from controllers.common.schema import get_or_create_model + + mock_console_ns.models = {} + new_model = MagicMock() + mock_console_ns.model.return_value = new_model + field_def = {"name": {"type": "string"}} + + result = get_or_create_model("NewModel", field_def) + + assert result == new_model + mock_console_ns.model.assert_called_once_with("NewModel", field_def) + + +def test_get_or_create_model_does_not_call_model_if_exists(mock_console_ns): + from controllers.common.schema import get_or_create_model + + existing_model = MagicMock() + mock_console_ns.models = {"ExistingModel": existing_model} + + result = get_or_create_model("ExistingModel", {"key": "value"}) + + assert result == existing_model + mock_console_ns.model.assert_not_called() + + +def test_register_enum_models_registers_single_enum(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum) + + namespace.schema_model.assert_called_once() + + model_name, schema = namespace.schema_model.call_args.args + + assert model_name == "StatusEnum" + assert isinstance(schema, dict) + + +def test_register_enum_models_registers_multiple_enums(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum, PriorityEnum) + + assert namespace.schema_model.call_count == 2 + + called_names = [call.args[0] for call in namespace.schema_model.call_args_list] + assert called_names == ["StatusEnum", "PriorityEnum"] + + +def test_register_enum_models_uses_correct_ref_template(): + from controllers.common.schema import register_enum_models + + namespace = MagicMock(spec=Namespace) + + register_enum_models(namespace, StatusEnum) + + schema = namespace.schema_model.call_args.args[1] + + # Verify the schema contains enum values + assert "enum" in schema or "anyOf" in schema diff --git a/api/tests/unit_tests/controllers/console/app/__init__.py b/api/tests/unit_tests/controllers/console/app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_api.py b/api/tests/unit_tests/controllers/console/app/test_annotation_api.py new file mode 100644 index 0000000000..fecbd7f7b0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_api.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from controllers.console.app import annotation as annotation_module + + +def test_annotation_reply_payload_valid(): + """Test AnnotationReplyPayload with valid data.""" + payload = annotation_module.AnnotationReplyPayload( + score_threshold=0.5, + embedding_provider_name="openai", + embedding_model_name="text-embedding-3-small", + ) + assert payload.score_threshold == 0.5 + assert payload.embedding_provider_name == "openai" + assert payload.embedding_model_name == "text-embedding-3-small" + + +def test_annotation_setting_update_payload_valid(): + """Test AnnotationSettingUpdatePayload with valid data.""" + payload = annotation_module.AnnotationSettingUpdatePayload( + score_threshold=0.75, + ) + assert payload.score_threshold == 0.75 + + +def test_annotation_list_query_defaults(): + """Test AnnotationListQuery with default parameters.""" + query = annotation_module.AnnotationListQuery() + assert query.page == 1 + assert query.limit == 20 + assert query.keyword == "" + + +def test_annotation_list_query_custom_page(): + """Test AnnotationListQuery with custom page.""" + query = annotation_module.AnnotationListQuery(page=3, limit=50) + assert query.page == 3 + assert query.limit == 50 + + +def test_annotation_list_query_with_keyword(): + """Test AnnotationListQuery with keyword.""" + query = annotation_module.AnnotationListQuery(keyword="test") + assert query.keyword == "test" + + +def test_create_annotation_payload_with_message_id(): + """Test CreateAnnotationPayload with message ID.""" + payload = annotation_module.CreateAnnotationPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + question="What is AI?", + ) + assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000" + assert payload.question == "What is AI?" + + +def test_create_annotation_payload_with_text(): + """Test CreateAnnotationPayload with text content.""" + payload = annotation_module.CreateAnnotationPayload( + question="What is ML?", + answer="Machine learning is...", + ) + assert payload.question == "What is ML?" + assert payload.answer == "Machine learning is..." + + +def test_update_annotation_payload(): + """Test UpdateAnnotationPayload.""" + payload = annotation_module.UpdateAnnotationPayload( + question="Updated question", + answer="Updated answer", + ) + assert payload.question == "Updated question" + assert payload.answer == "Updated answer" + + +def test_annotation_reply_status_query_enable(): + """Test AnnotationReplyStatusQuery with enable action.""" + query = annotation_module.AnnotationReplyStatusQuery(action="enable") + assert query.action == "enable" + + +def test_annotation_reply_status_query_disable(): + """Test AnnotationReplyStatusQuery with disable action.""" + query = annotation_module.AnnotationReplyStatusQuery(action="disable") + assert query.action == "disable" + + +def test_annotation_file_payload_valid(): + """Test AnnotationFilePayload with valid message ID.""" + payload = annotation_module.AnnotationFilePayload(message_id="550e8400-e29b-41d4-a716-446655440000") + assert payload.message_id == "550e8400-e29b-41d4-a716-446655440000" diff --git a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py index 06a7b98baf..9f1ff9b40f 100644 --- a/api/tests/unit_tests/controllers/console/app/test_annotation_security.py +++ b/api/tests/unit_tests/controllers/console/app/test_annotation_security.py @@ -13,6 +13,9 @@ from pandas.errors import ParserError from werkzeug.datastructures import FileStorage from configs import dify_config +from controllers.console.wraps import annotation_import_concurrency_limit, annotation_import_rate_limit +from services.annotation_service import AppAnnotationService +from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task class TestAnnotationImportRateLimiting: @@ -33,8 +36,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_per_minute_enforced(self, mock_redis, mock_current_account): """Test that per-minute rate limit is enforced.""" - from controllers.console.wraps import annotation_import_rate_limit - # Simulate exceeding per-minute limit mock_redis.zcard.side_effect = [ dify_config.ANNOTATION_IMPORT_RATE_LIMIT_PER_MINUTE + 1, # Minute check @@ -54,7 +55,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_per_hour_enforced(self, mock_redis, mock_current_account): """Test that per-hour rate limit is enforced.""" - from controllers.console.wraps import annotation_import_rate_limit # Simulate exceeding per-hour limit mock_redis.zcard.side_effect = [ @@ -74,7 +74,6 @@ class TestAnnotationImportRateLimiting: def test_rate_limit_within_limits_passes(self, mock_redis, mock_current_account): """Test that requests within limits are allowed.""" - from controllers.console.wraps import annotation_import_rate_limit # Simulate being under both limits mock_redis.zcard.return_value = 2 @@ -110,7 +109,6 @@ class TestAnnotationImportConcurrencyControl: def test_concurrency_limit_enforced(self, mock_redis, mock_current_account): """Test that concurrent task limit is enforced.""" - from controllers.console.wraps import annotation_import_concurrency_limit # Simulate max concurrent tasks already running mock_redis.zcard.return_value = dify_config.ANNOTATION_IMPORT_MAX_CONCURRENT @@ -127,7 +125,6 @@ class TestAnnotationImportConcurrencyControl: def test_concurrency_within_limit_passes(self, mock_redis, mock_current_account): """Test that requests within concurrency limits are allowed.""" - from controllers.console.wraps import annotation_import_concurrency_limit # Simulate being under concurrent task limit mock_redis.zcard.return_value = 1 @@ -142,7 +139,6 @@ class TestAnnotationImportConcurrencyControl: def test_stale_jobs_are_cleaned_up(self, mock_redis, mock_current_account): """Test that old/stale job entries are removed.""" - from controllers.console.wraps import annotation_import_concurrency_limit mock_redis.zcard.return_value = 0 @@ -203,7 +199,6 @@ class TestAnnotationImportServiceValidation: def test_max_records_limit_enforced(self, mock_app, mock_db_session): """Test that files with too many records are rejected.""" - from services.annotation_service import AppAnnotationService # Create CSV with too many records max_records = dify_config.ANNOTATION_IMPORT_MAX_RECORDS @@ -229,7 +224,6 @@ class TestAnnotationImportServiceValidation: def test_min_records_limit_enforced(self, mock_app, mock_db_session): """Test that files with too few valid records are rejected.""" - from services.annotation_service import AppAnnotationService # Create CSV with only header (no data rows) csv_content = "question,answer\n" @@ -249,7 +243,6 @@ class TestAnnotationImportServiceValidation: def test_invalid_csv_format_handled(self, mock_app, mock_db_session): """Test that invalid CSV format is handled gracefully.""" - from services.annotation_service import AppAnnotationService # Any content is fine once we force ParserError csv_content = 'invalid,csv,format\nwith,unbalanced,quotes,and"stuff' @@ -270,7 +263,6 @@ class TestAnnotationImportServiceValidation: def test_valid_import_succeeds(self, mock_app, mock_db_session): """Test that valid import request succeeds.""" - from services.annotation_service import AppAnnotationService # Create valid CSV csv_content = "question,answer\nWhat is AI?,Artificial Intelligence\nWhat is ML?,Machine Learning\n" @@ -300,18 +292,10 @@ class TestAnnotationImportServiceValidation: class TestAnnotationImportTaskOptimization: """Test optimizations in batch import task.""" - def test_task_has_timeout_configured(self): - """Test that task has proper timeout configuration.""" - from tasks.annotation.batch_import_annotations_task import batch_import_annotations_task - - # Verify task configuration - assert hasattr(batch_import_annotations_task, "time_limit") - assert hasattr(batch_import_annotations_task, "soft_time_limit") - - # Check timeout values are reasonable - # Hard limit should be 6 minutes (360s) - # Soft limit should be 5 minutes (300s) - # Note: actual values depend on Celery configuration + def test_task_is_registered_with_queue(self): + """Test that task is registered with the correct queue.""" + assert hasattr(batch_import_annotations_task, "apply_async") + assert hasattr(batch_import_annotations_task, "delay") class TestConfigurationValues: diff --git a/api/tests/unit_tests/controllers/console/app/test_app_apis.py b/api/tests/unit_tests/controllers/console/app/test_app_apis.py new file mode 100644 index 0000000000..60b8ee96fe --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_apis.py @@ -0,0 +1,586 @@ +""" +Additional tests to improve coverage for low-coverage modules in controllers/console/app. +Target: increase coverage for files with <75% coverage. +""" + +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.app import ( + annotation as annotation_module, +) +from controllers.console.app import ( + completion as completion_module, +) +from controllers.console.app import ( + message as message_module, +) +from controllers.console.app import ( + ops_trace as ops_trace_module, +) +from controllers.console.app import ( + site as site_module, +) +from controllers.console.app import ( + statistic as statistic_module, +) +from controllers.console.app import ( + workflow_app_log as workflow_app_log_module, +) +from controllers.console.app import ( + workflow_draft_variable as workflow_draft_variable_module, +) +from controllers.console.app import ( + workflow_statistic as workflow_statistic_module, +) +from controllers.console.app import ( + workflow_trigger as workflow_trigger_module, +) +from controllers.console.app import ( + wraps as wraps_module, +) +from controllers.console.app.completion import ChatMessagePayload, CompletionMessagePayload +from controllers.console.app.mcp_server import MCPServerCreatePayload, MCPServerUpdatePayload +from controllers.console.app.ops_trace import TraceConfigPayload, TraceProviderQuery +from controllers.console.app.site import AppSiteUpdatePayload +from controllers.console.app.workflow import AdvancedChatWorkflowRunPayload, SyncDraftWorkflowPayload +from controllers.console.app.workflow_app_log import WorkflowAppLogQuery +from controllers.console.app.workflow_draft_variable import WorkflowDraftVariableUpdatePayload +from controllers.console.app.workflow_statistic import WorkflowStatisticQuery +from controllers.console.app.workflow_trigger import Parser, ParserEnable + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _ConnContext: + def __init__(self, rows): + self._rows = rows + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _query, _args): + return self._rows + + +# ========== Completion Tests ========== +class TestCompletionEndpoints: + """Tests for completion API endpoints.""" + + def test_completion_create_payload(self): + """Test completion creation payload.""" + payload = CompletionMessagePayload(inputs={"prompt": "test"}, model_config={}) + assert payload.inputs == {"prompt": "test"} + + def test_chat_message_payload_uuid_validation(self): + payload = ChatMessagePayload( + inputs={}, + model_config={}, + query="hi", + conversation_id=str(uuid.uuid4()), + parent_message_id=str(uuid.uuid4()), + ) + assert payload.query == "hi" + + def test_completion_api_success(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: {"text": "ok"}, + ) + monkeypatch.setattr( + completion_module.helper, + "compact_generate_response", + lambda response: {"result": response}, + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + resp = method(app_model=MagicMock(id="app-1")) + + assert resp == {"result": {"text": "ok"}} + + def test_completion_api_conversation_not_exists(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw( + completion_module.services.errors.conversation.ConversationNotExistsError() + ), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(NotFound): + method(app_model=MagicMock(id="app-1")) + + def test_completion_api_provider_not_initialized(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw(completion_module.ProviderTokenNotInitError("x")), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(app_model=MagicMock(id="app-1")) + + def test_completion_api_quota_exceeded(self, app, monkeypatch): + api = completion_module.CompletionMessageApi() + method = _unwrap(api.post) + + class DummyAccount: + pass + + dummy_account = DummyAccount() + + monkeypatch.setattr(completion_module, "current_user", dummy_account) + monkeypatch.setattr(completion_module, "Account", DummyAccount) + monkeypatch.setattr( + completion_module.AppGenerateService, + "generate", + lambda **_kwargs: (_ for _ in ()).throw(completion_module.QuotaExceededError()), + ) + + with app.test_request_context( + "/", + json={"inputs": {}, "model_config": {}, "query": "hi"}, + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(app_model=MagicMock(id="app-1")) + + +# ========== OpsTrace Tests ========== +class TestOpsTraceEndpoints: + """Tests for ops_trace endpoint.""" + + def test_ops_trace_query_basic(self): + """Test ops_trace query.""" + query = TraceProviderQuery(tracing_provider="langfuse") + assert query.tracing_provider == "langfuse" + + def test_ops_trace_config_payload(self): + payload = TraceConfigPayload(tracing_provider="langfuse", tracing_config={"api_key": "k"}) + assert payload.tracing_config["api_key"] == "k" + + def test_trace_app_config_get_empty(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "get_tracing_app_config", + lambda **_kwargs: None, + ) + + with app.test_request_context("/?tracing_provider=langfuse"): + result = method(app_id="app-1") + + assert result == {"has_not_configured": True} + + def test_trace_app_config_post_invalid(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "create_tracing_app_config", + lambda **_kwargs: {"error": True}, + ) + + with app.test_request_context( + "/", + json={"tracing_provider": "langfuse", "tracing_config": {"api_key": "k"}}, + ): + with pytest.raises(BadRequest): + method(app_id="app-1") + + def test_trace_app_config_delete_not_found(self, app, monkeypatch): + api = ops_trace_module.TraceAppConfigApi() + method = _unwrap(api.delete) + + monkeypatch.setattr( + ops_trace_module.OpsService, + "delete_tracing_app_config", + lambda **_kwargs: False, + ) + + with app.test_request_context("/?tracing_provider=langfuse"): + with pytest.raises(BadRequest): + method(app_id="app-1") + + +# ========== Site Tests ========== +class TestSiteEndpoints: + """Tests for site endpoint.""" + + def test_site_response_structure(self): + """Test site response structure.""" + payload = AppSiteUpdatePayload(title="My Site", description="Test site") + assert payload.title == "My Site" + + def test_site_default_language_validation(self): + payload = AppSiteUpdatePayload(default_language="en-US") + assert payload.default_language == "en-US" + + def test_app_site_update_post(self, app, monkeypatch): + api = site_module.AppSite() + method = _unwrap(api.post) + + site = MagicMock() + query = MagicMock() + query.where.return_value.first.return_value = site + monkeypatch.setattr( + site_module.db, + "session", + MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + ) + monkeypatch.setattr( + site_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") + + with app.test_request_context("/", json={"title": "My Site"}): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is site + + def test_app_site_access_token_reset(self, app, monkeypatch): + api = site_module.AppSiteAccessTokenReset() + method = _unwrap(api.post) + + site = MagicMock() + query = MagicMock() + query.where.return_value.first.return_value = site + monkeypatch.setattr( + site_module.db, + "session", + MagicMock(query=lambda *_args, **_kwargs: query, commit=lambda: None), + ) + monkeypatch.setattr(site_module.Site, "generate_code", lambda *_args, **_kwargs: "code") + monkeypatch.setattr( + site_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr(site_module, "naive_utc_now", lambda: "now") + + with app.test_request_context("/"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is site + + +# ========== Workflow Tests ========== +class TestWorkflowEndpoints: + """Tests for workflow endpoints.""" + + def test_workflow_copy_payload(self): + """Test workflow copy payload.""" + payload = SyncDraftWorkflowPayload(graph={}, features={}) + assert payload.graph == {} + + def test_workflow_mode_query(self): + """Test workflow mode query.""" + payload = AdvancedChatWorkflowRunPayload(inputs={}, query="hi") + assert payload.query == "hi" + + +# ========== Workflow App Log Tests ========== +class TestWorkflowAppLogEndpoints: + """Tests for workflow app log endpoints.""" + + def test_workflow_app_log_query(self): + """Test workflow app log query.""" + query = WorkflowAppLogQuery(keyword="test", page=1, limit=20) + assert query.keyword == "test" + + def test_workflow_app_log_query_detail_bool(self): + query = WorkflowAppLogQuery(detail="true") + assert query.detail is True + + def test_workflow_app_log_api_get(self, app, monkeypatch): + api = workflow_app_log_module.WorkflowAppLogApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_app_log_module, "db", SimpleNamespace(engine=MagicMock())) + + class DummySession: + def __enter__(self): + return "session" + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_app_log_module, "Session", lambda *args, **kwargs: DummySession()) + + def fake_get_paginate(self, **_kwargs): + return {"items": [], "total": 0} + + monkeypatch.setattr( + workflow_app_log_module.WorkflowAppService, + "get_paginate_workflow_app_logs", + fake_get_paginate, + ) + + with app.test_request_context("/?page=1&limit=20"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result == {"items": [], "total": 0} + + +# ========== Workflow Draft Variable Tests ========== +class TestWorkflowDraftVariableEndpoints: + """Tests for workflow draft variable endpoints.""" + + def test_workflow_variable_creation(self): + """Test workflow variable creation.""" + payload = WorkflowDraftVariableUpdatePayload(name="var1", value="test") + assert payload.name == "var1" + + def test_workflow_variable_collection_get(self, app, monkeypatch): + api = workflow_draft_variable_module.WorkflowVariableCollectionApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_draft_variable_module, "db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr(workflow_draft_variable_module, "current_user", SimpleNamespace(id="user-1")) + + class DummySession: + def __enter__(self): + return "session" + + def __exit__(self, exc_type, exc, tb): + return False + + class DummyDraftService: + def __init__(self, session): + self.session = session + + def list_variables_without_values(self, **_kwargs): + return {"items": [], "total": 0} + + monkeypatch.setattr(workflow_draft_variable_module, "Session", lambda *args, **kwargs: DummySession()) + + class DummyWorkflowService: + def is_workflow_exist(self, *args, **kwargs): + return True + + monkeypatch.setattr(workflow_draft_variable_module, "WorkflowDraftVariableService", DummyDraftService) + monkeypatch.setattr(workflow_draft_variable_module, "WorkflowService", DummyWorkflowService) + + with app.test_request_context("/?page=1&limit=20"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result == {"items": [], "total": 0} + + +# ========== Workflow Statistic Tests ========== +class TestWorkflowStatisticEndpoints: + """Tests for workflow statistic endpoints.""" + + def test_workflow_statistic_time_range(self): + """Test workflow statistic time range query.""" + query = WorkflowStatisticQuery(start="2024-01-01", end="2024-12-31") + assert query.start == "2024-01-01" + + def test_workflow_statistic_blank_to_none(self): + query = WorkflowStatisticQuery(start="", end="") + assert query.start is None + assert query.end is None + + def test_workflow_daily_runs_statistic(self, app, monkeypatch): + monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr( + workflow_statistic_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: SimpleNamespace(get_daily_runs_statistics=lambda **_kw: [{"date": "2024-01-01"}]), + ) + monkeypatch.setattr( + workflow_statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + workflow_statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + + api = workflow_statistic_module.WorkflowDailyRunsStatistic() + method = _unwrap(api.get) + + with app.test_request_context("/"): + response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-01"}]} + + def test_workflow_daily_terminals_statistic(self, app, monkeypatch): + monkeypatch.setattr(workflow_statistic_module, "db", SimpleNamespace(engine=MagicMock())) + monkeypatch.setattr( + workflow_statistic_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: SimpleNamespace( + get_daily_terminals_statistics=lambda **_kw: [{"date": "2024-01-02"}] + ), + ) + monkeypatch.setattr( + workflow_statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + workflow_statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + + api = workflow_statistic_module.WorkflowDailyTerminalsStatistic() + method = _unwrap(api.get) + + with app.test_request_context("/"): + response = method(app_model=SimpleNamespace(tenant_id="t1", id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02"}]} + + +# ========== Workflow Trigger Tests ========== +class TestWorkflowTriggerEndpoints: + """Tests for workflow trigger endpoints.""" + + def test_webhook_trigger_payload(self): + """Test webhook trigger payload.""" + payload = Parser(node_id="node-1") + assert payload.node_id == "node-1" + + enable_payload = ParserEnable(trigger_id="trigger-1", enable_trigger=True) + assert enable_payload.enable_trigger is True + + def test_webhook_trigger_api_get(self, app, monkeypatch): + api = workflow_trigger_module.WebhookTriggerApi() + method = _unwrap(api.get) + + monkeypatch.setattr(workflow_trigger_module, "db", SimpleNamespace(engine=MagicMock())) + + trigger = MagicMock() + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = trigger + + class DummySession: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(workflow_trigger_module, "Session", lambda *_args, **_kwargs: DummySession()) + + with app.test_request_context("/?node_id=node-1"): + result = method(app_model=SimpleNamespace(id="app-1")) + + assert result is trigger + + +# ========== Wraps Tests ========== +class TestWrapsEndpoints: + """Tests for wraps utility functions.""" + + def test_get_app_model_context(self): + """Test get_app_model wrapper context.""" + # These are decorator functions, so we test their availability + assert hasattr(wraps_module, "get_app_model") + + +# ========== MCP Server Tests ========== +class TestMCPServerEndpoints: + """Tests for MCP server endpoints.""" + + def test_mcp_server_connection(self): + """Test MCP server connection.""" + payload = MCPServerCreatePayload(parameters={"url": "http://localhost:3000"}) + assert payload.parameters["url"] == "http://localhost:3000" + + def test_mcp_server_update_payload(self): + payload = MCPServerUpdatePayload(id="server-1", parameters={"timeout": 30}, status="active") + assert payload.status == "active" + + +# ========== Error Handling Tests ========== +class TestErrorHandling: + """Tests for error handling in various endpoints.""" + + def test_annotation_list_query_validation(self): + """Test annotation list query validation.""" + with pytest.raises(ValueError): + annotation_module.AnnotationListQuery(page=0) + + +# ========== Integration-like Tests ========== +class TestPayloadIntegration: + """Integration tests for payload handling.""" + + def test_multiple_payload_types(self): + """Test handling of multiple payload types.""" + payloads = [ + annotation_module.AnnotationReplyPayload( + score_threshold=0.5, embedding_provider_name="openai", embedding_model_name="text-embedding-3-small" + ), + message_module.MessageFeedbackPayload(message_id=str(uuid.uuid4()), rating="like"), + statistic_module.StatisticTimeRangeQuery(start="2024-01-01"), + ] + assert len(payloads) == 3 + assert all(p is not None for p in payloads) diff --git a/api/tests/unit_tests/controllers/console/app/test_app_import_api.py b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py new file mode 100644 index 0000000000..91f58460ac --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_app_import_api.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import app_import as app_import_module +from services.app_dsl_service import ImportStatus + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _Result: + def __init__(self, status: ImportStatus, app_id: str | None = "app-1"): + self.status = status + self.app_id = app_id + + def model_dump(self, mode: str = "json"): + return {"status": self.status, "app_id": self.app_id} + + +class _SessionContext: + def __init__(self, session): + self._session = session + + def __enter__(self): + return self._session + + def __exit__(self, exc_type, exc, tb): + return False + + +def _install_session(monkeypatch: pytest.MonkeyPatch, session: MagicMock) -> None: + monkeypatch.setattr(app_import_module, "Session", lambda *_: _SessionContext(session)) + monkeypatch.setattr(app_import_module, "db", SimpleNamespace(engine=object())) + + +def _install_features(monkeypatch: pytest.MonkeyPatch, enabled: bool) -> None: + features = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=enabled)) + monkeypatch.setattr(app_import_module.FeatureService, "get_system_features", lambda: features) + + +def test_import_post_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED, app_id=None), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + +def test_import_post_returns_pending_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=False) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.PENDING), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + assert status == 202 + assert response["status"] == ImportStatus.PENDING + + +def test_import_post_updates_webapp_auth_when_enabled(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + _install_features(monkeypatch, enabled=True) + monkeypatch.setattr( + app_import_module.AppDslService, + "import_app", + lambda *_args, **_kwargs: _Result(ImportStatus.COMPLETED, app_id="app-123"), + ) + update_access = MagicMock() + monkeypatch.setattr(app_import_module.EnterpriseService.WebAppAuth, "update_app_access_mode", update_access) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports", method="POST", json={"mode": "yaml-content"}): + response, status = method() + + session.commit.assert_called_once() + update_access.assert_called_once_with("app-123", "private") + assert status == 200 + assert response["status"] == ImportStatus.COMPLETED + + +def test_import_confirm_returns_failed_status(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportConfirmApi() + method = _unwrap(api.post) + + session = MagicMock() + _install_session(monkeypatch, session) + monkeypatch.setattr( + app_import_module.AppDslService, + "confirm_import", + lambda *_args, **_kwargs: _Result(ImportStatus.FAILED), + ) + monkeypatch.setattr(app_import_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + with app.test_request_context("/console/api/apps/imports/import-1/confirm", method="POST"): + response, status = method(import_id="import-1") + + session.commit.assert_called_once() + assert status == 400 + assert response["status"] == ImportStatus.FAILED + + +def test_import_check_dependencies_returns_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = app_import_module.AppImportCheckDependenciesApi() + method = _unwrap(api.get) + + session = MagicMock() + _install_session(monkeypatch, session) + monkeypatch.setattr( + app_import_module.AppDslService, + "check_dependencies", + lambda *_args, **_kwargs: SimpleNamespace(model_dump=lambda mode="json": {"leaked_dependencies": []}), + ) + + with app.test_request_context("/console/api/apps/imports/app-1/check-dependencies", method="GET"): + response, status = method(app_model=SimpleNamespace(id="app-1")) + + assert status == 200 + assert response["leaked_dependencies"] == [] diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py new file mode 100644 index 0000000000..021e9a0784 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import InternalServerError + +from controllers.console.app.audio import ChatMessageAudioApi, ChatMessageTextApi, TextModesApi +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.audio_service import AudioService +from services.errors.app_model_config import AppModelConfigBrokenError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + ProviderNotSupportTextToSpeechLanageServiceError, + UnsupportedAudioTypeServiceError, +) + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _file_data(): + return FileStorage(stream=io.BytesIO(b"audio"), filename="audio.wav", content_type="audio/wav") + + +def test_console_audio_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "ok"}) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + response = handler(app_model=app_model) + + assert response == {"text": "ok"} + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (AppModelConfigBrokenError(), AppUnavailableError), + (NoAudioUploadedServiceError(), NoAudioUploadedError), + (AudioTooLargeServiceError("too big"), AudioTooLargeError), + (UnsupportedAudioTypeServiceError(), UnsupportedAudioTypeError), + (ProviderNotSupportSpeechToTextServiceError(), ProviderNotSupportSpeechToTextError), + (ProviderTokenNotInitError("token"), ProviderNotInitializeError), + (QuotaExceededError(), ProviderQuotaExceededError), + (ModelCurrentlyNotSupportError(), ProviderModelCurrentlyNotSupportError), + (InvokeError("invoke"), CompletionRequestError), + ], +) +def test_console_audio_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch, exc, expected) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(exc)) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + with pytest.raises(expected): + handler(app_model=app_model) + + +def test_console_audio_api_unhandled_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("boom"))) + api = ChatMessageAudioApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context("/console/api/apps/app/audio-to-text", method="POST", data={"file": _file_data()}): + with pytest.raises(InternalServerError): + handler(app_model=app_model) + + +def test_console_text_api_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + api = ChatMessageTextApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context( + "/console/api/apps/app/text-to-audio", + method="POST", + json={"text": "hello", "voice": "v"}, + ): + response = handler(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_console_text_api_error_mapping(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: (_ for _ in ()).throw(QuotaExceededError())) + + api = ChatMessageTextApi() + handler = _unwrap(api.post) + app_model = SimpleNamespace(id="a1") + + with app.test_request_context( + "/console/api/apps/app/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + with pytest.raises(ProviderQuotaExceededError): + handler(app_model=app_model) + + +def test_console_text_modes_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + api = TextModesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(tenant_id="t1") + + with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): + response = handler(app_model=app_model) + + assert response == ["voice-1"] + + +def test_console_text_modes_language_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + AudioService, + "transcript_tts_voices", + lambda **_kwargs: (_ for _ in ()).throw(ProviderNotSupportTextToSpeechLanageServiceError()), + ) + + api = TextModesApi() + handler = _unwrap(api.get) + app_model = SimpleNamespace(tenant_id="t1") + + with app.test_request_context("/console/api/apps/app/text-to-audio/voices?language=en", method="GET"): + with pytest.raises(AppUnavailableError): + handler(app_model=app_model) + + +def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + response_payload = {"text": "hello"} + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: response_payload) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + response = method(app_model=app_model) + + assert response == response_payload + + +def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + AudioService, + "transcript_asr", + lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")), + ) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + with pytest.raises(AudioTooLargeError): + method(app_model=app_model) + + +def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + response = method(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr(AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices", + method="GET", + query_string={"language": "en-US"}, + ): + response = method(app_model=app_model) + + assert response == ["voice-1"] + + +def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + # Should not raise, AudioService is mocked + response = method(app_model=app_model) + assert response == {"text": "test"} + + +def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello", "language": "en-US"}, + ): + response = method(app_model=app_model) + assert response == {"audio": "test"} + + +def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + AudioService, + "transcript_tts_voices", + lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}], + ) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices?language=en-US", + method="GET", + ): + response = method(app_model=app_model) + assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_audio_api.py b/api/tests/unit_tests/controllers/console/app/test_audio_api.py new file mode 100644 index 0000000000..8b71837c29 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_audio_api.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest + +from controllers.console.app import audio as audio_module +from controllers.console.app.error import AudioTooLargeError +from services.errors.audio import AudioTooLargeServiceError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_audio_to_text_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + response_payload = {"text": "hello"} + monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: response_payload) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + response = method(app_model=app_model) + + assert response == response_payload + + +def test_audio_to_text_maps_audio_too_large(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr( + audio_module.AudioService, + "transcript_asr", + lambda **_kwargs: (_ for _ in ()).throw(AudioTooLargeServiceError("too large")), + ) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"x"), "sample.wav")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + with pytest.raises(AudioTooLargeError): + method(app_model=app_model) + + +def test_text_to_audio_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "ok"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello"}, + ): + response = method(app_model=app_model) + + assert response == {"audio": "ok"} + + +def test_text_to_audio_voices_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts_voices", lambda **_kwargs: ["voice-1"]) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices", + method="GET", + query_string={"language": "en-US"}, + ): + response = method(app_model=app_model) + + assert response == ["voice-1"] + + +def test_audio_to_text_with_invalid_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageAudioApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_asr", lambda **_kwargs: {"text": "test"}) + + app_model = SimpleNamespace(id="app-1") + + data = {"file": (io.BytesIO(b"invalid"), "sample.xyz")} + with app.test_request_context( + "/console/api/apps/app-1/audio-to-text", + method="POST", + data=data, + content_type="multipart/form-data", + ): + # Should not raise, AudioService is mocked + response = method(app_model=app_model) + assert response == {"text": "test"} + + +def test_text_to_audio_with_language_param(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.ChatMessageTextApi() + method = _unwrap(api.post) + + monkeypatch.setattr(audio_module.AudioService, "transcript_tts", lambda **_kwargs: {"audio": "test"}) + + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio", + method="POST", + json={"text": "hello", "language": "en-US"}, + ): + response = method(app_model=app_model) + assert response == {"audio": "test"} + + +def test_text_to_audio_voices_with_language_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = audio_module.TextModesApi() + method = _unwrap(api.get) + + monkeypatch.setattr( + audio_module.AudioService, + "transcript_tts_voices", + lambda **_kwargs: [{"id": "voice-1", "name": "Voice 1"}], + ) + + app_model = SimpleNamespace(tenant_id="tenant-1") + + with app.test_request_context( + "/console/api/apps/app-1/text-to-audio/voices?language=en-US", + method="GET", + ): + response = method(app_model=app_model) + assert isinstance(response, list) diff --git a/api/tests/unit_tests/controllers/console/app/test_conversation_api.py b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py new file mode 100644 index 0000000000..5db8e5c332 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_conversation_api.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.app import conversation as conversation_module +from models.model import AppMode +from services.errors.conversation import ConversationNotExistsError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _make_account(): + return SimpleNamespace(timezone="UTC", id="u1") + + +def test_completion_conversation_list_returns_paginated_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) + + paginate_result = MagicMock() + monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) + + with app.test_request_context("/console/api/apps/app-1/completion-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response is paginate_result + + +def test_completion_conversation_list_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr( + conversation_module, + "parse_time_range", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad range")), + ) + + with app.test_request_context( + "/console/api/apps/app-1/completion-conversations", + method="GET", + query_string={"start": "bad"}, + ): + with pytest.raises(BadRequest): + method(app_model=SimpleNamespace(id="app-1")) + + +def test_chat_conversation_list_advanced_chat_calls_paginate(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.ChatConversationApi() + method = _unwrap(api.get) + + account = _make_account() + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (account, "t1")) + monkeypatch.setattr(conversation_module, "parse_time_range", lambda *_args, **_kwargs: (None, None)) + + paginate_result = MagicMock() + monkeypatch.setattr(conversation_module.db, "paginate", lambda *_args, **_kwargs: paginate_result) + + with app.test_request_context("/console/api/apps/app-1/chat-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1", mode=AppMode.ADVANCED_CHAT)) + + assert response is paginate_result + + +def test_get_conversation_updates_read_at(monkeypatch: pytest.MonkeyPatch) -> None: + conversation = SimpleNamespace(id="c1", app_id="app-1") + + query = MagicMock() + query.where.return_value = query + query.first.return_value = conversation + + session = MagicMock() + session.query.return_value = query + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr(conversation_module.db, "session", session) + + result = conversation_module._get_conversation(SimpleNamespace(id="app-1"), "c1") + + assert result is conversation + session.execute.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once_with(conversation) + + +def test_get_conversation_missing_raises_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + + session = MagicMock() + session.query.return_value = query + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr(conversation_module.db, "session", session) + + with pytest.raises(NotFound): + conversation_module._get_conversation(SimpleNamespace(id="app-1"), "missing") + + +def test_completion_conversation_delete_maps_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + api = conversation_module.CompletionConversationDetailApi() + method = _unwrap(api.delete) + + monkeypatch.setattr(conversation_module, "current_account_with_tenant", lambda: (_make_account(), "t1")) + monkeypatch.setattr( + conversation_module.ConversationService, + "delete", + lambda *_args, **_kwargs: (_ for _ in ()).throw(ConversationNotExistsError()), + ) + + with pytest.raises(NotFound): + method(app_model=SimpleNamespace(id="app-1"), conversation_id="c1") diff --git a/api/tests/unit_tests/controllers/console/app/test_generator_api.py b/api/tests/unit_tests/controllers/console/app/test_generator_api.py new file mode 100644 index 0000000000..f83bc18da3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_generator_api.py @@ -0,0 +1,260 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from controllers.console.app import generator as generator_module +from controllers.console.app.error import ProviderNotInitializeError +from core.errors.error import ProviderTokenNotInitError + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def _model_config_payload(): + return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}} + + +def _install_workflow_service(monkeypatch: pytest.MonkeyPatch, workflow): + class _Service: + def get_draft_workflow(self, app_model): + return workflow + + monkeypatch.setattr(generator_module, "WorkflowService", lambda: _Service()) + + +def test_rule_generate_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.RuleGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(generator_module.LLMGenerator, "generate_rule_config", lambda **_kwargs: {"rules": []}) + + with app.test_request_context( + "/console/api/rule-generate", + method="POST", + json={"instruction": "do it", "model_config": _model_config_payload()}, + ): + response = method() + + assert response == {"rules": []} + + +def test_rule_code_generate_maps_token_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.RuleCodeGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + def _raise(*_args, **_kwargs): + raise ProviderTokenNotInitError("missing token") + + monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", _raise) + + with app.test_request_context( + "/console/api/rule-code-generate", + method="POST", + json={"instruction": "do it", "model_config": _model_config_payload()}, + ): + with pytest.raises(ProviderNotInitializeError): + method() + + +def test_instruction_generate_app_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: None) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "app app-1 not found" + + +def test_instruction_generate_workflow_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + _install_workflow_service(monkeypatch, workflow=None) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "workflow app-1 not found" + + +def test_instruction_generate_node_missing(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + workflow = SimpleNamespace(graph_dict={"nodes": []}) + _install_workflow_service(monkeypatch, workflow=workflow) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "node node-1 not found" + + +def test_instruction_generate_code_node(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + app_model = SimpleNamespace(id="app-1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + monkeypatch.setattr(generator_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + workflow = SimpleNamespace( + graph_dict={ + "nodes": [ + {"id": "node-1", "data": {"type": "code"}}, + ] + } + ) + _install_workflow_service(monkeypatch, workflow=workflow) + monkeypatch.setattr(generator_module.LLMGenerator, "generate_code", lambda **_kwargs: {"code": "x"}) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "node-1", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response = method() + + assert response == {"code": "x"} + + +def test_instruction_generate_legacy_modify(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr( + generator_module.LLMGenerator, + "instruction_modify_legacy", + lambda **_kwargs: {"instruction": "ok"}, + ) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "", + "current": "old", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response = method() + + assert response == {"instruction": "ok"} + + +def test_instruction_generate_incompatible_params(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = generator_module.InstructionGenerateApi() + method = _unwrap(api.post) + + monkeypatch.setattr(generator_module, "current_account_with_tenant", lambda: (None, "t1")) + + with app.test_request_context( + "/console/api/instruction-generate", + method="POST", + json={ + "flow_id": "app-1", + "node_id": "", + "current": "", + "instruction": "do", + "model_config": _model_config_payload(), + }, + ): + response, status = method() + + assert status == 400 + assert response["error"] == "incompatible parameters" + + +def test_instruction_template_prompt(app) -> None: + api = generator_module.InstructionGenerationTemplateApi() + method = _unwrap(api.post) + + with app.test_request_context( + "/console/api/instruction-generate/template", + method="POST", + json={"type": "prompt"}, + ): + response = method() + + assert "data" in response + + +def test_instruction_template_invalid_type(app) -> None: + api = generator_module.InstructionGenerationTemplateApi() + method = _unwrap(api.post) + + with app.test_request_context( + "/console/api/instruction-generate/template", + method="POST", + json={"type": "unknown"}, + ): + with pytest.raises(ValueError): + method() diff --git a/api/tests/unit_tests/controllers/console/app/test_message.py b/api/tests/unit_tests/controllers/console/app/test_message.py new file mode 100644 index 0000000000..e6dfc0d3bd --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_message.py @@ -0,0 +1,320 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.exceptions import InternalServerError, NotFound +from werkzeug.local import LocalProxy + +from controllers.console.app.error import ( + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.app.message import ( + ChatMessageListApi, + ChatMessagesQuery, + FeedbackExportQuery, + MessageAnnotationCountApi, + MessageApi, + MessageFeedbackApi, + MessageFeedbackExportApi, + MessageFeedbackPayload, + MessageSuggestedQuestionApi, +) +from controllers.console.explore.error import AppSuggestedQuestionsAfterAnswerDisabledError +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from models import App, AppMode +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.CHAT + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +import contextlib + + +@contextlib.contextmanager +def setup_test_context( + test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None, qs=None +): + with ( + patch("extensions.ext_database.db") as mock_db, + patch("controllers.console.app.wraps.db", mock_db), + patch("controllers.console.wraps.db", mock_db), + patch("controllers.console.app.message.db", mock_db), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.app.message.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + # Set up a generic query mock that usually returns mock_app_model when getting app + app_query_mock = MagicMock() + app_query_mock.filter.return_value.first.return_value = mock_app_model + app_query_mock.filter.return_value.filter.return_value.first.return_value = mock_app_model + app_query_mock.where.return_value.first.return_value = mock_app_model + app_query_mock.where.return_value.where.return_value.first.return_value = mock_app_model + + data_query_mock = MagicMock() + + def query_side_effect(*args, **kwargs): + if args and hasattr(args[0], "__name__") and args[0].__name__ == "App": + return app_query_mock + return data_query_mock + + mock_db.session.query.side_effect = query_side_effect + mock_db.data_query = data_query_mock + + # Let the caller override the stat db query logic + proxy_mock = LocalProxy(lambda: mock_account) + + query_string = "&".join([f"{k}={v}" for k, v in (qs or {}).items()]) + full_path = f"{route_path}?{query_string}" if qs else route_path + + with ( + patch("libs.login.current_user", proxy_mock), + patch("flask_login.current_user", proxy_mock), + patch("controllers.console.app.message.attach_message_extra_contents", return_value=None), + ): + with test_app.test_request_context(full_path, method=method, json=payload): + request.view_args = {"app_id": "app_123"} + + if "suggested-questions" in route_path: + # simplistic extraction for message_id + parts = route_path.split("chat-messages/") + if len(parts) > 1: + request.view_args["message_id"] = parts[1].split("/")[0] + elif "messages/" in route_path and "chat-messages" not in route_path: + parts = route_path.split("messages/") + if len(parts) > 1: + request.view_args["message_id"] = parts[1].split("/")[0] + + api_instance = endpoint_class() + + # Check if it has a dispatch_request or method + if hasattr(api_instance, method.lower()): + yield api_instance, mock_db, request.view_args + + +class TestMessageValidators: + def test_chat_messages_query_validators(self): + # Test empty_to_none + assert ChatMessagesQuery.empty_to_none("") is None + assert ChatMessagesQuery.empty_to_none("val") == "val" + + # Test validate_uuid + assert ChatMessagesQuery.validate_uuid(None) is None + assert ( + ChatMessagesQuery.validate_uuid("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_message_feedback_validators(self): + assert ( + MessageFeedbackPayload.validate_message_id("123e4567-e89b-12d3-a456-426614174000") + == "123e4567-e89b-12d3-a456-426614174000" + ) + + def test_feedback_export_validators(self): + assert FeedbackExportQuery.parse_bool(None) is None + assert FeedbackExportQuery.parse_bool(True) is True + assert FeedbackExportQuery.parse_bool("1") is True + assert FeedbackExportQuery.parse_bool("0") is False + assert FeedbackExportQuery.parse_bool("off") is False + + with pytest.raises(ValueError): + FeedbackExportQuery.parse_bool("invalid") + + +class TestMessageEndpoints: + def test_chat_message_list_not_found(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + ChatMessageListApi, + "/apps/app_123/chat-messages", + "GET", + mock_account, + mock_app_model, + qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000"}, + ) as (api, mock_db, v_args): + mock_db.session.scalar.return_value = None + + with pytest.raises(NotFound): + api.get(**v_args) + + def test_chat_message_list_success(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + ChatMessageListApi, + "/apps/app_123/chat-messages", + "GET", + mock_account, + mock_app_model, + qs={"conversation_id": "123e4567-e89b-12d3-a456-426614174000", "limit": 1}, + ) as (api, mock_db, v_args): + mock_conv = MagicMock() + mock_conv.id = "123e4567-e89b-12d3-a456-426614174000" + mock_msg = MagicMock() + mock_msg.id = "msg_123" + mock_msg.feedbacks = [] + mock_msg.annotation = None + mock_msg.annotation_hit_history = None + mock_msg.agent_thoughts = [] + mock_msg.message_files = [] + mock_msg.extra_contents = [] + mock_msg.message = {} + mock_msg.message_metadata_dict = {} + + # scalar() is called twice: first for conversation lookup, second for has_more check + mock_db.session.scalar.side_effect = [mock_conv, False] + scalars_result = MagicMock() + scalars_result.all.return_value = [mock_msg] + mock_db.session.scalars.return_value = scalars_result + + resp = api.get(**v_args) + assert resp["limit"] == 1 + assert resp["has_more"] is False + assert len(resp["data"]) == 1 + + def test_message_feedback_not_found(self, app, mock_account, mock_app_model): + with setup_test_context( + app, + MessageFeedbackApi, + "/apps/app_123/feedbacks", + "POST", + mock_account, + mock_app_model, + payload={"message_id": "123e4567-e89b-12d3-a456-426614174000"}, + ) as (api, mock_db, v_args): + mock_db.session.scalar.return_value = None + + with pytest.raises(NotFound): + api.post(**v_args) + + def test_message_feedback_success(self, app, mock_account, mock_app_model): + payload = {"message_id": "123e4567-e89b-12d3-a456-426614174000", "rating": "like"} + with setup_test_context( + app, MessageFeedbackApi, "/apps/app_123/feedbacks", "POST", mock_account, mock_app_model, payload=payload + ) as (api, mock_db, v_args): + mock_msg = MagicMock() + mock_msg.admin_feedback = None + mock_db.session.scalar.return_value = mock_msg + + resp = api.post(**v_args) + assert resp == {"result": "success"} + + def test_message_annotation_count(self, app, mock_account, mock_app_model): + with setup_test_context( + app, MessageAnnotationCountApi, "/apps/app_123/annotations/count", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + mock_db.session.scalar.return_value = 5 + + resp = api.get(**v_args) + assert resp == {"count": 5} + + @patch("controllers.console.app.message.MessageService") + def test_message_suggested_questions_success(self, mock_msg_srv, app, mock_account, mock_app_model): + mock_msg_srv.get_suggested_questions_after_answer.return_value = ["q1", "q2"] + + with setup_test_context( + app, + MessageSuggestedQuestionApi, + "/apps/app_123/chat-messages/msg_123/suggested-questions", + "GET", + mock_account, + mock_app_model, + ) as (api, mock_db, v_args): + resp = api.get(**v_args) + assert resp == {"data": ["q1", "q2"]} + + @pytest.mark.parametrize( + ("exc", "expected_exc"), + [ + (MessageNotExistsError, NotFound), + (ConversationNotExistsError, NotFound), + (ProviderTokenNotInitError, ProviderNotInitializeError), + (QuotaExceededError, ProviderQuotaExceededError), + (ModelCurrentlyNotSupportError, ProviderModelCurrentlyNotSupportError), + (SuggestedQuestionsAfterAnswerDisabledError, AppSuggestedQuestionsAfterAnswerDisabledError), + (Exception, InternalServerError), + ], + ) + @patch("controllers.console.app.message.MessageService") + def test_message_suggested_questions_errors( + self, mock_msg_srv, exc, expected_exc, app, mock_account, mock_app_model + ): + mock_msg_srv.get_suggested_questions_after_answer.side_effect = exc() + + with setup_test_context( + app, + MessageSuggestedQuestionApi, + "/apps/app_123/chat-messages/msg_123/suggested-questions", + "GET", + mock_account, + mock_app_model, + ) as (api, mock_db, v_args): + with pytest.raises(expected_exc): + api.get(**v_args) + + @patch("services.feedback_service.FeedbackService.export_feedbacks") + def test_message_feedback_export_success(self, mock_export, app, mock_account, mock_app_model): + mock_export.return_value = {"exported": True} + + with setup_test_context( + app, MessageFeedbackExportApi, "/apps/app_123/feedbacks/export", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + resp = api.get(**v_args) + assert resp == {"exported": True} + + def test_message_api_get_success(self, app, mock_account, mock_app_model): + with setup_test_context( + app, MessageApi, "/apps/app_123/messages/msg_123", "GET", mock_account, mock_app_model + ) as (api, mock_db, v_args): + mock_msg = MagicMock() + mock_msg.id = "msg_123" + mock_msg.feedbacks = [] + mock_msg.annotation = None + mock_msg.annotation_hit_history = None + mock_msg.agent_thoughts = [] + mock_msg.message_files = [] + mock_msg.extra_contents = [] + mock_msg.message = {} + mock_msg.message_metadata_dict = {} + + mock_db.session.scalar.return_value = mock_msg + + resp = api.get(**v_args) + assert resp["id"] == "msg_123" diff --git a/api/tests/unit_tests/controllers/console/app/test_message_api.py b/api/tests/unit_tests/controllers/console/app/test_message_api.py new file mode 100644 index 0000000000..a76e958829 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_message_api.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import pytest + +from controllers.console.app import message as message_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_chat_messages_query_valid(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test valid ChatMessagesQuery with all fields.""" + query = message_module.ChatMessagesQuery( + conversation_id="550e8400-e29b-41d4-a716-446655440000", + first_id="550e8400-e29b-41d4-a716-446655440001", + limit=50, + ) + assert query.limit == 50 + + +def test_chat_messages_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test ChatMessagesQuery with defaults.""" + query = message_module.ChatMessagesQuery(conversation_id="550e8400-e29b-41d4-a716-446655440000") + assert query.first_id is None + assert query.limit == 20 + + +def test_chat_messages_query_empty_first_id(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test ChatMessagesQuery converts empty first_id to None.""" + query = message_module.ChatMessagesQuery( + conversation_id="550e8400-e29b-41d4-a716-446655440000", + first_id="", + ) + assert query.first_id is None + + +def test_message_feedback_payload_valid_like(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload with like rating.""" + payload = message_module.MessageFeedbackPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + rating="like", + content="Good answer", + ) + assert payload.rating == "like" + assert payload.content == "Good answer" + + +def test_message_feedback_payload_valid_dislike(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload with dislike rating.""" + payload = message_module.MessageFeedbackPayload( + message_id="550e8400-e29b-41d4-a716-446655440000", + rating="dislike", + ) + assert payload.rating == "dislike" + + +def test_message_feedback_payload_no_rating(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test MessageFeedbackPayload without rating.""" + payload = message_module.MessageFeedbackPayload(message_id="550e8400-e29b-41d4-a716-446655440000") + assert payload.rating is None + + +def test_feedback_export_query_defaults(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with default format.""" + query = message_module.FeedbackExportQuery() + assert query.format == "csv" + assert query.from_source is None + + +def test_feedback_export_query_json_format(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with JSON format.""" + query = message_module.FeedbackExportQuery(format="json") + assert query.format == "json" + + +def test_feedback_export_query_has_comment_true(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as true string.""" + query = message_module.FeedbackExportQuery(has_comment="true") + assert query.has_comment is True + + +def test_feedback_export_query_has_comment_false(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as false string.""" + query = message_module.FeedbackExportQuery(has_comment="false") + assert query.has_comment is False + + +def test_feedback_export_query_has_comment_1(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as 1.""" + query = message_module.FeedbackExportQuery(has_comment="1") + assert query.has_comment is True + + +def test_feedback_export_query_has_comment_0(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with has_comment as 0.""" + query = message_module.FeedbackExportQuery(has_comment="0") + assert query.has_comment is False + + +def test_feedback_export_query_rating_filter(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test FeedbackExportQuery with rating filter.""" + query = message_module.FeedbackExportQuery(rating="like") + assert query.rating == "like" + + +def test_annotation_count_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test AnnotationCountResponse creation.""" + response = message_module.AnnotationCountResponse(count=10) + assert response.count == 10 + + +def test_suggested_questions_response(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test SuggestedQuestionsResponse creation.""" + response = message_module.SuggestedQuestionsResponse(data=["What is AI?", "How does ML work?"]) + assert len(response.data) == 2 + assert response.data[0] == "What is AI?" diff --git a/api/tests/unit_tests/controllers/console/app/test_model_config_api.py b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py new file mode 100644 index 0000000000..61d92bb5c7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_model_config_api.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from controllers.console.app import model_config as model_config_module +from models.model import AppMode, AppModelConfig + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def test_post_updates_app_model_config_for_chat(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = model_config_module.ModelConfigResource() + method = _unwrap(api.post) + + app_model = SimpleNamespace( + id="app-1", + mode=AppMode.CHAT.value, + is_agent=False, + app_model_config_id=None, + updated_by=None, + updated_at=None, + ) + monkeypatch.setattr( + model_config_module.AppModelConfigService, + "validate_configuration", + lambda **_kwargs: {"pre_prompt": "hi"}, + ) + monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + session = MagicMock() + monkeypatch.setattr(model_config_module.db, "session", session) + + def _from_model_config_dict(self, model_config): + self.pre_prompt = model_config["pre_prompt"] + self.id = "config-1" + return self + + monkeypatch.setattr(AppModelConfig, "from_model_config_dict", _from_model_config_dict) + send_mock = MagicMock() + monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) + + with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): + response = method(app_model=app_model) + + session.add.assert_called_once() + session.flush.assert_called_once() + session.commit.assert_called_once() + send_mock.assert_called_once() + assert app_model.app_model_config_id == "config-1" + assert response["result"] == "success" + + +def test_post_encrypts_agent_tool_parameters(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = model_config_module.ModelConfigResource() + method = _unwrap(api.post) + + app_model = SimpleNamespace( + id="app-1", + mode=AppMode.AGENT_CHAT.value, + is_agent=True, + app_model_config_id="config-0", + updated_by=None, + updated_at=None, + ) + + original_config = AppModelConfig(app_id="app-1", created_by="u1", updated_by="u1") + original_config.agent_mode = json.dumps( + { + "enabled": True, + "strategy": "function-calling", + "tools": [ + { + "provider_id": "provider", + "provider_type": "builtin", + "tool_name": "tool", + "tool_parameters": {"secret": "masked"}, + } + ], + "prompt": None, + } + ) + + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.first.return_value = original_config + session.query.return_value = query + monkeypatch.setattr(model_config_module.db, "session", session) + + monkeypatch.setattr( + model_config_module.AppModelConfigService, + "validate_configuration", + lambda **_kwargs: { + "pre_prompt": "hi", + "agent_mode": { + "enabled": True, + "strategy": "function-calling", + "tools": [ + { + "provider_id": "provider", + "provider_type": "builtin", + "tool_name": "tool", + "tool_parameters": {"secret": "masked"}, + } + ], + "prompt": None, + }, + }, + ) + monkeypatch.setattr(model_config_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), "t1")) + + monkeypatch.setattr(model_config_module.ToolManager, "get_agent_tool_runtime", lambda **_kwargs: object()) + + class _ParamManager: + def __init__(self, **_kwargs): + self.delete_called = False + + def decrypt_tool_parameters(self, _value): + return {"secret": "decrypted"} + + def mask_tool_parameters(self, _value): + return {"secret": "masked"} + + def encrypt_tool_parameters(self, _value): + return {"secret": "encrypted"} + + def delete_tool_parameters_cache(self): + self.delete_called = True + + monkeypatch.setattr(model_config_module, "ToolParameterConfigurationManager", _ParamManager) + send_mock = MagicMock() + monkeypatch.setattr(model_config_module.app_model_config_was_updated, "send", send_mock) + + with app.test_request_context("/console/api/apps/app-1/model-config", method="POST", json={"pre_prompt": "hi"}): + response = method(app_model=app_model) + + stored_config = session.add.call_args[0][0] + stored_agent_mode = json.loads(stored_config.agent_mode) + assert stored_agent_mode["tools"][0]["tool_parameters"]["secret"] == "encrypted" + assert response["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic.py b/api/tests/unit_tests/controllers/console/app/test_statistic.py new file mode 100644 index 0000000000..beba23385d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_statistic.py @@ -0,0 +1,275 @@ +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.local import LocalProxy + +from controllers.console.app.statistic import ( + AverageResponseTimeStatistic, + AverageSessionInteractionStatistic, + DailyConversationStatistic, + DailyMessageStatistic, + DailyTerminalsStatistic, + DailyTokenCostStatistic, + TokensPerSecondStatistic, + UserSatisfactionRateStatistic, +) +from models import App, AppMode + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.CHAT + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +def setup_test_context( + test_app, endpoint_class, route_path, mock_account, mock_app_model, mock_rs, mock_parse_ret=(None, None) +): + with ( + patch("controllers.console.app.statistic.db") as mock_db_stat, + patch("controllers.console.app.wraps.db") as mock_db_wraps, + patch("controllers.console.wraps.db", mock_db_wraps), + patch( + "controllers.console.app.statistic.current_account_with_tenant", return_value=(mock_account, "tenant_123") + ), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + mock_conn = MagicMock() + mock_conn.execute.return_value = mock_rs + + mock_begin = MagicMock() + mock_begin.__enter__.return_value = mock_conn + mock_db_stat.engine.begin.return_value = mock_begin + + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = mock_app_model + mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model + mock_query.where.return_value.first.return_value = mock_app_model + mock_query.where.return_value.where.return_value.first.return_value = mock_app_model + mock_db_wraps.session.query.return_value = mock_query + + proxy_mock = LocalProxy(lambda: mock_account) + + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with test_app.test_request_context(route_path, method="GET"): + request.view_args = {"app_id": "app_123"} + api_instance = endpoint_class() + response = api_instance.get(app_id="app_123") + return response + + +class TestStatisticEndpoints: + def test_daily_message_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.message_count = 10 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=2023-01-01 00:00&end=2023-01-02 00:00", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["message_count"] == 10 + + def test_daily_conversation_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.conversation_count = 5 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyConversationStatistic, + "/apps/app_123/statistics/daily-conversations", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["conversation_count"] == 5 + + def test_daily_terminals_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.terminal_count = 2 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyTerminalsStatistic, + "/apps/app_123/statistics/daily-end-users", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["terminal_count"] == 2 + + def test_daily_token_cost_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.token_count = 100 + mock_row.total_price = Decimal("0.02") + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + DailyTokenCostStatistic, + "/apps/app_123/statistics/token-costs", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["token_count"] == 100 + assert response.json["data"][0]["total_price"] == "0.02" + + def test_average_session_interaction_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.interactions = Decimal("3.523") + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + AverageSessionInteractionStatistic, + "/apps/app_123/statistics/average-session-interactions", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["interactions"] == 3.52 + + def test_user_satisfaction_rate_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.message_count = 100 + mock_row.feedback_count = 10 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + UserSatisfactionRateStatistic, + "/apps/app_123/statistics/user-satisfaction-rate", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["rate"] == 100.0 + + def test_average_response_time_statistic(self, app, mock_account, mock_app_model): + mock_app_model.mode = AppMode.COMPLETION + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.latency = 1.234 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + AverageResponseTimeStatistic, + "/apps/app_123/statistics/average-response-time", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["latency"] == 1234.0 + + def test_tokens_per_second_statistic(self, app, mock_account, mock_app_model): + mock_row = MagicMock() + mock_row.date = "2023-01-01" + mock_row.tokens_per_second = 15.5 + mock_row.interactions = Decimal(0) + + with patch("controllers.console.app.statistic.parse_time_range", return_value=(None, None)): + response = setup_test_context( + app, + TokensPerSecondStatistic, + "/apps/app_123/statistics/tokens-per-second", + mock_account, + mock_app_model, + [mock_row], + ) + assert response.status_code == 200 + assert response.json["data"][0]["tps"] == 15.5 + + @patch("controllers.console.app.statistic.parse_time_range") + def test_invalid_time_range(self, mock_parse, app, mock_account, mock_app_model): + mock_parse.side_effect = ValueError("Invalid time") + + from werkzeug.exceptions import BadRequest + + with pytest.raises(BadRequest): + setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=invalid&end=invalid", + mock_account, + mock_app_model, + [], + ) + + @patch("controllers.console.app.statistic.parse_time_range") + def test_time_range_params_passed(self, mock_parse, app, mock_account, mock_app_model): + import datetime + + start = datetime.datetime.now() + end = datetime.datetime.now() + mock_parse.return_value = (start, end) + + response = setup_test_context( + app, + DailyMessageStatistic, + "/apps/app_123/statistics/daily-messages?start=something&end=something", + mock_account, + mock_app_model, + [], + ) + assert response.status_code == 200 + mock_parse.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/app/test_statistic_api.py b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py new file mode 100644 index 0000000000..15459994f9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_statistic_api.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from decimal import Decimal +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import BadRequest + +from controllers.console.app import statistic as statistic_module + + +def _unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +class _ConnContext: + def __init__(self, rows): + self._rows = rows + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, _query, _args): + return self._rows + + +def _install_db(monkeypatch: pytest.MonkeyPatch, rows) -> None: + engine = SimpleNamespace(begin=lambda: _ConnContext(rows)) + monkeypatch.setattr(statistic_module, "db", SimpleNamespace(engine=engine)) + + +def _install_common(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: (None, None), + ) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + +def test_daily_message_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-01", message_count=3)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-01", "message_count": 3}]} + + +def test_daily_conversation_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyConversationStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} + + +def test_daily_token_cost_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTokenCostStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-03", token_count=10, total_price=0.25, currency="USD")] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 1 + assert data["data"][0]["date"] == "2024-01-03" + assert data["data"][0]["token_count"] == 10 + assert data["data"][0]["total_price"] == 0.25 + + +def test_daily_terminals_statistic_returns_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTerminalsStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-04", terminal_count=7)] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-end-users", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-04", "terminal_count": 7}]} + + +def test_average_session_interaction_statistic_requires_chat_mode(app, monkeypatch: pytest.MonkeyPatch) -> None: + """Test that AverageSessionInteractionStatistic is limited to chat/agent modes.""" + # This just verifies the decorator is applied correctly + # Actual endpoint testing would require complex JOIN mocking + api = statistic_module.AverageSessionInteractionStatistic() + method = _unwrap(api.get) + assert callable(method) + + +def test_daily_message_statistic_with_invalid_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + def mock_parse(*args, **kwargs): + raise ValueError("Invalid time range") + + _install_db(monkeypatch, []) + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr(statistic_module, "parse_time_range", mock_parse) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + with pytest.raises(BadRequest): + method(app_model=SimpleNamespace(id="app-1")) + + +def test_daily_message_statistic_multiple_rows(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + rows = [ + SimpleNamespace(date="2024-01-01", message_count=10), + SimpleNamespace(date="2024-01-02", message_count=15), + SimpleNamespace(date="2024-01-03", message_count=12), + ] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 3 + + +def test_daily_message_statistic_empty_result(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyMessageStatistic() + method = _unwrap(api.get) + + _install_common(monkeypatch) + _install_db(monkeypatch, []) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-messages", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": []} + + +def test_daily_conversation_statistic_with_time_range(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyConversationStatistic() + method = _unwrap(api.get) + + rows = [SimpleNamespace(date="2024-01-02", conversation_count=5)] + _install_db(monkeypatch, rows) + monkeypatch.setattr( + statistic_module, + "current_account_with_tenant", + lambda: (SimpleNamespace(timezone="UTC"), "t1"), + ) + monkeypatch.setattr( + statistic_module, + "parse_time_range", + lambda *_args, **_kwargs: ("s", "e"), + ) + monkeypatch.setattr(statistic_module, "convert_datetime_to_date", lambda field: field) + + with app.test_request_context("/console/api/apps/app-1/statistics/daily-conversations", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + assert response.get_json() == {"data": [{"date": "2024-01-02", "conversation_count": 5}]} + + +def test_daily_token_cost_with_multiple_currencies(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = statistic_module.DailyTokenCostStatistic() + method = _unwrap(api.get) + + rows = [ + SimpleNamespace(date="2024-01-01", token_count=100, total_price=Decimal("0.50"), currency="USD"), + SimpleNamespace(date="2024-01-02", token_count=200, total_price=Decimal("1.00"), currency="USD"), + ] + _install_common(monkeypatch) + _install_db(monkeypatch, rows) + + with app.test_request_context("/console/api/apps/app-1/statistics/token-costs", method="GET"): + response = method(app_model=SimpleNamespace(id="app-1")) + + data = response.get_json() + assert len(data["data"]) == 2 diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py new file mode 100644 index 0000000000..0e22db9f9b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from werkzeug.exceptions import HTTPException, NotFound + +from controllers.console.app import workflow as workflow_module +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_parse_file_no_config(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: None) + workflow = SimpleNamespace(features_dict={}, tenant_id="t1") + + assert workflow_module._parse_file(workflow, files=[{"id": "f"}]) == [] + + +def test_parse_file_with_config(monkeypatch: pytest.MonkeyPatch) -> None: + config = object() + file_list = [ + File( + tenant_id="t1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="http://u", + ) + ] + build_mock = Mock(return_value=file_list) + monkeypatch.setattr(workflow_module.FileUploadConfigManager, "convert", lambda *_args, **_kwargs: config) + monkeypatch.setattr(workflow_module.file_factory, "build_from_mappings", build_mock) + + workflow = SimpleNamespace(features_dict={}, tenant_id="t1") + result = workflow_module._parse_file(workflow, files=[{"id": "f"}]) + + assert result == file_list + build_mock.assert_called_once() + + +def test_sync_draft_workflow_invalid_content_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + with app.test_request_context("/apps/app/workflows/draft", method="POST", data="x", content_type="text/html"): + with pytest.raises(HTTPException) as exc: + handler(api, app_model=SimpleNamespace(id="app")) + + assert exc.value.code == 415 + + +def test_sync_draft_workflow_invalid_json(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + data="[]", + content_type="application/json", + ): + response, status = handler(api, app_model=SimpleNamespace(id="app")) + + assert status == 400 + assert response["message"] == "Invalid JSON data" + + +def test_sync_draft_workflow_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="h", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + monkeypatch.setattr( + workflow_module.variable_factory, "build_environment_variable_from_mapping", lambda *_args: "env" + ) + monkeypatch.setattr( + workflow_module.variable_factory, "build_conversation_variable_from_mapping", lambda *_args: "conv" + ) + + service = SimpleNamespace(sync_draft_workflow=lambda **_kwargs: workflow) + monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + json={"graph": {}, "features": {}, "hash": "h"}, + ): + response = handler(api, app_model=SimpleNamespace(id="app")) + + assert response["result"] == "success" + + +def test_sync_draft_workflow_hash_mismatch(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + def _raise(*_args, **_kwargs): + raise workflow_module.WorkflowHashNotEqualError() + + service = SimpleNamespace(sync_draft_workflow=_raise) + monkeypatch.setattr(workflow_module, "WorkflowService", lambda: service) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft", + method="POST", + json={"graph": {}, "features": {}, "hash": "h"}, + ): + with pytest.raises(DraftWorkflowNotSync): + handler(api, app_model=SimpleNamespace(id="app")) + + +def test_restore_published_workflow_to_draft_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow = SimpleNamespace( + unique_hash="restored-hash", + updated_at=None, + created_at=datetime(2024, 1, 1), + ) + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace(restore_published_workflow_to_draft=lambda **_kwargs: workflow), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + response = handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert response["result"] == "success" + assert response["hash"] == "restored-hash" + + +def test_restore_published_workflow_to_draft_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.WorkflowNotFoundError("Workflow not found") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(NotFound): + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + +def test_restore_published_workflow_to_draft_returns_400_for_draft_source(app, monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + workflow_module.IsDraftWorkflowError( + "Cannot use draft workflow version. Workflow ID: draft-workflow. " + "Please use a published workflow version or leave workflow_id empty." + ) + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/draft-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="draft-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == workflow_module.RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE + + +def test_restore_published_workflow_to_draft_returns_400_for_invalid_structure( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + user = SimpleNamespace(id="account-1") + + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (user, "t1")) + monkeypatch.setattr( + workflow_module, + "WorkflowService", + lambda: SimpleNamespace( + restore_published_workflow_to_draft=lambda **_kwargs: (_ for _ in ()).throw( + ValueError("invalid workflow graph") + ) + ), + ) + + api = workflow_module.DraftWorkflowRestoreApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/workflows/published-workflow/restore", + method="POST", + ): + with pytest.raises(HTTPException) as exc: + handler( + api, + app_model=SimpleNamespace(id="app", tenant_id="tenant-1"), + workflow_id="published-workflow", + ) + + assert exc.value.code == 400 + assert exc.value.description == "invalid workflow graph" + + +def test_draft_workflow_get_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + workflow_module, "WorkflowService", lambda: SimpleNamespace(get_draft_workflow=lambda **_k: None) + ) + + api = workflow_module.DraftWorkflowApi() + handler = _unwrap(api.get) + + with pytest.raises(DraftWorkflowNotExist): + handler(api, app_model=SimpleNamespace(id="app")) + + +def test_advanced_chat_run_conversation_not_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + workflow_module.AppGenerateService, + "generate", + lambda *_args, **_kwargs: (_ for _ in ()).throw( + workflow_module.services.errors.conversation.ConversationNotExistsError() + ), + ) + monkeypatch.setattr(workflow_module, "current_account_with_tenant", lambda: (SimpleNamespace(), "t1")) + + api = workflow_module.AdvancedChatDraftWorkflowRunApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/apps/app/advanced-chat/workflows/draft/run", + method="POST", + json={"inputs": {}}, + ): + with pytest.raises(NotFound): + handler(api, app_model=SimpleNamespace(id="app")) diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py new file mode 100644 index 0000000000..9b5d47c208 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_draft_variable.py @@ -0,0 +1,313 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, request +from werkzeug.local import LocalProxy + +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.app.workflow_draft_variable import ( + ConversationVariableCollectionApi, + EnvironmentVariableCollectionApi, + NodeVariableCollectionApi, + SystemVariableCollectionApi, + VariableApi, + VariableResetApi, + WorkflowVariableCollectionApi, +) +from controllers.web.error import InvalidArgumentError, NotFoundError +from models import App, AppMode +from models.enums import DraftVariableType + + +@pytest.fixture +def app(): + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + flask_app.config["RESTX_MASK_HEADER"] = "X-Fields" + return flask_app + + +@pytest.fixture +def mock_account(): + from models.account import Account, AccountStatus + + account = MagicMock(spec=Account) + account.id = "user_123" + account.timezone = "UTC" + account.status = AccountStatus.ACTIVE + account.is_admin_or_owner = True + account.current_tenant.current_role = "owner" + account.has_edit_permission = True + return account + + +@pytest.fixture +def mock_app_model(): + app_model = MagicMock(spec=App) + app_model.id = "app_123" + app_model.mode = AppMode.WORKFLOW + app_model.tenant_id = "tenant_123" + return app_model + + +@pytest.fixture(autouse=True) +def mock_csrf(): + with patch("libs.login.check_csrf_token") as mock: + yield mock + + +def setup_test_context(test_app, endpoint_class, route_path, method, mock_account, mock_app_model, payload=None): + with ( + patch("controllers.console.app.wraps.db") as mock_db_wraps, + patch("controllers.console.wraps.db", mock_db_wraps), + patch("controllers.console.app.workflow_draft_variable.db"), + patch("controllers.console.app.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + ): + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = mock_app_model + mock_query.filter.return_value.filter.return_value.first.return_value = mock_app_model + mock_query.where.return_value.first.return_value = mock_app_model + mock_query.where.return_value.where.return_value.first.return_value = mock_app_model + mock_db_wraps.session.query.return_value = mock_query + + proxy_mock = LocalProxy(lambda: mock_account) + + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + with test_app.test_request_context(route_path, method=method, json=payload): + request.view_args = {"app_id": "app_123"} + # extract node_id or variable_id from path manually since view_args overrides + if "nodes/" in route_path: + request.view_args["node_id"] = route_path.split("nodes/")[1].split("/")[0] + if "variables/" in route_path: + # simplistic extraction + parts = route_path.split("variables/") + if len(parts) > 1 and parts[1] and parts[1] != "reset": + request.view_args["variable_id"] = parts[1].split("/")[0] + + api_instance = endpoint_class() + # we just call dispatch_request to avoid manual argument passing + if hasattr(api_instance, method.lower()): + func = getattr(api_instance, method.lower()) + return func(**request.view_args) + + +class TestWorkflowDraftVariableEndpoints: + @staticmethod + def _mock_workflow_variable(variable_type: DraftVariableType = DraftVariableType.NODE) -> MagicMock: + class DummyValueType: + def exposed_type(self): + return DraftVariableType.NODE + + mock_var = MagicMock() + mock_var.app_id = "app_123" + mock_var.id = "var_123" + mock_var.name = "test_var" + mock_var.description = "" + mock_var.get_variable_type.return_value = variable_type + mock_var.get_selector.return_value = [] + mock_var.value_type = DummyValueType() + mock_var.edited = False + mock_var.visible = True + mock_var.file_id = None + mock_var.variable_file = None + mock_var.is_truncated.return_value = False + mock_var.get_value.return_value.model_copy.return_value.value = "test_value" + return mock_var + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_workflow_variable_collection_get_success( + self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model + ): + mock_wf_srv.return_value.is_workflow_exist.return_value = True + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_variables_without_values.return_value = WorkflowDraftVariableList( + variables=[], total=0 + ) + + resp = setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables?page=1&limit=20", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": [], "total": 0} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + def test_workflow_variable_collection_get_not_exist(self, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.is_workflow_exist.return_value = False + + with pytest.raises(DraftWorkflowNotExist): + setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables", + "GET", + mock_account, + mock_app_model, + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_workflow_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): + resp = setup_test_context( + app, + WorkflowVariableCollectionApi, + "/apps/app_123/workflows/draft/variables", + "DELETE", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_node_variable_collection_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_node_variables.return_value = WorkflowDraftVariableList(variables=[]) + resp = setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/node_123/variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + def test_node_variable_collection_get_invalid_node_id(self, app, mock_account, mock_app_model): + with pytest.raises(InvalidArgumentError): + setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/sys/variables", + "GET", + mock_account, + mock_app_model, + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_node_variable_collection_delete(self, mock_draft_srv, app, mock_account, mock_app_model): + resp = setup_test_context( + app, + NodeVariableCollectionApi, + "/apps/app_123/workflows/draft/nodes/node_123/variables", + "DELETE", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_get_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model + ) + assert resp["id"] == "var_123" + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_get_not_found(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = None + + with pytest.raises(NotFoundError): + setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "GET", mock_account, mock_app_model + ) + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_patch_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, + VariableApi, + "/apps/app_123/workflows/draft/variables/var_123", + "PATCH", + mock_account, + mock_app_model, + payload={"name": "new_name"}, + ) + assert resp["id"] == "var_123" + mock_draft_srv.return_value.update_variable.assert_called_once() + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_api_delete_success(self, mock_draft_srv, app, mock_account, mock_app_model): + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + + resp = setup_test_context( + app, VariableApi, "/apps/app_123/workflows/draft/variables/var_123", "DELETE", mock_account, mock_app_model + ) + assert resp.status_code == 204 + mock_draft_srv.return_value.delete_variable.assert_called_once() + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_variable_reset_api_put_success(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() + mock_draft_srv.return_value.get_variable.return_value = self._mock_workflow_variable() + mock_draft_srv.return_value.reset_variable.return_value = None # means no content + + resp = setup_test_context( + app, + VariableResetApi, + "/apps/app_123/workflows/draft/variables/var_123/reset", + "PUT", + mock_account, + mock_app_model, + ) + assert resp.status_code == 204 + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_conversation_variable_collection_get(self, mock_draft_srv, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf_srv.return_value.get_draft_workflow.return_value = MagicMock() + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_conversation_variables.return_value = WorkflowDraftVariableList(variables=[]) + + resp = setup_test_context( + app, + ConversationVariableCollectionApi, + "/apps/app_123/workflows/draft/conversation-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowDraftVariableService") + def test_system_variable_collection_get(self, mock_draft_srv, app, mock_account, mock_app_model): + from services.workflow_draft_variable_service import WorkflowDraftVariableList + + mock_draft_srv.return_value.list_system_variables.return_value = WorkflowDraftVariableList(variables=[]) + + resp = setup_test_context( + app, + SystemVariableCollectionApi, + "/apps/app_123/workflows/draft/system-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} + + @patch("controllers.console.app.workflow_draft_variable.WorkflowService") + def test_environment_variable_collection_get(self, mock_wf_srv, app, mock_account, mock_app_model): + mock_wf = MagicMock() + mock_wf.environment_variables = [] + mock_wf_srv.return_value.get_draft_workflow.return_value = mock_wf + + resp = setup_test_context( + app, + EnvironmentVariableCollectionApi, + "/apps/app_123/workflows/draft/environment-variables", + "GET", + mock_account, + mock_app_model, + ) + assert resp == {"items": []} diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index f9788e2e50..83601dc1b9 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -10,10 +10,10 @@ from flask import Flask from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/controllers/console/app/test_wraps.py b/api/tests/unit_tests/controllers/console/app/test_wraps.py new file mode 100644 index 0000000000..7664e492da --- /dev/null +++ b/api/tests/unit_tests/controllers/console/app/test_wraps.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from controllers.console.app import wraps as wraps_module +from controllers.console.app.error import AppNotFoundError +from models.model import AppMode + + +def test_get_app_model_injects_model(monkeypatch: pytest.MonkeyPatch) -> None: + app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + @wraps_module.get_app_model + def handler(app_model): + return app_model.id + + assert handler(app_id="app-1") == "app-1" + + +def test_get_app_model_rejects_wrong_mode(monkeypatch: pytest.MonkeyPatch) -> None: + app_model = SimpleNamespace(id="app-1", mode=AppMode.CHAT.value, status="normal", tenant_id="t1") + query = SimpleNamespace(where=lambda *_args, **_kwargs: query, first=lambda: app_model) + + monkeypatch.setattr(wraps_module, "current_account_with_tenant", lambda: (None, "t1")) + monkeypatch.setattr(wraps_module.db, "session", SimpleNamespace(query=lambda *_args, **_kwargs: query)) + + @wraps_module.get_app_model(mode=[AppMode.COMPLETION]) + def handler(app_model): + return app_model.id + + with pytest.raises(AppNotFoundError): + handler(app_id="app-1") + + +def test_get_app_model_requires_app_id() -> None: + @wraps_module.get_app_model + def handler(app_model): + return app_model.id + + with pytest.raises(ValueError): + handler() diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index cf10182ad3..f34702a257 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,8 +13,8 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.variables.types import SegmentType +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -310,8 +310,8 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from core.workflow.file.enums import FileTransferMethod, FileType - from core.workflow.file.models import File + from dify_graph.file.enums import FileTransferMethod, FileType + from dify_graph.file.models import File # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( @@ -368,8 +368,8 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from core.workflow.file.enums import FileTransferMethod, FileType - from core.workflow.file.models import File + from dify_graph.file.enums import FileTransferMethod, FileType + from dify_graph.file.models import File # Create a File object with REMOTE_URL transfer method test_file = File( diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py new file mode 100644 index 0000000000..bc4c7e0993 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_bearer_auth.py @@ -0,0 +1,209 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.auth.data_source_bearer_auth import ( + ApiKeyAuthDataSource, + ApiKeyAuthDataSourceBinding, + ApiKeyAuthDataSourceBindingDelete, +) +from controllers.console.auth.error import ApiKeyAuthFailedError + + +class TestApiKeyAuthDataSource: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") + def test_get_api_key_auth_data_source(self, mock_get_list, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_binding = MagicMock() + mock_binding.id = "bind_123" + mock_binding.category = "api_key" + mock_binding.provider = "custom_provider" + mock_binding.disabled = False + mock_binding.created_at.timestamp.return_value = 1620000000 + mock_binding.updated_at.timestamp.return_value = 1620000001 + + mock_get_list.return_value = [mock_binding] + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSource() + response = api_instance.get() + + assert "sources" in response + assert len(response["sources"]) == 1 + assert response["sources"][0]["provider"] == "custom_provider" + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.get_provider_auth_list") + def test_get_api_key_auth_data_source_empty(self, mock_get_list, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_get_list.return_value = None + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source", method="GET"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSource() + response = api_instance.get() + + assert "sources" in response + assert len(response["sources"]) == 0 + + +class TestApiKeyAuthDataSourceBinding: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") + def test_create_binding_successful(self, mock_validate, mock_create, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context( + "/console/api/api-key-auth/data-source/binding", + method="POST", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + ): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBinding() + response = api_instance.post() + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_validate.assert_called_once() + mock_create.assert_called_once() + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.create_provider_auth") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.validate_api_key_auth_args") + def test_create_binding_failure(self, mock_validate, mock_create, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + mock_create.side_effect = ValueError("Invalid structure") + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context( + "/console/api/api-key-auth/data-source/binding", + method="POST", + json={"category": "api_key", "provider": "custom", "credentials": {"key": "value"}}, + ): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBinding() + with pytest.raises(ApiKeyAuthFailedError, match="Invalid structure"): + api_instance.post() + + +class TestApiKeyAuthDataSourceBindingDelete: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + return app + + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_bearer_auth.ApiKeyAuthService.delete_provider_auth") + def test_delete_binding_successful(self, mock_delete, mock_db, mock_csrf, app): + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with ( + patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, "tenant_123")), + patch( + "controllers.console.auth.data_source_bearer_auth.current_account_with_tenant", + return_value=(mock_account, "tenant_123"), + ), + ): + with app.test_request_context("/console/api/api-key-auth/data-source/binding_123", method="DELETE"): + proxy_mock = MagicMock() + proxy_mock._get_current_object.return_value = mock_account + with patch("libs.login.current_user", proxy_mock), patch("flask_login.current_user", proxy_mock): + api_instance = ApiKeyAuthDataSourceBindingDelete() + response = api_instance.delete("binding_123") + + assert response[0]["result"] == "success" + assert response[1] == 204 + mock_delete.assert_called_once_with("tenant_123", "binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py new file mode 100644 index 0000000000..f369565946 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_data_source_oauth.py @@ -0,0 +1,192 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.local import LocalProxy + +from controllers.console.auth.data_source_oauth import ( + OAuthDataSource, + OAuthDataSourceBinding, + OAuthDataSourceCallback, + OAuthDataSourceSync, +) + + +class TestOAuthDataSource: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("flask_login.current_user") + @patch("libs.login.current_user") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.data_source_oauth.dify_config.NOTION_INTEGRATION_TYPE", None) + def test_get_oauth_url_successful( + self, mock_db, mock_csrf, mock_libs_user, mock_flask_user, mock_get_providers, app + ): + mock_oauth_provider = MagicMock() + mock_oauth_provider.get_authorization_url.return_value = "http://oauth.provider/auth" + mock_get_providers.return_value = {"notion": mock_oauth_provider} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + mock_libs_user.return_value = mock_account + mock_flask_user.return_value = mock_account + + # also patch current_account_with_tenant + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/notion", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSource() + response = api_instance.get("notion") + + assert response[0]["data"] == "http://oauth.provider/auth" + assert response[1] == 200 + mock_oauth_provider.get_authorization_url.assert_called_once() + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("flask_login.current_user") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + def test_get_oauth_url_invalid_provider(self, mock_db, mock_csrf, mock_flask_user, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/unknown_provider", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSource() + response = api_instance.get("unknown_provider") + + assert response[0]["error"] == "Invalid provider" + assert response[1] == 400 + + +class TestOAuthDataSourceCallback: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_successful(self, mock_get_providers, app): + provider_mock = MagicMock() + mock_get_providers.return_value = {"notion": provider_mock} + + with app.test_request_context("/console/api/oauth/data-source/notion/callback?code=mock_code", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("notion") + + assert response.status_code == 302 + assert "code=mock_code" in response.location + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_missing_code(self, mock_get_providers, app): + provider_mock = MagicMock() + mock_get_providers.return_value = {"notion": provider_mock} + + with app.test_request_context("/console/api/oauth/data-source/notion/callback", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("notion") + + assert response.status_code == 302 + assert "error=Access denied" in response.location + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_oauth_callback_invalid_provider(self, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + with app.test_request_context("/console/api/oauth/data-source/invalid/callback?code=mock_code", method="GET"): + api_instance = OAuthDataSourceCallback() + response = api_instance.get("invalid") + + assert response[0]["error"] == "Invalid provider" + assert response[1] == 400 + + +class TestOAuthDataSourceBinding: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_get_binding_successful(self, mock_get_providers, app): + mock_provider = MagicMock() + mock_provider.get_access_token.return_value = None + mock_get_providers.return_value = {"notion": mock_provider} + + with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=auth_code_123", method="GET"): + api_instance = OAuthDataSourceBinding() + response = api_instance.get("notion") + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_provider.get_access_token.assert_called_once_with("auth_code_123") + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + def test_get_binding_missing_code(self, mock_get_providers, app): + mock_get_providers.return_value = {"notion": MagicMock()} + + with app.test_request_context("/console/api/oauth/data-source/notion/binding?code=", method="GET"): + api_instance = OAuthDataSourceBinding() + response = api_instance.get("notion") + + assert response[0]["error"] == "Invalid code" + assert response[1] == 400 + + +class TestOAuthDataSourceSync: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @patch("controllers.console.auth.data_source_oauth.get_oauth_providers") + @patch("libs.login.check_csrf_token") + @patch("controllers.console.wraps.db") + def test_sync_successful(self, mock_db, mock_csrf, mock_get_providers, app): + mock_provider = MagicMock() + mock_provider.sync_data_source.return_value = None + mock_get_providers.return_value = {"notion": mock_provider} + + from models.account import Account, AccountStatus + + mock_account = MagicMock(spec=Account) + mock_account.id = "user_123" + mock_account.status = AccountStatus.ACTIVE + mock_account.is_admin_or_owner = True + mock_account.current_tenant.current_role = "owner" + + with patch("controllers.console.wraps.current_account_with_tenant", return_value=(mock_account, MagicMock())): + with app.test_request_context("/console/api/oauth/data-source/notion/binding_123/sync", method="GET"): + proxy_mock = LocalProxy(lambda: mock_account) + with patch("libs.login.current_user", proxy_mock): + api_instance = OAuthDataSourceSync() + # The route pattern uses , so we just pass a string for unit testing + response = api_instance.get("notion", "binding_123") + + assert response[0]["result"] == "success" + assert response[1] == 200 + mock_provider.sync_data_source.assert_called_once_with("binding_123") diff --git a/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py new file mode 100644 index 0000000000..fc5663e72d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/auth/test_oauth_server.py @@ -0,0 +1,417 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest, NotFound + +from controllers.console.auth.oauth_server import ( + OAuthServerAppApi, + OAuthServerUserAccountApi, + OAuthServerUserAuthorizeApi, + OAuthServerUserTokenApi, +) + + +class TestOAuthServerAppApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + oauth_app.redirect_uris = ["http://localhost/callback"] + oauth_app.app_icon = "icon_url" + oauth_app.app_label = "Test App" + oauth_app.scope = "read,write" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_successful_post(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_client_id", "redirect_uri": "http://localhost/callback"}, + ): + api_instance = OAuthServerAppApi() + response = api_instance.post() + + assert response["app_icon"] == "icon_url" + assert response["app_label"] == "Test App" + assert response["scope"] == "read,write" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_client_id", "redirect_uri": "http://invalid/callback"}, + ): + api_instance = OAuthServerAppApi() + with pytest.raises(BadRequest, match="redirect_uri is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_client_id(self, mock_get_app, mock_db, app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = None + + with app.test_request_context( + "/oauth/provider", + method="POST", + json={"client_id": "test_invalid_client_id", "redirect_uri": "http://localhost/callback"}, + ): + api_instance = OAuthServerAppApi() + with pytest.raises(NotFound, match="client_id is invalid"): + api_instance.post() + + +class TestOAuthServerUserAuthorizeApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + oauth_app = MagicMock() + oauth_app.client_id = "test_client_id" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.current_account_with_tenant") + @patch("controllers.console.wraps.current_account_with_tenant") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_authorization_code") + @patch("libs.login.check_csrf_token") + def test_successful_authorize( + self, mock_csrf, mock_sign, mock_wrap_current, mock_current, mock_get_app, mock_db, app, mock_oauth_provider_app + ): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + mock_account = MagicMock() + mock_account.id = "user_123" + from models.account import AccountStatus + + mock_account.status = AccountStatus.ACTIVE + + mock_current.return_value = (mock_account, MagicMock()) + mock_wrap_current.return_value = (mock_account, MagicMock()) + + mock_sign.return_value = "auth_code_123" + + with app.test_request_context("/oauth/provider/authorize", method="POST", json={"client_id": "test_client_id"}): + with patch("libs.login.current_user", mock_account): + api_instance = OAuthServerUserAuthorizeApi() + response = api_instance.post() + + assert response["code"] == "auth_code_123" + mock_sign.assert_called_once_with("test_client_id", "user_123") + + +class TestOAuthServerUserTokenApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + oauth_app.client_secret = "test_secret" + oauth_app.redirect_uris = ["http://localhost/callback"] + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") + def test_authorization_code_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_sign.return_value = ("access_123", "refresh_123") + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + response = api_instance.post() + + assert response["access_token"] == "access_123" + assert response["refresh_token"] == "refresh_123" + assert response["token_type"] == "Bearer" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_missing_code(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "client_secret": "test_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="code is required"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_invalid_secret(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "invalid_secret", + "redirect_uri": "http://localhost/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="client_secret is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_authorization_code_grant_invalid_redirect_uri(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "authorization_code", + "code": "auth_code", + "client_secret": "test_secret", + "redirect_uri": "http://invalid/callback", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="redirect_uri is invalid"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.sign_oauth_access_token") + def test_refresh_token_grant(self, mock_sign, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_sign.return_value = ("new_access", "new_refresh") + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={"client_id": "test_client_id", "grant_type": "refresh_token", "refresh_token": "refresh_123"}, + ): + api_instance = OAuthServerUserTokenApi() + response = api_instance.post() + + assert response["access_token"] == "new_access" + assert response["refresh_token"] == "new_refresh" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_refresh_token_grant_missing_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "refresh_token", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="refresh_token is required"): + api_instance.post() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_grant_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/token", + method="POST", + json={ + "client_id": "test_client_id", + "grant_type": "invalid_grant", + }, + ): + api_instance = OAuthServerUserTokenApi() + with pytest.raises(BadRequest, match="invalid grant_type"): + api_instance.post() + + +class TestOAuthServerUserAccountApi: + @pytest.fixture + def app(self): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + @pytest.fixture + def mock_oauth_provider_app(self): + from models.model import OAuthProviderApp + + oauth_app = MagicMock(spec=OAuthProviderApp) + oauth_app.client_id = "test_client_id" + return oauth_app + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") + def test_successful_account_retrieval(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + mock_account = MagicMock() + mock_account.name = "Test User" + mock_account.email = "test@example.com" + mock_account.avatar = "avatar_url" + mock_account.interface_language = "en-US" + mock_account.timezone = "UTC" + mock_validate.return_value = mock_account + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer valid_access_token"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response["name"] == "Test User" + assert response["email"] == "test@example.com" + assert response["avatar"] == "avatar_url" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_missing_authorization_header(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context("/oauth/provider/account", method="POST", json={"client_id": "test_client_id"}): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Authorization header is required" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_authorization_header_format(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "InvalidFormat"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Invalid Authorization header format" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_invalid_token_type(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Basic something"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "token_type is invalid" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + def test_missing_access_token(self, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer "}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "Invalid Authorization header format" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.auth.oauth_server.OAuthServerService.get_oauth_provider_app") + @patch("controllers.console.auth.oauth_server.OAuthServerService.validate_oauth_access_token") + def test_invalid_access_token(self, mock_validate, mock_get_app, mock_db, app, mock_oauth_provider_app): + mock_db.session.query.return_value.first.return_value = MagicMock() + mock_get_app.return_value = mock_oauth_provider_app + mock_validate.return_value = None + + with app.test_request_context( + "/oauth/provider/account", + method="POST", + json={"client_id": "test_client_id"}, + headers={"Authorization": "Bearer invalid_token"}, + ): + api_instance = OAuthServerUserAccountApi() + response = api_instance.post() + + assert response.status_code == 401 + assert response.json["error"] == "access_token or client_id is invalid" diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py new file mode 100644 index 0000000000..9014edc39e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -0,0 +1,817 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_auth import ( + DatasourceAuth, + DatasourceAuthDefaultApi, + DatasourceAuthDeleteApi, + DatasourceAuthListApi, + DatasourceAuthOauthCustomClient, + DatasourceAuthUpdateApi, + DatasourceHardCodeAuthListApi, + DatasourceOAuthCallback, + DatasourcePluginOAuthAuthorizationUrl, + DatasourceUpdateProviderNameApi, +) +from core.plugin.impl.oauth import OAuthHandler +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from services.datasource_provider_service import DatasourceProviderService +from services.plugin.oauth_service import OAuthProxyService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasourcePluginOAuthAuthorizationUrl: + def test_get_success(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/?credential_id=cred-1"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-1", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + + def test_get_no_oauth_config(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_without_credential_id_sets_cookie(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-123", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + assert "context_id" in response.headers.get("Set-Cookie") + + +class TestDatasourceOAuthCallback: + def test_callback_success_new_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {"name": "test"} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + def test_callback_missing_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_invalid_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with ( + app.test_request_context("/?context_id=bad"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_oauth_config_not_found(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + context = {"user_id": "u", "tenant_id": "t"} + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "notion") + + def test_callback_reauthorize_existing_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} # avatar + name missing + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": "cred-1", + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "reauthorize_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + assert "/oauth-callback" in response.location + + def test_callback_context_id_from_cookie(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + +class TestDatasourceAuth: + def test_post_success(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "val"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_invalid_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "bad"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + side_effect=CredentialsValidateFailedError("invalid"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_success(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] + + def test_post_missing_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_empty_list(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceAuthDeleteApi: + def test_delete_success(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {"credential_id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_missing_credential_id(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceAuthUpdateApi: + def test_update_success(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {"k": "v"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 201 + + def test_update_with_credentials_none(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + response, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + def test_update_name_only(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 201 + + def test_update_with_empty_credentials_dict(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + _, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + +class TestDatasourceAuthListApi: + def test_list_success(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + def test_auth_list_empty(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + def test_hardcode_list_empty(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceHardCodeAuthListApi: + def test_list_success(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDatasourceAuthOauthCustomClient: + def test_post_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {"client_params": {}, "enable_oauth_custom_client": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_empty_payload(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 200 + + def test_post_disabled_flag(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = { + "client_params": {"a": 1}, + "enable_oauth_custom_client": False, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ) as setup_mock, + ): + _, status = method(api, "notion") + + setup_mock.assert_called_once() + assert status == 200 + + +class TestDatasourceAuthDefaultApi: + def test_set_default_success(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {"id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "set_default_datasource_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_default_missing_id(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceUpdateProviderNameApi: + def test_update_name_success(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_provider_name", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_update_name_too_long(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = { + "credential_id": "id", + "name": "x" * 101, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_update_name_missing_credential_id(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"name": "Valid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py new file mode 100644 index 0000000000..7a8ccde55a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_content_preview import ( + DataSourceContentPreviewApi, +) +from models import Account +from models.dataset import Pipeline + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDataSourceContentPreviewApi: + def _valid_payload(self): + return { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": "cred-1", + } + + def test_post_success(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + node_id = "node-1" + account = MagicMock(spec=Account) + + preview_result = {"content": "preview data"} + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = preview_result + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, node_id) + + service_instance.run_datasource_node_preview.assert_called_once_with( + pipeline=pipeline, + node_id=node_id, + user_inputs=payload["inputs"], + account=account, + datasource_type=payload["datasource_type"], + is_published=True, + credential_id=payload["credential_id"], + ) + assert status == 200 + assert response == preview_result + + def test_post_forbidden_non_account_user(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + MagicMock(), # NOT Account + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline, "node-1") + + def test_post_invalid_payload(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + # datasource_type missing + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node-1") + + def test_post_without_credential_id(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": None, + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = {"ok": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, "node-1") + + service_instance.run_datasource_node_preview.assert_called_once() + assert status == 200 + assert response == {"ok": True} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py new file mode 100644 index 0000000000..ebbb34e069 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -0,0 +1,225 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline import ( + CustomizedPipelineTemplateApi, + PipelineTemplateDetailApi, + PipelineTemplateListApi, + PublishCustomizedPipelineTemplateApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestPipelineTemplateListApi: + def test_get_success(self, app): + api = PipelineTemplateListApi() + method = unwrap(api.get) + + templates = [{"id": "t1"}] + + with ( + app.test_request_context("/?type=built-in&language=en-US"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates", + return_value=templates, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == templates + + +class TestPipelineTemplateDetailApi: + def test_get_success(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + template = {"id": "tpl-1"} + + service = MagicMock() + service.get_pipeline_template_detail.return_value = template + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == template + + def test_get_returns_404_when_template_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + + def test_get_returns_404_for_customized_type_not_found(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_pipeline_template_detail.return_value = None + + with ( + app.test_request_context("/?type=customized"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "non-existent-id") + + assert status == 404 + assert "error" in response + + +class TestCustomizedPipelineTemplateApi: + def test_patch_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.patch) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template" + ) as update_mock, + ): + response = method(api, "tpl-1") + + update_mock.assert_called_once() + assert response == 200 + + def test_delete_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template" + ) as delete_mock, + ): + response = method(api, "tpl-1") + + delete_mock.assert_called_once_with("tpl-1") + assert response == 200 + + def test_post_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + template = MagicMock() + template.yaml_content = "yaml-data" + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = template + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == {"data": "yaml-data"} + + def test_post_template_not_found(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + with pytest.raises(ValueError): + method(api, "tpl-1") + + +class TestPublishCustomizedPipelineTemplateApi: + def test_post_success(self, app): + api = PublishCustomizedPipelineTemplateApi() + method = unwrap(api.post) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + service = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response = method(api, "pipeline-1") + + service.publish_customized_pipeline_template.assert_called_once() + assert response == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py new file mode 100644 index 0000000000..fd38fcbb5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -0,0 +1,187 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import ( + CreateEmptyRagPipelineDatasetApi, + CreateRagPipelineDatasetApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestCreateRagPipelineDatasetApi: + def _valid_payload(self): + return {"yaml_content": "name: test"} + + def test_post_success(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + import_info = {"dataset_id": "ds-1"} + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.return_value = import_info + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == import_info + + def test_post_forbidden_non_editor(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_dataset_name_duplicate(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError() + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = {} + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestCreateEmptyRagPipelineDatasetApi: + def test_post_success(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal", + return_value={"id": "ds-1"}, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == {"id": "ds-1"} + + def test_post_forbidden_non_editor(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py new file mode 100644 index 0000000000..b4c0903f63 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -0,0 +1,324 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Response + +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import ( + RagPipelineEnvironmentVariableCollectionApi, + RagPipelineNodeVariableCollectionApi, + RagPipelineSystemVariableCollectionApi, + RagPipelineVariableApi, + RagPipelineVariableCollectionApi, + RagPipelineVariableResetApi, +) +from controllers.web.error import InvalidArgumentError, NotFoundError +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def fake_db(): + db = MagicMock() + db.engine = MagicMock() + db.session.return_value = MagicMock() + return db + + +@pytest.fixture +def editor_user(): + user = MagicMock(spec=Account) + user.has_edit_permission = True + return user + + +@pytest.fixture +def restx_config(app): + return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"}) + + +class TestRagPipelineVariableCollectionApi: + def test_get_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = True + + # IMPORTANT: RESTX expects .variables + var_list = MagicMock() + var_list.variables = [] + + draft_srv = MagicMock() + draft_srv.list_variables_without_values.return_value = var_list + + with ( + app.test_request_context("/?page=1&limit=10"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=draft_srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = False + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_delete_variables_success(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"), + ): + result = method(api, pipeline) + + assert isinstance(result, Response) + assert result.status_code == 204 + + +class TestRagPipelineNodeVariableCollectionApi: + def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_node_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "node1") + + assert result["items"] == [] + + def test_get_node_variables_invalid_node(self, app, editor_user): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + ): + with pytest.raises(InvalidArgumentError): + method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID) + + +class TestRagPipelineVariableApi: + def test_get_variable_not_found(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.get) + + srv = MagicMock() + srv.get_variable.return_value = None + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(NotFoundError): + method(api, MagicMock(), "v1") + + def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.patch) + + pipeline = MagicMock(id="p1", tenant_id="t1") + variable = MagicMock(app_id="p1", value_type=SegmentType.FILE) + + srv = MagicMock() + srv.get_variable.return_value = variable + + payload = {"value": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(InvalidArgumentError): + method(api, pipeline, "v1") + + def test_delete_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "v1") + + assert result.status_code == 204 + + +class TestRagPipelineVariableResetApi: + def test_reset_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableResetApi() + method = unwrap(api.put) + + pipeline = MagicMock(id="p1") + workflow = MagicMock() + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + srv.reset_variable.return_value = variable + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal", + return_value={"id": "v1"}, + ), + ): + result = method(api, pipeline, "v1") + + assert result == {"id": "v1"} + + +class TestSystemAndEnvironmentVariablesApi: + def test_system_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineSystemVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_system_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_environment_variables_success(self, app, editor_user): + api = RagPipelineEnvironmentVariableCollectionApi() + method = unwrap(api.get) + + env_var = MagicMock( + id="e1", + name="ENV", + description="d", + selector="s", + value_type=MagicMock(value="string"), + value="x", + ) + + workflow = MagicMock(environment_variables=[env_var]) + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + result = method(api, pipeline) + + assert len(result["items"]) == 1 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py new file mode 100644 index 0000000000..a72ad45110 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -0,0 +1,329 @@ +from unittest.mock import MagicMock, patch + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( + RagPipelineExportApi, + RagPipelineImportApi, + RagPipelineImportCheckDependenciesApi, + RagPipelineImportConfirmApi, +) +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestRagPipelineImportApi: + def _payload(self, mode="create"): + return { + "mode": mode, + "yaml_content": "content", + "name": "Test", + } + + def test_post_success_200(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"status": "success"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"status": "success"} + + def test_post_failed_400(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"status": "failed"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 400 + assert response == {"status": "failed"} + + def test_post_pending_202(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.PENDING + result.model_dump.return_value = {"status": "pending"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 202 + assert response == {"status": "pending"} + + +class TestRagPipelineImportConfirmApi: + def test_confirm_success(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"ok": True} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 200 + assert response == {"ok": True} + + def test_confirm_failed(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"ok": False} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 400 + assert response == {"ok": False} + + +class TestRagPipelineImportCheckDependenciesApi: + def test_get_success(self, app): + api = RagPipelineImportCheckDependenciesApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + result = MagicMock() + result.model_dump.return_value = {"deps": []} + + service = MagicMock() + service.check_dependencies.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"deps": []} + + +class TestRagPipelineExportApi: + def test_get_with_include_secret(self, app): + api = RagPipelineExportApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + service = MagicMock() + service.export_rag_pipeline_dsl.return_value = {"yaml": "data"} + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/?include_secret=true"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"data": {"yaml": "data"}} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py new file mode 100644 index 0000000000..472d133349 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -0,0 +1,769 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, HTTPException, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( + DefaultRagPipelineBlockConfigApi, + DraftRagPipelineApi, + DraftRagPipelineRunApi, + PublishedAllRagPipelineApi, + PublishedRagPipelineApi, + PublishedRagPipelineRunApi, + RagPipelineByIdApi, + RagPipelineDatasourceVariableApi, + RagPipelineDraftNodeRunApi, + RagPipelineDraftRunIterationNodeApi, + RagPipelineDraftRunLoopNodeApi, + RagPipelineDraftWorkflowRestoreApi, + RagPipelineRecommendedPluginApi, + RagPipelineTaskStopApi, + RagPipelineTransformApi, + RagPipelineWorkflowLastRunApi, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDraftWorkflowApi: + def test_get_draft_success(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result == workflow + + def test_get_draft_not_exist(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_sync_hash_not_match(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError() + + with ( + app.test_request_context("/", json={"graph": {}, "features": {}}), + patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotSync): + method(api, pipeline) + + def test_sync_invalid_text_plain(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + response, status = method(api, pipeline) + assert status == 400 + + def test_restore_published_workflow_to_draft_success(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + workflow = MagicMock(unique_hash="restored-hash", updated_at=None, created_at=datetime(2024, 1, 1)) + + service = MagicMock() + service.restore_published_workflow_to_draft.return_value = workflow + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "published-workflow") + + assert result["result"] == "success" + assert result["hash"] == "restored-hash" + + def test_restore_published_workflow_to_draft_not_found(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = WorkflowNotFoundError("Workflow not found") + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "published-workflow") + + def test_restore_published_workflow_to_draft_returns_400_for_draft_source(self, app): + api = RagPipelineDraftWorkflowRestoreApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="account-1") + + service = MagicMock() + service.restore_published_workflow_to_draft.side_effect = IsDraftWorkflowError( + "source workflow must be published" + ) + + with ( + app.test_request_context("/", method="POST"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(HTTPException) as exc: + method(api, pipeline, "draft-workflow") + + assert exc.value.code == 400 + assert exc.value.description == "source workflow must be published" + + +class TestDraftRunNodes: + def test_iteration_node_success(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline, "node") + assert result == {"ok": True} + + def test_iteration_node_conversation_not_exists(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + side_effect=services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node") + + def test_loop_node_success(self, app): + api = RagPipelineDraftRunLoopNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline, "node") == {"ok": True} + + +class TestPipelineRunApis: + def test_draft_run_success(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline) == {"ok": True} + + def test_draft_run_rate_limit(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context( + "/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"} + ), + patch.object( + type(console_ns), + "payload", + {"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDraftNodeRun: + def test_execution_not_found(self, app): + api = RagPipelineDraftNodeRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.run_draft_workflow_node.return_value = None + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node") + + +class TestPublishedPipelineApis: + def test_publish_success(self, app): + api = PublishedRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + workflow = MagicMock( + id="w1", + created_at=datetime.utcnow(), + ) + + session = MagicMock() + session.merge.return_value = pipeline + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + service = MagicMock() + service.publish_workflow.return_value = workflow + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["result"] == "success" + assert "created_at" in result + + +class TestMiscApis: + def test_task_stop(self, app): + api = RagPipelineTaskStopApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag" + ) as stop_mock, + ): + result = method(api, pipeline, "task-1") + stop_mock.assert_called_once() + assert result["result"] == "success" + + def test_transform_forbidden(self, app): + api = RagPipelineTransformApi() + method = unwrap(api.post) + + user = MagicMock(has_edit_permission=False, is_dataset_operator=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds1") + + def test_recommended_plugins(self, app): + api = RagPipelineRecommendedPluginApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_recommended_plugins.return_value = [{"id": "p1"}] + + with ( + app.test_request_context("/?type=all"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api) + assert result == [{"id": "p1"}] + + +class TestPublishedRagPipelineRunApi: + def test_published_run_success(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + "response_mode": "blocking", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline) + assert result == {"ok": True} + + def test_published_run_rate_limit(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDefaultBlockConfigApi: + def test_get_block_config_success(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_default_block_config.return_value = {"k": "v"} + + with ( + app.test_request_context("/?q={}"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "llm") + assert result == {"k": "v"} + + def test_get_block_config_invalid_json(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + with app.test_request_context("/?q=bad-json"): + with pytest.raises(ValueError): + method(api, pipeline, "llm") + + +class TestPublishedAllRagPipelineApi: + def test_get_published_workflows_success(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + service = MagicMock() + service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [{"id": "w1"}] + assert result["has_more"] is False + + def test_get_published_workflows_forbidden(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/?user_id=u2"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline) + + +class TestRagPipelineByIdApi: + def test_patch_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock(tenant_id="t1") + user = MagicMock(id="u1") + + workflow = MagicMock() + + service = MagicMock() + service.update_workflow.return_value = workflow + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + payload = {"marked_name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "w1") + + assert result == workflow + + def test_patch_no_fields(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + result, status = method(api, pipeline, "w1") + assert status == 400 + + +class TestRagPipelineWorkflowLastRunApi: + def test_last_run_success(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + node_exec = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + service.get_node_last_run.return_value = node_exec + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "node1") + assert result == node_exec + + def test_last_run_not_found(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node1") + + +class TestRagPipelineDatasourceVariableApi: + def test_set_datasource_variables_success(self, app): + api = RagPipelineDatasourceVariableApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "datasource_type": "db", + "datasource_info": {}, + "start_node_id": "n1", + "start_node_title": "Node", + } + + service = MagicMock() + service.set_datasource_variables.return_value = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result is not None diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py new file mode 100644 index 0000000000..3060062adf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -0,0 +1,444 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.console.datasets import data_source +from controllers.console.datasets.data_source import ( + DataSourceApi, + DataSourceNotionApi, + DataSourceNotionDatasetSyncApi, + DataSourceNotionDocumentSyncApi, + DataSourceNotionListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.data_source.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def mock_engine(): + with patch.object( + type(data_source.db), + "engine", + new_callable=PropertyMock, + return_value=MagicMock(), + ): + yield + + +class TestDataSourceApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + binding = MagicMock( + id="b1", + provider="notion", + created_at="now", + disabled=False, + source_info={}, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: [binding]), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"][0]["is_bound"] is True + + def test_get_no_bindings(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"] == [] + + def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "enable") + + assert status == 200 + assert binding.disabled is False + + def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "disable") + + assert status == 200 + assert binding.disabled is True + + def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = None + + with pytest.raises(NotFound): + method(api, "b1", "enable") + + def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "enable") + + def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "disable") + + +class TestDataSourceNotionListApi: + def test_get_credential_not_found(self, app, patch_tenant): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + response, status = method(api) + + assert status == 200 + + def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + dataset = MagicMock(data_source_type="notion_import") + document = MagicMock(data_source_info='{"notion_page_id": "p1"}') + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [document] + + response, status = method(api) + + assert status == 200 + + def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + dataset = MagicMock(data_source_type="other_type") + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session"), + ): + with pytest.raises(ValueError): + method(api) + + +class TestDataSourceNotionApi: + def test_get_preview_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.get) + + extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")]) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"integration_secret": "t"}, + ), + patch( + "controllers.console.datasets.data_source.NotionExtractor", + return_value=extractor, + ), + ): + response, status = method(api, "p1", "page") + + assert status == 200 + + def test_post_indexing_estimate_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.post) + + payload = { + "notion_info_list": [ + { + "workspace_id": "w1", + "credential_id": "c1", + "pages": [{"page_id": "p1", "type": "page"}], + } + ], + "process_rule": {"rules": {}}, + "doc_form": "text_model", + "doc_language": "English", + } + + with ( + app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}), + patch( + "controllers.console.datasets.data_source.DocumentService.estimate_args_validate", + ), + patch( + "controllers.console.datasets.data_source.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"total_pages": 1}), + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDataSourceNotionDatasetSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id", + return_value=[MagicMock(id="d1")], + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDataSourceNotionDocumentSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_document_not_found(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py new file mode 100644 index 0000000000..0ee76e504b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -0,0 +1,1927 @@ +import datetime +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets import ( + DatasetApi, + DatasetApiBaseUrlApi, + DatasetApiDeleteApi, + DatasetApiKeyApi, + DatasetAutoDisableLogApi, + DatasetEnableApiApi, + DatasetErrorDocs, + DatasetIndexingEstimateApi, + DatasetIndexingStatusApi, + DatasetListApi, + DatasetPermissionUserListApi, + DatasetQueryApi, + DatasetRelatedAppListApi, + DatasetRetrievalSettingApi, + DatasetRetrievalSettingMockApi, + DatasetUseCheckApi, +) +from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.provider_manager import ProviderManager +from extensions.storage.storage_type import StorageType +from models.enums import CreatorUserRole +from models.model import ApiToken, UploadFile +from services.dataset_service import DatasetPermissionService, DatasetService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasetList: + def _mock_dataset_dict(self, **overrides): + base = { + "id": "ds-1", + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "permission": "only_me", + } + base.update(overrides) + return base + + def _mock_user(self): + user = MagicMock() + user.is_dataset_editor = True + return user + + def test_get_success_basic(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["embedding_available"] is True + + def test_get_with_ids_filter(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?ids=1&ids=2"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets_by_ids", + return_value=(datasets, 2), + ) as by_ids_mock, + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + by_ids_mock.assert_called_once() + assert status == 200 + assert resp["total"] == 2 + + def test_get_with_tag_ids(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?tag_ids=tag1"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + + def test_embedding_available_false(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [ + self._mock_dataset_dict( + indexing_technique="high_quality", + embedding_model="text-embed", + embedding_model_provider="openai", + ) + ] + + config = MagicMock() + config.get_models.return_value = [] # model not available + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=config, + ), + ): + resp, status = method(api) + + assert resp["data"][0]["embedding_available"] is False + + def test_partial_members_permission(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict(permission="partial_members")] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.db.session.execute", + return_value=MagicMock(all=lambda: [("ds-1", "u1")]), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert resp["data"][0]["partial_member_list"] == ["u1"] + + +class TestDatasetListApiPost: + def test_post_success(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "My Dataset", + "description": "desc", + "indexing_technique": "economy", + "provider": "vendor", + } + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + # ---- minimal required fields for marshal ---- + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_post_forbidden(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "test"} + + user = MagicMock() + user.is_dataset_editor = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate"} + + user = MagicMock() + user.is_dataset_editor = True + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload_missing_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}): + with pytest.raises(ValueError): + method(api) + + def test_post_invalid_indexing_technique(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "indexing_technique": "invalid-tech", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid indexing technique"): + method(api) + + def test_post_invalid_provider(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "provider": "unknown", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid provider"): + method(api) + + +class TestDatasetApiGet: + def test_get_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "123e4567-e89b-12d3-a456-426614174000" + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding models exist → embedding_available stays True + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, status = method(api, dataset_id) + + assert status == 200 + assert data["embedding_available"] is True + + def test_get_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "missing-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + dataset = MagicMock() + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + method(api, dataset_id) + + def test_get_high_quality_embedding_unavailable(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model = "text-embedding" + dataset.embedding_model_provider = "openai" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding model NOT configured + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["embedding_available"] is False + + def test_get_partial_members_permission(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + dataset.permission = "partial_members" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + partial_members = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=partial_members, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["partial_member_list"] == partial_members + + +class TestDatasetApiPatch: + def test_patch_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "name": "updated-name", + "description": "updated description", + } + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["partial_member_list"] == [] + + def test_patch_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/datasets/missing"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, "missing") + + def test_patch_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + dataset = MagicMock() + + payload = {"name": "x"} + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetPermissionService, + "check_permission", + side_effect=Forbidden("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_patch_partial_members_update(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "partial_members", + "partial_member_list": [{"id": "u1"}, {"id": "u2"}], + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "partial_members" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "update_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=payload["partial_member_list"], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == payload["partial_member_list"] + + def test_patch_clear_partial_members(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "only_me", + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == [] + + +class TestDatasetApiDelete: + def test_delete_success(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=True, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + ): + result, status = method(api, dataset_id) + + assert status == 204 + assert result == {"result": "success"} + + def test_delete_forbidden_no_permission(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = False + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_delete_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "missing-dataset" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=False, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_delete_dataset_in_use(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + side_effect=services.errors.dataset.DatasetInUseError(), + ), + ): + with pytest.raises(DatasetInUseError): + method(api, dataset_id) + + +class TestDatasetUseCheckApi: + def test_get_use_check_true(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=True, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": True} + + def test_get_use_check_false(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=False, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": False} + + +class TestDatasetQueryApi: + def test_get_queries_success(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock(), MagicMock()] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 2), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["total"] == 2 + assert response["page"] == 1 + assert response["limit"] == 20 + assert response["has_more"] is False + assert len(response["data"]) == 2 + + def test_get_queries_dataset_not_found(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_queries_permission_denied(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_get_queries_pagination_has_more(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock() for _ in range(20)] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 40), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["has_more"] is True + assert len(response["data"]) == 20 + + +class TestDatasetIndexingEstimateApi: + def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type=StorageType.LOCAL, + key="key", + name="name.txt", + size=1, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-1", + created_at=datetime.datetime.now(tz=datetime.UTC), + used=False, + ) + upload_file.id = file_id + return upload_file + + def _base_payload(self): + return { + "info_list": { + "data_source_type": "upload_file", + "file_info_list": { + "file_ids": ["file-1"], + }, + }, + "process_rule": {"chunk_size": 100}, + "indexing_technique": "high_quality", + "doc_form": "text_model", + "doc_language": "English", + "dataset_id": None, + } + + def test_post_success_upload_file(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + mock_file = self._upload_file() + + mock_response = MagicMock() + mock_response.model_dump.return_value = {"tokens": 100} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + return_value=mock_response, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"tokens": 100} + + def test_post_file_not_found(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: None), + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_post_llm_bad_request_error(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_provider_token_not_init(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_generic_exception(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(IndexingEstimateError): + method(api) + + +class TestDatasetRelatedAppListApi: + def test_get_success(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + app2 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=app2) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 2 + assert response["data"] == [app1, app2] + + def test_get_dataset_not_found(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + def test_get_permission_denied(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + def test_get_filters_none_apps(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=None) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + assert response["data"] == [app1] + + +class TestDatasetIndexingStatusApi: + def test_get_success_with_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "completed" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert "data" in response + assert len(response["data"]) == 1 + + item = response["data"][0] + assert item["completed_segments"] == 3 + assert item["total_segments"] == 3 + + def test_get_success_no_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == {"data": []} + + def test_segment_counts_different_values(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "indexing" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + # First count = completed segments, second = total segments + query_mock = MagicMock() + query_mock.where.side_effect = [ + MagicMock(count=lambda: 2), + MagicMock(count=lambda: 5), + ] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=query_mock, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + item = response["data"][0] + assert item["completed_segments"] == 2 + assert item["total_segments"] == 5 + + +class TestDatasetApiKeyApi: + def test_get_api_keys_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.get) + + mock_key_1 = MagicMock(spec=ApiToken) + mock_key_2 = MagicMock(spec=ApiToken) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]), + ), + ): + response = method(api) + + assert "items" in response + assert response["items"] == [mock_key_1, mock_key_2] + + def test_post_create_api_key_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + patch( + "controllers.console.datasets.datasets.ApiToken.generate_api_key", + return_value="dataset-abc123", + ), + patch( + "controllers.console.datasets.datasets.db.session.add", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + ): + response, status = method(api) + + assert status == 200 + assert isinstance(response, ApiToken) + assert response.token == "dataset-abc123" + assert response.type == "dataset" + + def test_post_exceed_max_keys(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + ), + ): + with pytest.raises(BadRequest) as exc_info: + method(api) + + assert exc_info.value.code == 400 + assert exc_info.value.data == { + "message": "Cannot create more than 10 API keys for this resource type.", + "custom": "max_keys_exceeded", + } + + +class TestDatasetApiDeleteApi: + def test_delete_success(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + mock_key = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.delete", + return_value=None, + ), + ): + response, status = method(api, "api-key-id") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_key_not_found(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "api-key-id") + + +class TestDatasetEnableApiApi: + def test_enable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_disable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "disable") + + assert status == 200 + assert response["result"] == "success" + + +class TestDatasetApiBaseUrlApi: + def test_get_api_base_url_from_config(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + + def test_get_api_base_url_from_request(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("http://localhost:5000/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + None, + ), + ): + response = method(api) + + assert response["api_base_url"] == "http://localhost:5000/v1" + + +class TestDatasetRetrievalSettingApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.VECTOR_STORE", + "qdrant", + ), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic", "hybrid"]}, + ), + ): + response = method(api) + + assert "retrieval_method" in response + + +class TestDatasetRetrievalSettingMockApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingMockApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic"]}, + ), + ): + response = method(api, "milvus") + + assert response["retrieval_method"] == ["semantic"] + + +class TestDatasetErrorDocs: + def test_get_success(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + dataset = MagicMock() + error_doc = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id", + return_value=[error_doc], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + + def test_get_dataset_not_found(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + +class TestDatasetPermissionUserListApi: + def test_get_success(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + users = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list", + return_value=users, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["data"] == users + + def test_get_permission_denied(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + +class TestDatasetAutoDisableLogApi: + def test_get_success(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + dataset = MagicMock() + logs = [{"reason": "quota"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs", + return_value=logs, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == logs + + def test_get_dataset_not_found(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py new file mode 100644 index 0000000000..f23dd5b44a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -0,0 +1,1380 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.datasets_document import ( + DatasetDocumentListApi, + DocumentApi, + DocumentBatchDownloadZipApi, + DocumentBatchIndexingEstimateApi, + DocumentBatchIndexingStatusApi, + DocumentDownloadApi, + DocumentGenerateSummaryApi, + DocumentIndexingEstimateApi, + DocumentIndexingStatusApi, + DocumentMetadataApi, + DocumentPipelineExecutionLogApi, + DocumentProcessingApi, + DocumentRetryApi, + DocumentStatusApi, + DocumentSummaryStatusApi, + GetProcessRuleApi, +) +from controllers.console.datasets.error import ( + DocumentAlreadyFinishedError, + DocumentIndexingError, + IndexingEstimateError, + InvalidActionError, + InvalidMetadataError, +) +from models.enums import DataSourceType, IndexingStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(is_dataset_editor=True, id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def dataset(): + return MagicMock(id="ds-1", indexing_technique="economy", summary_index_setting={"enable": True}) + + +@pytest.fixture +def document(): + return MagicMock( + id="doc-1", + tenant_id="tenant-1", + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info_dict={"upload_file_id": "file-1"}, + doc_form="text", + archived=False, + is_paused=False, + dataset_process_rule=None, + ) + + +@pytest.fixture +def patch_dataset(dataset): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ): + yield + + +@pytest.fixture +def patch_permission(): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ): + yield + + +class TestGetProcessRuleApi: + def test_get_default_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + response = method(api) + + assert "rules" in response + + def test_get_with_document_dataset_not_found(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + +class TestDatasetDocumentListApi: + def test_get_with_fetch_true_counts_segments(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + doc = MagicMock(id="doc-1") + pagination = MagicMock(items=[doc], total=1) + + count_mock = MagicMock(return_value=2) + + with ( + app.test_request_context("/?fetch=true"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["data"] + + def test_get_with_search_status_and_created_at_sort(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?keyword=test&status=enabled&sort=created_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.apply_display_status_filter", + side_effect=lambda q, s: q, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["total"] == 1 + + def test_get_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_post_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + payload = {"indexing_technique": "economy"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", + return_value=([MagicMock()], "batch-1"), + ), + ): + response = method(api, "ds-1") + + assert "documents" in response + + def test_post_forbidden(self, app): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1") + + def test_get_with_fetch_true_and_invalid_fetch(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?fetch=maybe"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_get_sort_hit_count(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[], total=0) + + with ( + app.test_request_context("/?sort=hit_count"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 0 + + +class TestDocumentApi: + def test_get_success(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_invalid_metadata(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(InvalidMetadataError): + method(api, "ds-1", "doc-1") + + def test_delete_success(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1", "doc-1") + + +class TestDocumentDownloadApi: + def test_download_success(self, app, patch_tenant): + api = DocumentDownloadApi() + method = unwrap(api.get) + + document = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document_download_url", + return_value="url", + ), + ): + response = method(api, "ds-1", "doc-1") + + assert response["url"] == "url" + + +class TestDocumentProcessingApi: + def test_processing_forbidden_when_not_editor(self, app): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object(api, "get_document", return_value=MagicMock()), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1", "pause") + + def test_resume_from_error_state(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + doc = MagicMock(indexing_status=IndexingStatus.ERROR, is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + _, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_resume_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status=IndexingStatus.PAUSED, is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_pause_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "pause") + + assert status == 200 + + def test_pause_invalid(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "pause") + + +class TestDocumentMetadataApi: + def test_put_metadata_schema_filtering(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + + payload = { + "doc_type": "invoice", + "doc_metadata": {"amount": 10, "invalid": "x"}, + } + + schema = {"amount": int} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"invoice": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + method(api, "ds-1", "doc-1") + + assert doc.doc_metadata == {"amount": 10} + + def test_put_success(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + document = MagicMock() + + payload = {"doc_type": "others", "doc_metadata": {"a": 1}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_put_invalid_payload(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_put_invalid_doc_type(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + payload = {"doc_type": "invalid", "doc_metadata": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + +class TestDocumentStatusApi: + def test_patch_success(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "enable") + + assert status == 200 + + def test_patch_invalid_action(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + side_effect=ValueError("x"), + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "enable") + + +class TestDocumentRetryApi: + def test_retry_archived_document_skipped(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + doc = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=doc, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + ) as retry_mock, + ): + resp, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + def test_retry_success(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status=IndexingStatus.INDEXING, archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=False, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", [document]) + + def test_retry_skips_completed_document(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + +class TestDocumentPipelineExecutionLogApi: + def test_get_log_success(self, app, patch_tenant, patch_dataset): + api = DocumentPipelineExecutionLogApi() + method = unwrap(api.get) + + log = MagicMock( + datasource_info="{}", + datasource_type="file", + input_data={}, + datasource_node_id="n1", + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) + ), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApi: + def test_generate_summary_missing_documents(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[MagicMock(id="doc-1")], + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_generate_not_enabled(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + def test_generate_summary_success_with_qa_skip(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + doc1 = MagicMock(id="doc-1", doc_form="qa_model") + doc2 = MagicMock(id="doc-2", doc_form="text") + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[doc1, doc2], + ), + patch( + "controllers.console.datasets.datasets_document.generate_summary_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + +class TestDocumentSummaryStatusApi: + def test_get_success(self, app, patch_tenant, patch_permission): + api = DocumentSummaryStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "services.summary_index_service.SummaryIndexService.get_document_summary_status_detail", + return_value={"total_segments": 0}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentIndexingEstimateApi: + def test_indexing_estimate_file_not_found(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + query_mock = MagicMock() + query_mock.where.return_value.first.return_value = None + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=query_mock, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_indexing_estimate_generic_exception(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + mock_indexing_runner = MagicMock() + mock_indexing_runner.indexing_estimate.side_effect = RuntimeError("Some indexing error") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) + ), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner", + return_value=mock_indexing_runner, + ), + ): + with pytest.raises(IndexingEstimateError): + method(api, "ds-1", "doc-1") + + def test_get_finished(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock(indexing_status=IndexingStatus.COMPLETED) + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(DocumentAlreadyFinishedError): + method(api, "ds-1", "doc-1") + + +class TestDocumentBatchDownloadZipApi: + def test_post_no_documents(self, app, patch_tenant): + api = DocumentBatchDownloadZipApi() + method = unwrap(api.post) + + payload = {"document_ids": []} + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDatasetDocumentListApiDelete: + def test_delete_success(self, app, patch_tenant, patch_dataset): + """Test successful deletion of documents""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1&document_id=doc-2"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + """Test deletion with indexing error""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1") + + def test_delete_dataset_not_found(self, app, patch_tenant): + """Test deletion when dataset not found""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDocumentBatchIndexingEstimateApi: + def test_batch_indexing_estimate_website(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.WEBSITE_CRAWL, + data_source_info_dict={ + "provider": "firecrawl", + "job_id": "j1", + "url": "https://x.com", + "mode": "single", + "only_main_content": True, + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 2}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_indexing_estimate_notion(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.NOTION_IMPORT, + data_source_info_dict={ + "credential_id": "c1", + "notion_workspace_id": "w1", + "notion_page_id": "p1", + "type": "page", + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 1}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_estimate_unsupported_datasource(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type="unknown", + data_source_info_dict={}, + doc_form="text", + ) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): + with pytest.raises(ValueError): + method(api, "ds-1", "batch-1") + + def test_get_batch_estimate_invalid_batch(self, app, patch_tenant): + """Test batch estimation with invalid batch""" + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentBatchIndexingStatusApi: + def test_get_batch_status_invalid_batch(self, app, patch_tenant): + """Test batch status with invalid batch""" + api = DocumentBatchIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentIndexingStatusApi: + def test_get_status_document_not_found(self, app, patch_tenant): + """Test getting status for non-existent document""" + api = DocumentIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-doc") + + +class TestDocumentApiMetadata: + def test_get_with_only_option(self, app, patch_tenant): + """Test get with 'only' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None, doc_metadata_details=[]) + + with ( + app.test_request_context("/?metadata=only"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_with_without_option(self, app, patch_tenant): + """Test get with 'without' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/?metadata=without"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApiSuccess: + def test_generate_not_enabled_high_quality(self, app, patch_tenant, patch_permission): + """Test summary generation on non-high-quality dataset""" + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDocumentProcessingApiResume: + def test_resume_invalid_status(self, app, patch_tenant): + """Test resume on non-paused document""" + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status=IndexingStatus.COMPLETED, is_paused=False) + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "resume") + + +class TestDocumentPermissionCases: + def test_document_batch_get_permission_denied(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "batch-1") + + def test_document_batch_get_documents_not_found(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch.object(api, "get_batch_documents", return_value=None), + ): + response, status = method(api, "ds-1", "batch-1") + + assert status == 200 + assert response == { + "tokens": 0, + "total_price": 0, + "currency": "USD", + "total_segments": 0, + "preview": [], + } + + def test_document_tenant_mismatch(self, app): + api = DocumentApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + document = MagicMock( + tenant_id="other-tenant", + dataset_process_rule=None, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), # ✅ prevents real DB call + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_process_rule_get_by_document_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + process_rule = MagicMock(mode="custom", rules_dict={"a": 1}) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=lambda *a: MagicMock( + order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) + ) + ), + ), + ): + result = method(api) + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["mode"] == "custom" + + def test_process_rule_permission_denied(self, app): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(MagicMock(is_dataset_editor=True), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestDocumentListAdvancedCases: + def test_document_list_with_multiple_sort_options(self, app, patch_tenant, patch_dataset, patch_permission): + """Test document list with different sort options""" + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?sort=updated_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_document_metadata_with_schema_validation(self, app, patch_tenant): + """Test document metadata update with schema validation""" + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + payload = { + "doc_type": "contract", + "doc_metadata": {"amount": 5000, "currency": "USD", "invalid_field": "x"}, + } + + schema = {"amount": int, "currency": str} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"contract": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert doc.doc_metadata == {"amount": 5000, "currency": "USD"} + + +class TestDocumentIndexingEdgeCases: + def test_document_indexing_with_extraction_setting(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status=IndexingStatus.INDEXING, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 5}), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py new file mode 100644 index 0000000000..e67e4daad9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -0,0 +1,1252 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets_segments import ( + ChildChunkAddApi, + ChildChunkUpdateApi, + DatasetDocumentSegmentAddApi, + DatasetDocumentSegmentApi, + DatasetDocumentSegmentBatchImportApi, + DatasetDocumentSegmentListApi, + DatasetDocumentSegmentUpdateApi, + _get_segment_with_summary, +) +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, +) +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from models.dataset import ChildChunk, DocumentSegment +from models.model import UploadFile + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _segment(): + return SimpleNamespace( + id="s1", + position=1, + document_id="d1", + content="c", + sign_content="c", + answer="a", + word_count=1, + tokens=1, + keywords=[], + index_node_id="n1", + index_node_hash="h", + hit_count=0, + enabled=True, + disabled_at=None, + disabled_by=None, + status="normal", + created_by="u1", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + updated_by="u1", + indexing_at=None, + completed_at=None, + error=None, + stopped_at=None, + child_chunks=[], + attachments=[], + summary=None, + ) + + +def test_get_segment_with_summary(monkeypatch): + segment = _segment() + summary = SimpleNamespace(summary_content="summary") + + monkeypatch.setattr( + "services.summary_index_service.SummaryIndexService.get_segment_summary", + lambda *_args, **_kwargs: summary, + ) + + result = _get_segment_with_summary(segment, dataset_id="d1") + + assert result["summary"] == "summary" + + +class TestDatasetDocumentSegmentListApi: + def test_get_success(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + + pagination = MagicMock() + pagination.items = [segment] + pagination.total = 1 + pagination.pages = 1 + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_get_permission_denied(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/?segment_id=s1&segment_id=s2"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segments_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_patch_document_indexing_in_progress(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=b"running", + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "disable") + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + def test_patch_provider_token_not_init(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + +class TestDatasetDocumentSegmentAddApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "hello"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + segment.id = "seg-1" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["data"]["id"] == "seg-1" + + def test_post_llm_bad_request(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + def test_post_provider_token_not_init(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentUpdateApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "updated"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert "data" in response + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestDatasetDocumentSegmentBatchImportApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["job_status"] == "waiting" + + def test_post_dataset_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_document_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_upload_file_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_invalid_file_type(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.txt" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_post_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + side_effect=Exception("redis down"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_get_job_not_found_in_redis(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, job_id="job-1") + + +class TestChildChunkAddApi: + def test_post_success(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock(spec=ChildChunk) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + return_value=child_chunk, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "cc-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert response["data"]["id"] == "cc-1" + + def test_post_child_chunk_indexing_error(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock(indexing_technique="economy") + document = MagicMock() + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + side_effect=services.errors.chunk.ChildChunkIndexingError("fail"), + ), + ): + with pytest.raises(ChildChunkIndexingError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestChildChunkUpdateApi: + def test_delete_success(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_child_chunk_index_error(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + side_effect=services.errors.chunk.ChildChunkDeleteIndexError("fail"), + ), + ): + with pytest.raises(ChildChunkDeleteIndexError): + method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + +class TestSegmentListAdvancedCases: + def test_segment_list_with_keyword_filter(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + segment.keywords = ["test"] + segment.enabled = True + + pagination = MagicMock(items=[segment], total=1, pages=1) + + with ( + app.test_request_context("/?keyword=test"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + ): + result = method(api, "ds-1", "doc-1") + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["total"] == 1 + + def test_segment_list_permission_denied(self, app): + """Test segment list with permission denied""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_segment_list_dataset_not_found(self, app): + """Test segment list with dataset not found""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + +class TestSegmentOperationCases: + def test_segment_add_with_provider_token_error(self, app): + """Test segment add with provider token not initialized""" + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + + payload = {"content": "new content", "answer": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + side_effect=ProviderTokenNotInitError("Token not init"), + ), + ): + with pytest.raises(ProviderTokenNotInitError): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_document_not_found(self, app): + """Test batch import with document not found""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_invalid_file(self, app): + """Test batch import with invalid file type""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = None # File not found + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = MagicMock(spec=UploadFile, extension="csv", id="file-1") + upload_file.name = "test.csv" + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + side_effect=Exception("Task failed"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_batch_import_get_job_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/?job_id=invalid-job"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "invalid-job") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py new file mode 100644 index 0000000000..161d0c41e8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -0,0 +1,399 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.external import ( + BedrockRetrievalApi, + ExternalApiTemplateApi, + ExternalApiTemplateListApi, + ExternalDatasetCreateApi, + ExternalKnowledgeHitTestingApi, +) +from services.dataset_service import DatasetService +from services.external_knowledge_service import ExternalDatasetService +from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_external_dataset") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + user.is_dataset_editor = True + user.has_edit_permission = True + user.is_dataset_operator = True + return user + + +@pytest.fixture(autouse=True) +def mock_auth(mocker, current_user): + mocker.patch( + "controllers.console.datasets.external.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ) + + +class TestExternalApiTemplateListApi: + def test_get_success(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + api_item = MagicMock() + api_item.to_dict.return_value = {"id": "1"} + + with ( + app.test_request_context("/?page=1&limit=20"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_apis", + return_value=([api_item], 1), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["id"] == "1" + + def test_post_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + patch.object( + ExternalDatasetService, + "create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + +class TestExternalApiTemplateApi: + def test_get_not_found(self, app): + api = ExternalApiTemplateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_api", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "api-id") + + def test_delete_forbidden(self, app, current_user): + current_user.has_edit_permission = False + current_user.is_dataset_operator = False + + api = ExternalApiTemplateApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "api-id") + + +class TestExternalDatasetCreateApi: + def test_create_success(self, app): + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + dataset = MagicMock() + + dataset.embedding_available = False + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.enable_qa = False + dataset.enable_vector_store = False + dataset.vector_store_setting = None + dataset.is_multimodal = False + + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetService, + "create_external_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_create_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApi: + def test_hit_testing_dataset_not_found(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-id") + + def test_hit_testing_success(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = {"query": "hello"} + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + HitTestingService, + "external_retrieve", + return_value={"ok": True}, + ), + ): + resp = method(api, "dataset-id") + + assert resp["ok"] is True + + +class TestBedrockRetrievalApi: + def test_bedrock_retrieval(self, app): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "hello", + "knowledge_id": "kid", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetTestService, + "knowledge_retrieval", + return_value={"ok": True}, + ), + ): + resp, status = method() + + assert status == 200 + assert resp["ok"] is True + + +class TestExternalApiTemplateListApiAdvanced: + def test_post_duplicate_name_error(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate_api", "settings": {"key": "value"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_get_with_pagination(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + templates = [MagicMock(id=f"api-{i}") for i in range(3)] + + with ( + app.test_request_context("/?page=1&limit=20"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", + return_value=(templates, 25), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 25 + assert len(resp["data"]) == 3 + + +class TestExternalDatasetCreateApiAdvanced: + def test_create_forbidden(self, app, mock_auth, current_user): + """Test creating external dataset without permission""" + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + current_user.is_dataset_editor = False + + payload = { + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "ek-1", + "name": "new_dataset", + "description": "A dataset", + } + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApiAdvanced: + def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user): + """Test hit testing on non-existent dataset""" + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test query", + "external_retrieval_model": None, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + dataset = MagicMock() + payload = { + "query": "test query", + "external_retrieval_model": {"type": "bm25"}, + "metadata_filtering_conditions": {"status": "active"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"), + patch( + "controllers.console.datasets.external.HitTestingService.external_retrieve", + return_value={"results": []}, + ), + ): + resp = method(api, "ds-1") + + assert resp["results"] == [] + + +class TestBedrockRetrievalApiAdvanced: + def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "test", + "knowledge_id": "k-1", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval", + side_effect=ValueError("Invalid settings"), + ), + ): + with pytest.raises(ValueError): + method() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py new file mode 100644 index 0000000000..726c0a5cf3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -0,0 +1,160 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.hit_testing import HitTestingApi +from controllers.console.datasets.hit_testing_base import HitTestingPayload + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_hit_testing") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass all decorators on the API method.""" + mocker.patch( + "controllers.console.datasets.hit_testing.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.login_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.account_initialization_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check", + return_value=lambda *_: lambda f: f, + ) + + +class TestHitTestingApi: + def test_hit_testing_success(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + "top_k": 3, + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": "what is vector search", "records": []}, + ), + ): + result = method(api, dataset_id) + + assert "query" in result + assert "records" in result + assert result["records"] == [] + + def test_hit_testing_dataset_not_found(self, app, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + side_effect=NotFound("Dataset not found"), + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_hit_testing_invalid_args(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + side_effect=ValueError("Invalid parameters"), + ), + ): + with pytest.raises(ValueError, match="Invalid parameters"): + method(api, dataset_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py new file mode 100644 index 0000000000..e7ae37ae45 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from controllers.console.datasets.hit_testing_base import ( + DatasetsHitTestingBase, +) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from models.account import Account +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +@pytest.fixture +def account(): + acc = MagicMock(spec=Account) + return acc + + +@pytest.fixture(autouse=True) +def patch_current_user(mocker, account): + """Patch current_user to a valid Account.""" + mocker.patch( + "controllers.console.datasets.hit_testing_base.current_user", + account, + ) + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +class TestGetAndValidateDataset: + def test_success(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + ): + result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + assert result == dataset + + def test_dataset_not_found(self): + with patch.object( + DatasetService, + "get_dataset", + return_value=None, + ): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + def test_permission_denied(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + +class TestHitTestingArgsCheck: + def test_args_check_called(self): + args = {"query": "test"} + + with patch.object( + HitTestingService, + "hit_testing_args_check", + ) as check_mock: + DatasetsHitTestingBase.hit_testing_args_check(args) + + check_mock.assert_called_once_with(args) + + +class TestParseArgs: + def test_parse_args_success(self): + payload = {"query": "hello"} + + result = DatasetsHitTestingBase.parse_args(payload) + + assert result["query"] == "hello" + + def test_parse_args_invalid(self): + payload = {"query": "x" * 300} + + with pytest.raises(ValueError): + DatasetsHitTestingBase.parse_args(payload) + + +class TestPerformHitTesting: + def test_success(self, dataset): + response = { + "query": "hello", + "records": [], + } + + with patch.object( + HitTestingService, + "retrieve", + return_value=response, + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == "hello" + assert result["records"] == [] + + def test_index_not_initialized(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=services.errors.index.IndexNotInitializedError(), + ): + with pytest.raises(DatasetNotInitializedError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_provider_token_not_init(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ProviderTokenNotInitError("token missing"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_quota_exceeded(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=QuotaExceededError(), + ): + with pytest.raises(ProviderQuotaExceededError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_model_not_supported(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ModelCurrentlyNotSupportError(), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_llm_bad_request(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=LLMBadRequestError("bad request"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_invoke_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=InvokeError("invoke failed"), + ): + with pytest.raises(CompletionRequestError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_value_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ValueError("bad args"), + ): + with pytest.raises(ValueError, match="bad args"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_unexpected_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=Exception("boom"), + ): + with pytest.raises(InternalServerError, match="boom"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py new file mode 100644 index 0000000000..de834c2d4d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -0,0 +1,362 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.metadata import ( + DatasetMetadataApi, + DatasetMetadataBuiltInFieldActionApi, + DatasetMetadataBuiltInFieldApi, + DatasetMetadataCreateApi, + DocumentMetadataEditApi, +) +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import ( + MetadataArgs, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_dataset_metadata") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + return user + + +@pytest.fixture +def dataset(): + ds = MagicMock() + ds.id = "dataset-1" + return ds + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def metadata_id(): + return uuid.uuid4() + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass setup/login/license decorators.""" + mocker.patch( + "controllers.console.datasets.metadata.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.account_initialization_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.enterprise_license_required", + lambda f: f, + ) + + +class TestDatasetMetadataCreateApi: + def test_create_metadata_success(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + payload = {"name": "author"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "create_metadata", + return_value={"id": "m1", "name": "author"}, + ), + ): + result, status = method(api, dataset_id) + + assert status == 201 + assert result["name"] == "author" + + def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + valid_payload = { + "type": "string", + "name": "author", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=valid_payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + +class TestDatasetMetadataGetApi: + def test_get_metadata_success(self, app, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + MetadataService, + "get_dataset_metadatas", + return_value=[{"id": "m1"}], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert isinstance(result, list) + + def test_get_metadata_dataset_not_found(self, app, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, dataset_id) + + +class TestDatasetMetadataApi: + def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.patch) + + payload = {"name": "updated-name"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "update_metadata_name", + return_value={"id": "m1", "name": "updated-name"}, + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 200 + assert result["name"] == "updated-name" + + def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "delete_metadata", + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 204 + assert result["result"] == "success" + + +class TestDatasetMetadataBuiltInFieldApi: + def test_get_built_in_fields(self, app): + api = DatasetMetadataBuiltInFieldApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + MetadataService, + "get_built_in_fields", + return_value=["title", "source"], + ), + ): + result, status = method(api) + + assert status == 200 + assert result["fields"] == ["title", "source"] + + +class TestDatasetMetadataBuiltInFieldActionApi: + def test_enable_built_in_field(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataBuiltInFieldActionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "enable_built_in_field", + ), + ): + result, status = method(api, dataset_id, "enable") + + assert status == 200 + assert result["result"] == "success" + + +class TestDocumentMetadataEditApi: + def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id): + api = DocumentMetadataEditApi() + method = unwrap(api.post) + + payload = {"operation": "add", "metadata": {}} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataOperationData, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + MetadataService, + "update_documents_metadata", + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_website.py b/api/tests/unit_tests/controllers/console/datasets/test_website.py new file mode 100644 index 0000000000..9f0da6e76f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_website.py @@ -0,0 +1,233 @@ +from unittest.mock import Mock, PropertyMock, patch + +import pytest +from flask import Flask + +from controllers.console import console_ns +from controllers.console.datasets.error import WebsiteCrawlError +from controllers.console.datasets.website import ( + WebsiteCrawlApi, + WebsiteCrawlStatusApi, +) +from services.website_service import ( + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_website_crawl") + app.config["TESTING"] = True + return app + + +@pytest.fixture(autouse=True) +def bypass_auth_and_setup(mocker): + """Bypass setup/login/account decorators.""" + mocker.patch( + "controllers.console.datasets.website.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.account_initialization_required", + lambda f: f, + ) + + +class TestWebsiteCrawlApi: + def test_crawl_success(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {"depth": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + return_value={"job_id": "job-1"}, + ) + + result, status = method(api) + + assert status == 200 + assert result["job_id"] == "job-1" + + def test_crawl_invalid_payload(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "bad-url", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + side_effect=ValueError("invalid payload"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid payload"): + method(api) + + def test_crawl_service_error(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + side_effect=Exception("crawl failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="crawl failed"): + method(api) + + +class TestWebsiteCrawlStatusApi: + def test_get_status_success(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + return_value={"status": "completed"}, + ) + + result, status = method(api, job_id) + + assert status == 200 + assert result["status"] == "completed" + + def test_get_status_invalid_provider(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + side_effect=ValueError("invalid provider"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid provider"): + method(api, job_id) + + def test_get_status_service_error(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + side_effect=Exception("status lookup failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="status lookup failed"): + method(api, job_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py new file mode 100644 index 0000000000..90f00711c1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -0,0 +1,117 @@ +from unittest.mock import Mock + +import pytest + +from controllers.console.datasets.error import PipelineNotFoundError +from controllers.console.datasets.wraps import get_rag_pipeline +from models.dataset import Pipeline + + +class TestGetRagPipeline: + def test_missing_pipeline_id(self): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + with pytest.raises(ValueError, match="missing pipeline_id"): + dummy_view() + + def test_pipeline_not_found(self, mocker): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = None + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + with pytest.raises(PipelineNotFoundError): + dummy_view(pipeline_id="pipeline-1") + + def test_pipeline_found_and_injected(self, mocker): + pipeline = Mock(spec=Pipeline) + pipeline.id = "pipeline-1" + pipeline.tenant_id = "tenant-1" + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result is pipeline + + def test_pipeline_id_removed_from_kwargs(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + assert "pipeline_id" not in kwargs + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result == "ok" + + def test_pipeline_id_cast_to_string(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + def where_side_effect(*args, **kwargs): + assert args[0].right.value == "123" + return Mock(first=lambda: pipeline) + + mock_query = Mock() + mock_query.where.side_effect = where_side_effect + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id=123) + + assert result is pipeline diff --git a/api/tests/unit_tests/controllers/console/explore/__init__.py b/api/tests/unit_tests/controllers/console/explore/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py new file mode 100644 index 0000000000..0afbc5a8f7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -0,0 +1,402 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError + +import controllers.console.explore.audio as audio_module +from controllers.console.app.error import ( + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, +) + + +def unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +@pytest.fixture +def installed_app(): + app = MagicMock() + app.app = MagicMock() + return app + + +@pytest.fixture +def audio_file(): + return (BytesIO(b"audio"), "audio.wav") + + +class TestChatAudioApi: + def setup_method(self): + self.api = audio_module.ChatAudioApi() + self.method = unwrap(self.api.post) + + def test_post_success(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + return_value={"text": "ok"}, + ), + ): + resp = self.method(installed_app) + + assert resp == {"text": "ok"} + + def test_app_unavailable(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + self.method(installed_app) + + def test_no_audio_uploaded(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(NoAudioUploadedError): + self.method(installed_app) + + def test_audio_too_large(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=AudioTooLargeServiceError("too big"), + ), + ): + with pytest.raises(AudioTooLargeError): + self.method(installed_app) + + def test_provider_quota_exceeded(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + self.method(installed_app) + + def test_unknown_exception(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + self.method(installed_app) + + def test_unsupported_audio_type(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(audio_module.UnsupportedAudioTypeError): + self.method(installed_app) + + def test_provider_not_support_speech_to_text(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError): + self.method(installed_app) + + def test_provider_not_initialized(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + self.method(installed_app) + + def test_model_currently_not_supported(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + self.method(installed_app) + + def test_invoke_error_asr(self, app, installed_app, audio_file): + with ( + app.test_request_context( + "/", + data={"file": audio_file}, + content_type="multipart/form-data", + ), + patch.object( + audio_module.AudioService, + "transcript_asr", + side_effect=InvokeError("invoke failed"), + ), + ): + with pytest.raises(CompletionRequestError): + self.method(installed_app) + + +class TestChatTextApi: + def setup_method(self): + self.api = audio_module.ChatTextApi() + self.method = unwrap(self.api.post) + + def test_post_success(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"message_id": "m1", "text": "hello", "voice": "v1"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + return_value={"audio": "ok"}, + ), + ): + resp = self.method(installed_app) + + assert resp == {"audio": "ok"} + + def test_provider_not_initialized(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + self.method(installed_app) + + def test_model_not_supported(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + self.method(installed_app) + + def test_invoke_error(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=InvokeError("invoke failed"), + ), + ): + with pytest.raises(CompletionRequestError): + self.method(installed_app) + + def test_unknown_exception(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + self.method(installed_app) + + def test_app_unavailable_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + self.method(installed_app) + + def test_no_audio_uploaded_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(NoAudioUploadedError): + self.method(installed_app) + + def test_audio_too_large_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=AudioTooLargeServiceError("too big"), + ), + ): + with pytest.raises(AudioTooLargeError): + self.method(installed_app) + + def test_unsupported_audio_type_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(audio_module.UnsupportedAudioTypeError): + self.method(installed_app) + + def test_provider_not_support_speech_to_text_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=audio_module.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(audio_module.ProviderNotSupportSpeechToTextError): + self.method(installed_app) + + def test_quota_exceeded_tts(self, app, installed_app): + with ( + app.test_request_context( + "/", + json={"text": "hi"}, + ), + patch.object( + audio_module.AudioService, + "transcript_tts", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + self.method(installed_app) diff --git a/api/tests/unit_tests/controllers/console/explore/test_banner.py b/api/tests/unit_tests/controllers/console/explore/test_banner.py new file mode 100644 index 0000000000..c8f674f515 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_banner.py @@ -0,0 +1,89 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import controllers.console.explore.banner as banner_module +from models.enums import BannerStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestBannerApi: + def test_get_banners_with_requested_language(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + banner = MagicMock() + banner.id = "b1" + banner.content = {"text": "hello"} + banner.link = "https://example.com" + banner.sort = 1 + banner.status = BannerStatus.ENABLED + banner.created_at = datetime(2024, 1, 1) + + session = MagicMock() + session.scalars.return_value.all.return_value = [banner] + + with app.test_request_context("/?language=fr-FR"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [ + { + "id": "b1", + "content": {"text": "hello"}, + "link": "https://example.com", + "sort": 1, + "status": "enabled", + "created_at": "2024-01-01T00:00:00", + } + ] + + def test_get_banners_fallback_to_en_us(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + banner = MagicMock() + banner.id = "b2" + banner.content = {"text": "fallback"} + banner.link = None + banner.sort = 1 + banner.status = BannerStatus.ENABLED + banner.created_at = None + + scalars_result = MagicMock() + scalars_result.all.side_effect = [ + [], + [banner], + ] + + session = MagicMock() + session.scalars.return_value = scalars_result + + with app.test_request_context("/?language=es-ES"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [ + { + "id": "b2", + "content": {"text": "fallback"}, + "link": None, + "sort": 1, + "status": "enabled", + "created_at": None, + } + ] + + def test_get_banners_default_language_en_us(self, app): + api = banner_module.BannerApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + + with app.test_request_context("/"), patch.object(banner_module.db, "session", session): + result = method(api) + + assert result == [] diff --git a/api/tests/unit_tests/controllers/console/explore/test_completion.py b/api/tests/unit_tests/controllers/console/explore/test_completion.py new file mode 100644 index 0000000000..1dd16f3c59 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_completion.py @@ -0,0 +1,459 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError + +import controllers.console.explore.completion as completion_module +from controllers.console.app.error import ( + ConversationCompletedError, +) +from controllers.console.explore.error import NotChatAppError, NotCompletionAppError +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from models import Account +from models.model import AppMode +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user(): + return MagicMock(spec=Account) + + +@pytest.fixture +def completion_app(): + return MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + +@pytest.fixture +def chat_app(): + return MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + +@pytest.fixture +def payload_data(): + return {"inputs": {}, "query": "hi"} + + +@pytest.fixture +def payload_patch(payload_data): + return patch.object( + type(completion_module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload_data, + ) + + +class TestCompletionApi: + def test_post_success(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + return_value={"ok": True}, + ), + patch.object( + completion_module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + result = method(completion_app) + + assert result == ("ok", 200) + + def test_post_wrong_app_mode(self): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + with pytest.raises(NotCompletionAppError): + method(installed_app) + + def test_conversation_completed(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(completion_app) + + def test_internal_error(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(completion_app) + + def test_conversation_not_exists(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(completion_module.NotFound): + method(completion_app) + + def test_app_unavailable(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(completion_module.AppUnavailableError): + method(completion_app) + + def test_provider_not_initialized(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(completion_app) + + def test_quota_exceeded(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.QuotaExceededError(), + ), + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(completion_app) + + def test_model_not_supported(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): + method(completion_app) + + def test_invoke_error(self, app, completion_app, user, payload_patch): + api = completion_module.CompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.InvokeError("invoke failed"), + ), + ): + with pytest.raises(completion_module.CompletionRequestError): + method(completion_app) + + +class TestCompletionStopApi: + def test_stop_success(self, completion_app, user): + api = completion_module.CompletionStopApi() + method = unwrap(api.post) + + user.id = "u1" + + with ( + patch.object(completion_module, "current_user", user), + patch.object(completion_module.AppTaskService, "stop_task"), + ): + resp, status = method(completion_app, "task-1") + + assert status == 200 + assert resp == {"result": "success"} + + def test_stop_wrong_app_mode(self): + api = completion_module.CompletionStopApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.CHAT)) + + with pytest.raises(NotCompletionAppError): + method(installed_app, "task") + + +class TestChatApi: + def test_post_success(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + return_value={"ok": True}, + ), + patch.object( + completion_module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + result = method(chat_app) + + assert result == ("ok", 200) + + def test_post_not_chat_app(self): + api = completion_module.ChatApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + with pytest.raises(NotChatAppError): + method(installed_app) + + def test_rate_limit_error(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(chat_app) + + def test_conversation_completed_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(chat_app) + + def test_conversation_not_exists_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(completion_module.NotFound): + method(chat_app) + + def test_app_unavailable_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(completion_module.AppUnavailableError): + method(chat_app) + + def test_provider_not_initialized_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ProviderTokenNotInitError("not init"), + ), + ): + with pytest.raises(completion_module.ProviderNotInitializeError): + method(chat_app) + + def test_quota_exceeded_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.QuotaExceededError(), + ), + ): + with pytest.raises(completion_module.ProviderQuotaExceededError): + method(chat_app) + + def test_model_not_supported_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(completion_module.ProviderModelCurrentlyNotSupportError): + method(chat_app) + + def test_invoke_error_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=completion_module.InvokeError("invoke failed"), + ), + ): + with pytest.raises(completion_module.CompletionRequestError): + method(chat_app) + + def test_internal_error_chat(self, app, chat_app, user, payload_patch): + api = completion_module.ChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + payload_patch, + patch.object(completion_module, "current_user", user), + patch.object( + completion_module.AppGenerateService, + "generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(chat_app) + + +class TestChatStopApi: + def test_stop_success(self, chat_app, user): + api = completion_module.ChatStopApi() + method = unwrap(api.post) + + user.id = "u1" + + with ( + patch.object(completion_module, "current_user", user), + patch.object(completion_module.AppTaskService, "stop_task"), + ): + resp, status = method(chat_app, "task-1") + + assert status == 200 + assert resp == {"result": "success"} + + def test_stop_not_chat_app(self): + api = completion_module.ChatStopApi() + method = unwrap(api.post) + + installed_app = MagicMock(app=MagicMock(mode=AppMode.COMPLETION)) + + with pytest.raises(NotChatAppError): + method(installed_app, "task") diff --git a/api/tests/unit_tests/controllers/console/explore/test_conversation.py b/api/tests/unit_tests/controllers/console/explore/test_conversation.py new file mode 100644 index 0000000000..65cc209725 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_conversation.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +import controllers.console.explore.conversation as conversation_module +from controllers.console.explore.error import NotChatAppError +from models import Account +from models.model import AppMode +from services.errors.conversation import ( + ConversationNotExistsError, + LastConversationNotExistsError, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class FakeConversation: + def __init__(self, cid): + self.id = cid + self.name = "test" + self.inputs = {} + self.status = "normal" + self.introduction = "" + + +@pytest.fixture +def chat_app(): + app_model = MagicMock(mode=AppMode.CHAT, id="app-id") + return MagicMock(app=app_model) + + +@pytest.fixture +def non_chat_app(): + app_model = MagicMock(mode=AppMode.COMPLETION) + return MagicMock(app=app_model) + + +@pytest.fixture +def user(): + user = MagicMock(spec=Account) + user.id = "uid" + return user + + +@pytest.fixture(autouse=True) +def mock_db_and_session(): + with ( + patch.object( + conversation_module, + "db", + MagicMock(session=MagicMock(), engine=MagicMock()), + ), + patch( + "controllers.console.explore.conversation.Session", + MagicMock(), + ), + ): + yield + + +class TestConversationListApi: + def test_get_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + pagination = MagicMock( + limit=20, + has_more=False, + data=[FakeConversation("c1"), FakeConversation("c2")], + ) + + with ( + app.test_request_context("/?limit=20"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pagination_by_last_id", + return_value=pagination, + ), + ): + result = method(chat_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_last_conversation_not_exists(self, app: Flask, chat_app, user): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pagination_by_last_id", + side_effect=LastConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app) + + def test_wrong_app_mode(self, app: Flask, non_chat_app): + api = conversation_module.ConversationListApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(non_chat_app) + + +class TestConversationApi: + def test_delete_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "delete", + ), + ): + result = method(chat_app, "cid") + + body, status = result + assert status == 204 + assert body["result"] == "success" + + def test_delete_not_found(self, app: Flask, chat_app, user): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "delete", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app, "cid") + + def test_delete_wrong_app_mode(self, app: Flask, non_chat_app): + api = conversation_module.ConversationApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(non_chat_app, "cid") + + +class TestConversationRenameApi: + def test_rename_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationRenameApi() + method = unwrap(api.post) + + conversation = FakeConversation("cid") + + with ( + app.test_request_context("/", json={"name": "new"}), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "rename", + return_value=conversation, + ), + ): + result = method(chat_app, "cid") + + assert result["id"] == "cid" + + def test_rename_not_found(self, app: Flask, chat_app, user): + api = conversation_module.ConversationRenameApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "new"}), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.ConversationService, + "rename", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(chat_app, "cid") + + +class TestConversationPinApi: + def test_pin_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationPinApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "pin", + ), + ): + result = method(chat_app, "cid") + + assert result == {"result": "success"} + + +class TestConversationUnPinApi: + def test_unpin_success(self, app: Flask, chat_app, user): + api = conversation_module.ConversationUnPinApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch.object(conversation_module, "current_user", user), + patch.object( + conversation_module.WebConversationService, + "unpin", + ), + ): + result = method(chat_app, "cid") + + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/explore/test_installed_app.py b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py new file mode 100644 index 0000000000..93652e75d2 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_installed_app.py @@ -0,0 +1,362 @@ +from datetime import datetime +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +import controllers.console.explore.installed_app as module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_id(): + return "t1" + + +@pytest.fixture +def current_user(tenant_id): + user = MagicMock() + user.id = "u1" + user.current_tenant = MagicMock(id=tenant_id) + return user + + +@pytest.fixture +def installed_app(): + app = MagicMock() + app.id = "ia1" + app.app = MagicMock(id="a1") + app.app_owner_tenant_id = "t2" + app.is_pinned = False + app.last_used_at = datetime(2024, 1, 1) + return app + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestInstalledAppsListApi: + def test_get_installed_apps(self, app, current_user, tenant_id, installed_app): + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert "installed_apps" in result + assert result["installed_apps"][0]["editable"] is True + assert result["installed_apps"][0]["uninstallable"] is False + + def test_get_installed_apps_with_app_id_filter(self, app, current_user, tenant_id): + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + + with ( + app.test_request_context("/?app_id=a1"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="member"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert result == {"installed_apps": []} + + def test_get_installed_apps_with_webapp_auth_enabled(self, app, current_user, tenant_id, installed_app): + """Test filtering when webapp_auth is enabled.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "restricted" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_is_user_allowed_to_access_webapps", + return_value={"a1": True}, + ), + ): + result = method(api) + + assert len(result["installed_apps"]) == 1 + + def test_get_installed_apps_with_webapp_auth_user_denied(self, app, current_user, tenant_id, installed_app): + """Test filtering when user doesn't have access.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "restricted" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="member"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_is_user_allowed_to_access_webapps", + return_value={"a1": False}, + ), + ): + result = method(api) + + assert result["installed_apps"] == [] + + def test_get_installed_apps_with_sso_verified_access(self, app, current_user, tenant_id, installed_app): + """Test that sso_verified access mode apps are skipped in filtering.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + mock_webapp_setting = MagicMock() + mock_webapp_setting.access_mode = "sso_verified" + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=True)), + ), + patch.object( + module.EnterpriseService.WebAppAuth, + "batch_get_app_access_mode_by_id", + return_value={"a1": mock_webapp_setting}, + ), + ): + result = method(api) + + assert len(result["installed_apps"]) == 0 + + def test_get_installed_apps_filters_null_apps(self, app, current_user, tenant_id): + """Test that installed apps with null app are filtered out.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + installed_app_with_null = MagicMock() + installed_app_with_null.app = None + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app_with_null] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + patch.object(module.TenantService, "get_user_role", return_value="owner"), + patch.object( + module.FeatureService, + "get_system_features", + return_value=MagicMock(webapp_auth=MagicMock(enabled=False)), + ), + ): + result = method(api) + + assert result["installed_apps"] == [] + + def test_get_installed_apps_current_tenant_none(self, app, tenant_id, installed_app): + """Test error when current_user.current_tenant is None.""" + api = module.InstalledAppsListApi() + method = unwrap(api.get) + + current_user = MagicMock() + current_user.current_tenant = None + + session = MagicMock() + session.scalars.return_value.all.return_value = [installed_app] + + with ( + app.test_request_context("/"), + patch.object(module, "current_account_with_tenant", return_value=(current_user, tenant_id)), + patch.object(module.db, "session", session), + ): + with pytest.raises(ValueError, match="current_user.current_tenant must not be None"): + method(api) + + +class TestInstalledAppsCreateApi: + def test_post_success(self, app, tenant_id, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + recommended = MagicMock() + recommended.install_count = 0 + + app_entity = MagicMock() + app_entity.id = "a1" + app_entity.is_public = True + app_entity.tenant_id = "t2" + + session = MagicMock() + # scalar() is called for recommended_app and installed_app lookups + session.scalar.side_effect = [recommended, None] + # get() is called for app PK lookup + session.get.return_value = app_entity + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + ): + result = method(api) + + assert result == {"message": "App installed successfully"} + assert recommended.install_count == 1 + + def test_post_recommended_not_found(self, app, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + session = MagicMock() + session.scalar.return_value = None + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + ): + with pytest.raises(NotFound): + method(api) + + def test_post_app_not_public(self, app, tenant_id, payload_patch): + api = module.InstalledAppsListApi() + method = unwrap(api.post) + + recommended = MagicMock() + app_entity = MagicMock(is_public=False) + + session = MagicMock() + # scalar() returns recommended_app + session.scalar.return_value = recommended + # get() returns the app entity + session.get.return_value = app_entity + + with ( + app.test_request_context("/", json={"app_id": "a1"}), + payload_patch({"app_id": "a1"}), + patch.object(module.db, "session", session), + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestInstalledAppApi: + def test_delete_success(self, tenant_id, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.delete) + + with ( + patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)), + patch.object(module.db, "session"), + ): + resp, status = method(installed_app) + + assert status == 204 + assert resp["result"] == "success" + + def test_delete_owned_by_current_tenant(self, tenant_id): + api = module.InstalledAppApi() + method = unwrap(api.delete) + + installed_app = MagicMock(app_owner_tenant_id=tenant_id) + + with patch.object(module, "current_account_with_tenant", return_value=(None, tenant_id)): + with pytest.raises(BadRequest): + method(installed_app) + + def test_patch_update_pin(self, app, payload_patch, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/", json={"is_pinned": True}), + payload_patch({"is_pinned": True}), + patch.object(module.db, "session"), + ): + result = method(installed_app) + + assert installed_app.is_pinned is True + assert result["result"] == "success" + + def test_patch_no_change(self, app, payload_patch, installed_app): + api = module.InstalledAppApi() + method = unwrap(api.patch) + + with app.test_request_context("/", json={}), payload_patch({}), patch.object(module.db, "session"): + result = method(installed_app) + + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py new file mode 100644 index 0000000000..6b5c304884 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -0,0 +1,552 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import InternalServerError, NotFound + +import controllers.console.explore.message as module +from controllers.console.app.error import ( + AppMoreLikeThisDisabledError, + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + AppSuggestedQuestionsAfterAnswerDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.conversation import ConversationNotExistsError +from services.errors.message import ( + FirstMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) + + +def unwrap(func): + bound_self = getattr(func, "__self__", None) + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + if bound_self is not None: + return func.__get__(bound_self, bound_self.__class__) + return func + + +def make_message(): + msg = MagicMock() + msg.id = "m1" + msg.conversation_id = "11111111-1111-1111-1111-111111111111" + msg.parent_message_id = None + msg.inputs = {} + msg.query = "hello" + msg.re_sign_file_url_answer = "" + msg.user_feedback = MagicMock(rating=None) + msg.status = "normal" + msg.error = None + return msg + + +class TestMessageListApi: + def test_get_success(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + pagination = MagicMock( + limit=20, + has_more=False, + data=[make_message(), make_message()], + ) + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + return_value=pagination, + ), + ): + result = method(installed_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_get_not_chat_app(self): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotChatAppError): + method(installed_app) + + def test_conversation_not_exists(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + def test_first_message_not_exists(self, app): + api = module.MessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + app.test_request_context( + "/", + query_string={"conversation_id": "11111111-1111-1111-1111-111111111111"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "pagination_by_first_id", + side_effect=FirstMessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + +class TestMessageFeedbackApi: + def test_post_success(self, app): + api = module.MessageFeedbackApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock() + + with ( + app.test_request_context("/", json={"rating": "like"}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "create_feedback", + ), + ): + result = method(installed_app, "mid") + + assert result["result"] == "success" + + def test_message_not_exists(self, app): + api = module.MessageFeedbackApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock() + + with ( + app.test_request_context("/", json={}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "create_feedback", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + +class TestMessageMoreLikeThisApi: + def test_get_success(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + return_value={"ok": True}, + ), + patch.object( + module.helper, + "compact_generate_response", + return_value=("ok", 200), + ), + ): + resp = method(installed_app, "mid") + + assert resp == ("ok", 200) + + def test_not_completion_app(self): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app, "mid") + + def test_more_like_this_disabled(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=module.MoreLikeThisDisabledError(), + ), + ): + with pytest.raises(AppMoreLikeThisDisabledError): + method(installed_app, "mid") + + def test_message_not_exists_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_provider_not_init_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(installed_app, "mid") + + def test_quota_exceeded_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(installed_app, "mid") + + def test_model_not_support_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(installed_app, "mid") + + def test_invoke_error_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(installed_app, "mid") + + def test_unexpected_error_more_like_this(self, app): + api = module.MessageMoreLikeThisApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + app.test_request_context( + "/", + query_string={"response_mode": "blocking"}, + ), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.AppGenerateService, + "generate_more_like_this", + side_effect=Exception("unexpected"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_app, "mid") + + +class TestMessageSuggestedQuestionApi: + def test_get_success(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ), + ): + result = method(installed_app, "mid") + + assert result["data"] == ["q1", "q2"] + + def test_not_chat_app(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotChatAppError): + method(installed_app, "mid") + + def test_disabled(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=SuggestedQuestionsAfterAnswerDisabledError(), + ), + ): + with pytest.raises(AppSuggestedQuestionsAfterAnswerDisabledError): + method(installed_app, "mid") + + def test_message_not_exists_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_conversation_not_exists_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app, "mid") + + def test_provider_not_init_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(installed_app, "mid") + + def test_quota_exceeded_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(installed_app, "mid") + + def test_model_not_support_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(installed_app, "mid") + + def test_invoke_error_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(installed_app, "mid") + + def test_unexpected_error_suggested_question(self): + api = module.MessageSuggestedQuestionApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=Exception("unexpected"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_app, "mid") diff --git a/api/tests/unit_tests/controllers/console/explore/test_parameter.py b/api/tests/unit_tests/controllers/console/explore/test_parameter.py new file mode 100644 index 0000000000..7aaecbff14 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_parameter.py @@ -0,0 +1,140 @@ +from unittest.mock import MagicMock, patch + +import pytest + +import controllers.console.explore.parameter as module +from controllers.console.app.error import AppUnavailableError +from models.model import AppMode + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAppParameterApi: + def test_get_app_none(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + installed_app = MagicMock(app=None) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + def test_get_advanced_chat_workflow(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + workflow = MagicMock() + workflow.features_dict = {"f": "v"} + workflow.user_input_form.return_value = [{"name": "x"}] + + app = MagicMock( + mode=AppMode.ADVANCED_CHAT, + workflow=workflow, + ) + + installed_app = MagicMock(app=app) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value={"any": "thing"}, + ), + patch.object( + module.fields.Parameters, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: {"ok": True}), + ), + ): + result = method(installed_app) + + assert result == {"ok": True} + + def test_get_advanced_chat_workflow_missing(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app = MagicMock( + mode=AppMode.ADVANCED_CHAT, + workflow=None, + ) + + installed_app = MagicMock(app=app) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + def test_get_non_workflow_app(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app_model_config = MagicMock() + app_model_config.to_dict.return_value = {"user_input_form": [{"name": "y"}]} + + app = MagicMock( + mode=AppMode.CHAT, + app_model_config=app_model_config, + ) + + installed_app = MagicMock(app=app) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value={"whatever": 123}, + ), + patch.object( + module.fields.Parameters, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: {"ok": True}), + ), + ): + result = method(installed_app) + + assert result == {"ok": True} + + def test_get_non_workflow_missing_config(self): + api = module.AppParameterApi() + method = unwrap(api.get) + + app = MagicMock( + mode=AppMode.CHAT, + app_model_config=None, + ) + + installed_app = MagicMock(app=app) + + with pytest.raises(AppUnavailableError): + method(installed_app) + + +class TestExploreAppMetaApi: + def test_get_meta_success(self): + api = module.ExploreAppMetaApi() + method = unwrap(api.get) + + app = MagicMock() + installed_app = MagicMock(app=app) + + with patch.object( + module.AppService, + "get_app_meta", + return_value={"meta": "ok"}, + ): + result = method(installed_app) + + assert result == {"meta": "ok"} + + def test_get_meta_app_missing(self): + api = module.ExploreAppMetaApi() + method = unwrap(api.get) + + installed_app = MagicMock(app=None) + + with pytest.raises(ValueError): + method(installed_app) diff --git a/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py new file mode 100644 index 0000000000..02c7507ea7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_recommended_app.py @@ -0,0 +1,92 @@ +from unittest.mock import MagicMock, patch + +import controllers.console.explore.recommended_app as module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestRecommendedAppListApi: + def test_get_with_language_param(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/", query_string={"language": "en-US"}), + patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with("en-US") + assert result == result_data + + def test_get_fallback_to_user_language(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/", query_string={"language": "invalid"}), + patch.object(module, "current_user", MagicMock(interface_language="fr-FR")), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with("fr-FR") + assert result == result_data + + def test_get_fallback_to_default_language(self, app): + api = module.RecommendedAppListApi() + method = unwrap(api.get) + + result_data = {"recommended_apps": [], "categories": []} + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", MagicMock(interface_language=None)), + patch.object( + module.RecommendedAppService, + "get_recommended_apps_and_categories", + return_value=result_data, + ) as service_mock, + ): + result = method(api) + + service_mock.assert_called_once_with(module.languages[0]) + assert result == result_data + + +class TestRecommendedAppApi: + def test_get_success(self, app): + api = module.RecommendedAppApi() + method = unwrap(api.get) + + result_data = {"id": "app1"} + + with ( + app.test_request_context("/"), + patch.object( + module.RecommendedAppService, + "get_recommend_app_detail", + return_value=result_data, + ) as service_mock, + ): + result = method(api, "11111111-1111-1111-1111-111111111111") + + service_mock.assert_called_once_with("11111111-1111-1111-1111-111111111111") + assert result == result_data diff --git a/api/tests/unit_tests/controllers/console/explore/test_saved_message.py b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py new file mode 100644 index 0000000000..bb7cdd55c4 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_saved_message.py @@ -0,0 +1,154 @@ +from unittest.mock import MagicMock, PropertyMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import NotFound + +import controllers.console.explore.saved_message as module +from controllers.console.explore.error import NotCompletionAppError +from services.errors.message import MessageNotExistsError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def make_saved_message(): + msg = MagicMock() + msg.id = str(uuid4()) + msg.message_id = str(uuid4()) + msg.app_id = str(uuid4()) + msg.inputs = {} + msg.query = "hello" + msg.answer = "world" + msg.user_feedback = MagicMock(rating="like") + msg.created_at = None + return msg + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(module.console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestSavedMessageListApi: + def test_get_success(self, app): + api = module.SavedMessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + pagination = MagicMock( + limit=20, + has_more=False, + data=[make_saved_message(), make_saved_message()], + ) + + with ( + app.test_request_context("/", query_string={}), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.SavedMessageService, + "pagination_by_last_id", + return_value=pagination, + ), + ): + result = method(installed_app) + + assert result["limit"] == 20 + assert result["has_more"] is False + assert len(result["data"]) == 2 + + def test_get_not_completion_app(self): + api = module.SavedMessageListApi() + method = unwrap(api.get) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app) + + def test_post_success(self, app, payload_patch): + api = module.SavedMessageListApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + payload = {"message_id": str(uuid4())} + + with ( + app.test_request_context("/", json=payload), + payload_patch(payload), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object(module.SavedMessageService, "save") as save_mock, + ): + result = method(installed_app) + + save_mock.assert_called_once() + assert result == {"result": "success"} + + def test_post_message_not_exists(self, app, payload_patch): + api = module.SavedMessageListApi() + method = unwrap(api.post) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + payload = {"message_id": str(uuid4())} + + with ( + app.test_request_context("/", json=payload), + payload_patch(payload), + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object( + module.SavedMessageService, + "save", + side_effect=MessageNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(installed_app) + + +class TestSavedMessageApi: + def test_delete_success(self): + api = module.SavedMessageApi() + method = unwrap(api.delete) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="completion") + + with ( + patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)), + patch.object(module.SavedMessageService, "delete") as delete_mock, + ): + result, status = method(installed_app, str(uuid4())) + + delete_mock.assert_called_once() + assert status == 204 + assert result == {"result": "success"} + + def test_delete_not_completion_app(self): + api = module.SavedMessageApi() + method = unwrap(api.delete) + + installed_app = MagicMock() + installed_app.app = MagicMock(mode="chat") + + with patch.object(module, "current_account_with_tenant", return_value=(MagicMock(), None)): + with pytest.raises(NotCompletionAppError): + method(installed_app, str(uuid4())) diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py new file mode 100644 index 0000000000..5a03daecbc --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -0,0 +1,1101 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import controllers.console.explore.trial as module +from controllers.console.app.error import ( + AppUnavailableError, + CompletionRequestError, + ConversationCompletedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.explore.error import ( + NotChatAppError, + NotCompletionAppError, + NotWorkflowAppError, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from core.errors.error import ( + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from models import Account +from models.account import TenantStatus +from models.model import AppMode +from services.errors.conversation import ConversationNotExistsError +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def account(): + acc = MagicMock(spec=Account) + acc.id = "u1" + return acc + + +@pytest.fixture +def trial_app_chat(): + app = MagicMock() + app.id = "a-chat" + app.mode = AppMode.CHAT + return app + + +@pytest.fixture +def trial_app_completion(): + app = MagicMock() + app.id = "a-comp" + app.mode = AppMode.COMPLETION + return app + + +@pytest.fixture +def trial_app_workflow(): + app = MagicMock() + app.id = "a-workflow" + app.mode = AppMode.WORKFLOW + return app + + +@pytest.fixture +def valid_parameters(): + return { + "user_input_form": [], + "system_parameters": {}, + "suggested_questions": {}, + "suggested_questions_after_answer": {}, + "speech_to_text": {}, + "text_to_speech": {}, + "retriever_resource": {}, + "annotation_reply": {}, + "more_like_this": {}, + "sensitive_word_avoidance": {}, + "file_upload": {}, + } + + +class TestTrialAppWorkflowRunApi: + def test_not_workflow_app(self, app): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with app.test_request_context("/"): + with pytest.raises(NotWorkflowAppError): + method(MagicMock(mode=AppMode.CHAT)) + + def test_success(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(trial_app_workflow) + + assert result is not None + + def test_workflow_provider_not_init(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(trial_app_workflow) + + def test_workflow_quota_exceeded(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(trial_app_workflow) + + def test_workflow_model_not_support(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(trial_app_workflow) + + def test_workflow_invoke_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(trial_app_workflow) + + def test_workflow_rate_limit_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(trial_app_workflow) + + def test_workflow_value_error(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "files": []}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(trial_app_workflow) + + def test_workflow_generic_exception(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "files": []}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(trial_app_workflow) + + +class TestTrialChatApi: + def test_not_chat_app(self, app): + api = module.TrialChatApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={"inputs": {}, "query": "hi"}): + with pytest.raises(NotChatAppError): + method(api, MagicMock(mode="completion")) + + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result is not None + + def test_chat_conversation_not_exists(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, trial_app_chat) + + def test_chat_conversation_completed(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.conversation.ConversationCompletedError(), + ), + ): + with pytest.raises(ConversationCompletedError): + method(api, trial_app_chat) + + def test_chat_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + method(api, trial_app_chat) + + def test_chat_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_chat_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_chat_model_not_support(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_chat) + + def test_chat_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + def test_chat_rate_limit_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, trial_app_chat) + + def test_chat_value_error(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(api, trial_app_chat) + + def test_chat_generic_exception(self, app, trial_app_chat, account): + api = module.TrialChatApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": "hi"}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_chat) + + +class TestTrialCompletionApi: + def test_not_completion_app(self, app): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={"inputs": {}, "query": ""}): + with pytest.raises(NotCompletionAppError): + method(api, MagicMock(mode=AppMode.CHAT)) + + def test_success(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object(module.AppGenerateService, "generate", return_value=MagicMock()), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_completion) + + assert result is not None + + def test_completion_app_config_broken(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(AppUnavailableError): + method(api, trial_app_completion) + + def test_completion_provider_not_init(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_completion) + + def test_completion_quota_exceeded(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_completion) + + def test_completion_model_not_support(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ModelCurrentlyNotSupportError(), + ), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_completion) + + def test_completion_invoke_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_completion) + + def test_completion_rate_limit_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=InvokeRateLimitError("test"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_completion) + + def test_completion_value_error(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=ValueError("test error"), + ), + ): + with pytest.raises(ValueError): + method(api, trial_app_completion) + + def test_completion_generic_exception(self, app, trial_app_completion, account): + api = module.TrialCompletionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"inputs": {}, "query": ""}), + patch.object(module, "current_user", account), + patch.object( + module.AppGenerateService, + "generate", + side_effect=RuntimeError("unexpected error"), + ), + ): + with pytest.raises(InternalServerError): + method(api, trial_app_completion) + + +class TestTrialMessageSuggestedQuestionApi: + def test_not_chat_app(self, app): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(NotChatAppError): + method(api, MagicMock(mode="completion"), str(uuid4())) + + def test_success(self, app, trial_app_chat, account): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + return_value=["q1", "q2"], + ), + ): + result = method(api, trial_app_chat, str(uuid4())) + + assert result == {"data": ["q1", "q2"]} + + def test_conversation_not_exists(self, app, trial_app_chat, account): + api = module.TrialMessageSuggestedQuestionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object( + module.MessageService, + "get_suggested_questions_after_answer", + side_effect=ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, trial_app_chat, str(uuid4())) + + +class TestTrialAppParameterApi: + def test_app_unavailable(self): + api = module.TrialAppParameterApi() + method = unwrap(api.get) + + with pytest.raises(AppUnavailableError): + method(api, None) + + def test_success_non_workflow(self, valid_parameters): + api = module.TrialAppParameterApi() + method = unwrap(api.get) + + app_model = MagicMock( + mode=AppMode.CHAT, + app_model_config=MagicMock(to_dict=lambda: {"user_input_form": []}), + ) + + with ( + patch.object( + module, + "get_parameters_from_feature_dict", + return_value=valid_parameters, + ), + patch.object( + module.ParametersResponse, + "model_validate", + return_value=MagicMock(model_dump=lambda mode=None: {"ok": True}), + ), + ): + result = method(api, app_model) + + assert result == {"ok": True} + + +class TestTrialChatAudioApi: + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", return_value={"text": "hello"}), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result == {"text": "hello"} + + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_no_audio_uploaded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(module.NoAudioUploadedError): + method(api, trial_app_chat) + + def test_audio_too_large(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.AudioTooLargeServiceError("Too large"), + ), + ): + with pytest.raises(module.AudioTooLargeError): + method(api, trial_app_chat) + + def test_unsupported_audio_type(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.UnsupportedAudioTypeServiceError(), + ), + ): + with pytest.raises(module.UnsupportedAudioTypeError): + method(api, trial_app_chat) + + def test_provider_not_support_tts(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=module.services.errors.audio.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(module.ProviderNotSupportSpeechToTextError): + method(api, trial_app_chat) + + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", side_effect=ProviderTokenNotInitError("test")), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_asr", side_effect=QuotaExceededError()), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + +class TestTrialChatTextApi: + def test_success(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", return_value={"audio": "base64_data"}), + patch.object(module.RecommendedAppService, "add_trial_app_record"), + ): + result = method(api, trial_app_chat) + + assert result == {"audio": "base64_data"} + + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_provider_not_support(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.ProviderNotSupportSpeechToTextServiceError(), + ), + ): + with pytest.raises(module.ProviderNotSupportSpeechToTextError): + method(api, trial_app_chat) + + def test_audio_too_large(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.AudioTooLargeServiceError("Too large"), + ), + ): + with pytest.raises(module.AudioTooLargeError): + method(api, trial_app_chat) + + def test_no_audio_uploaded(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.NoAudioUploadedServiceError(), + ), + ): + with pytest.raises(module.NoAudioUploadedError): + method(api, trial_app_chat) + + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=ProviderTokenNotInitError("test")), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=QuotaExceededError()), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_model_not_support(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=ModelCurrentlyNotSupportError()), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + method(api, trial_app_chat) + + def test_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object(module.AudioService, "transcript_tts", side_effect=InvokeError("test error")), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + +class TestTrialAppWorkflowTaskStopApi: + def test_not_workflow_app(self, app, trial_app_chat): + api = module.TrialAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with app.test_request_context("/"): + with pytest.raises(NotWorkflowAppError): + method(trial_app_chat, str(uuid4())) + + def test_success(self, app, trial_app_workflow, account): + api = module.TrialAppWorkflowTaskStopApi() + method = unwrap(api.post) + + task_id = str(uuid4()) + with ( + app.test_request_context("/"), + patch.object(module, "current_user", account), + patch.object(module.AppQueueManager, "set_stop_flag_no_user_check") as mock_set_flag, + patch.object(module.GraphEngineManager, "send_stop_command") as mock_send_cmd, + ): + result = method(trial_app_workflow, task_id) + + assert result == {"result": "success"} + mock_set_flag.assert_called_once_with(task_id) + mock_send_cmd.assert_called_once_with(task_id) + + +class TestTrialSitApi: + def test_no_site(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + app_model = MagicMock() + app_model.id = "a1" + + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = None + with pytest.raises(Forbidden): + method(api, app_model) + + def test_archived_tenant(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + + site = MagicMock() + app_model = MagicMock() + app_model.id = "a1" + app_model.tenant = MagicMock() + app_model.tenant.status = TenantStatus.ARCHIVE + + with app.test_request_context("/"), patch.object(module.db.session, "scalar") as mock_scalar: + mock_scalar.return_value = site + with pytest.raises(Forbidden): + method(api, app_model) + + def test_success(self, app): + api = module.TrialSitApi() + method = unwrap(api.get) + + site = MagicMock() + app_model = MagicMock() + app_model.id = "a1" + app_model.tenant = MagicMock() + app_model.tenant.status = TenantStatus.NORMAL + + with ( + app.test_request_context("/"), + patch.object(module.db.session, "scalar") as mock_scalar, + patch.object(module.SiteResponse, "model_validate") as mock_validate, + ): + mock_scalar.return_value = site + mock_validate_result = MagicMock() + mock_validate_result.model_dump.return_value = {"name": "test", "icon": "icon"} + mock_validate.return_value = mock_validate_result + result = method(api, app_model) + + assert result == {"name": "test", "icon": "icon"} + + +class TestTrialChatAudioApiExceptionHandlers: + def test_provider_not_init(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=ProviderTokenNotInitError("test"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, trial_app_chat) + + def test_quota_exceeded(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=QuotaExceededError(), + ), + ): + with pytest.raises(ProviderQuotaExceededError): + method(api, trial_app_chat) + + def test_invoke_error(self, app, trial_app_chat, account): + api = module.TrialChatAudioApi() + method = unwrap(api.post) + + file_data = BytesIO(b"fake audio data") + file_data.filename = "test.wav" + + with ( + app.test_request_context( + "/", method="POST", data={"file": (file_data, "test.wav")}, content_type="multipart/form-data" + ), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_asr", + side_effect=InvokeError("test error"), + ), + ): + with pytest.raises(CompletionRequestError): + method(api, trial_app_chat) + + +class TestTrialChatTextApiExceptionHandlers: + def test_app_config_broken(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.app_model_config.AppModelConfigBrokenError(), + ), + ): + with pytest.raises(module.AppUnavailableError): + method(api, trial_app_chat) + + def test_unsupported_audio_type(self, app, trial_app_chat, account): + api = module.TrialChatTextApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"text": "hello", "voice": "en-US"}), + patch.object(module, "current_user", account), + patch.object( + module.AudioService, + "transcript_tts", + side_effect=module.services.errors.audio.UnsupportedAudioTypeServiceError("test"), + ), + ): + with pytest.raises(module.UnsupportedAudioTypeError): + method(api, trial_app_chat) diff --git a/api/tests/unit_tests/controllers/console/explore/test_workflow.py b/api/tests/unit_tests/controllers/console/explore/test_workflow.py new file mode 100644 index 0000000000..445f887fd3 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_workflow.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import InternalServerError + +from controllers.console.explore.error import NotWorkflowAppError +from controllers.console.explore.workflow import ( + InstalledAppWorkflowRunApi, + InstalledAppWorkflowTaskStopApi, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from models.model import AppMode +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.config["TESTING"] = True + return app + + +@pytest.fixture +def user(): + return MagicMock() + + +@pytest.fixture +def workflow_app(): + app = MagicMock() + app.mode = AppMode.WORKFLOW + return app + + +@pytest.fixture +def installed_workflow_app(workflow_app): + return MagicMock(app=workflow_app) + + +@pytest.fixture +def non_workflow_installed_app(): + app = MagicMock() + app.mode = AppMode.CHAT + return MagicMock(app=app) + + +@pytest.fixture +def payload(): + return {"inputs": {"a": 1}} + + +class TestInstalledAppWorkflowRunApi: + def test_not_workflow_app(self, app, non_workflow_installed_app): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(NotWorkflowAppError): + method(non_workflow_installed_app) + + def test_success(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + return_value=MagicMock(), + ) as generate_mock, + ): + result = method(installed_workflow_app) + + generate_mock.assert_called_once() + assert result is not None + + def test_rate_limit_error(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + side_effect=InvokeRateLimitError("rate limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(installed_workflow_app) + + def test_unexpected_exception(self, app, installed_workflow_app, user, payload): + api = InstalledAppWorkflowRunApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.explore.workflow.current_account_with_tenant", + return_value=(user, None), + ), + patch( + "controllers.console.explore.workflow.AppGenerateService.generate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(InternalServerError): + method(installed_workflow_app) + + +class TestInstalledAppWorkflowTaskStopApi: + def test_not_workflow_app(self, non_workflow_installed_app): + api = InstalledAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with pytest.raises(NotWorkflowAppError): + method(non_workflow_installed_app, "task-1") + + def test_success(self, installed_workflow_app): + api = InstalledAppWorkflowTaskStopApi() + method = unwrap(api.post) + + with ( + patch("controllers.console.explore.workflow.AppQueueManager.set_stop_flag_no_user_check") as stop_flag, + patch("controllers.console.explore.workflow.GraphEngineManager.send_stop_command") as send_stop, + ): + result = method(installed_workflow_app, "task-1") + + stop_flag.assert_called_once_with("task-1") + send_stop.assert_called_once_with("task-1") + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/explore/test_wraps.py b/api/tests/unit_tests/controllers/console/explore/test_wraps.py new file mode 100644 index 0000000000..2c1acfc3d6 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/explore/test_wraps.py @@ -0,0 +1,244 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console.explore.error import ( + AppAccessDeniedError, + TrialAppLimitExceeded, + TrialAppNotAllowed, +) +from controllers.console.explore.wraps import ( + InstalledAppResource, + TrialAppResource, + installed_app_required, + trial_app_required, + trial_feature_enable, + user_allowed_to_access_app, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_installed_app_required_not_found(): + @installed_app_required + def view(installed_app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + ): + scalar_mock.return_value = None + + with pytest.raises(NotFound): + view("app-id") + + +def test_installed_app_required_app_deleted(): + installed_app = MagicMock(app=None) + + @installed_app_required + def view(installed_app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + patch("controllers.console.explore.wraps.db.session.delete"), + patch("controllers.console.explore.wraps.db.session.commit"), + ): + scalar_mock.return_value = installed_app + + with pytest.raises(NotFound): + view("app-id") + + +def test_installed_app_required_success(): + installed_app = MagicMock(app=MagicMock()) + + @installed_app_required + def view(installed_app): + return installed_app + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + ): + scalar_mock.return_value = installed_app + + result = view("app-id") + assert result == installed_app + + +def test_user_allowed_to_access_app_denied(): + installed_app = MagicMock(app_id="app-1") + + @user_allowed_to_access_app + def view(installed_app): + return "ok" + + feature = MagicMock() + feature.webapp_auth.enabled = True + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=feature, + ), + patch( + "controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", + return_value=False, + ), + ): + with pytest.raises(AppAccessDeniedError): + view(installed_app) + + +def test_user_allowed_to_access_app_success(): + installed_app = MagicMock(app_id="app-1") + + @user_allowed_to_access_app + def view(installed_app): + return "ok" + + feature = MagicMock() + feature.webapp_auth.enabled = True + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=feature, + ), + patch( + "controllers.console.explore.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", + return_value=True, + ), + ): + assert view(installed_app) == "ok" + + +def test_trial_app_required_not_allowed(): + @trial_app_required + def view(app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + ): + scalar_mock.return_value = None + + with pytest.raises(TrialAppNotAllowed): + view("app-id") + + +def test_trial_app_required_limit_exceeded(): + trial_app = MagicMock(trial_limit=1, app=MagicMock()) + record = MagicMock(count=1) + + @trial_app_required + def view(app): + return "ok" + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + ): + scalar_mock.side_effect = [ + trial_app, + record, + ] + + with pytest.raises(TrialAppLimitExceeded): + view("app-id") + + +def test_trial_app_required_success(): + trial_app = MagicMock(trial_limit=2, app=MagicMock()) + record = MagicMock(count=1) + + @trial_app_required + def view(app): + return app + + with ( + patch( + "controllers.console.explore.wraps.current_account_with_tenant", + return_value=(MagicMock(id="user-1"), None), + ), + patch("controllers.console.explore.wraps.db.session.scalar") as scalar_mock, + ): + scalar_mock.side_effect = [ + trial_app, + record, + ] + + result = view("app-id") + assert result == trial_app.app + + +def test_trial_feature_enable_disabled(): + @trial_feature_enable + def view(): + return "ok" + + features = MagicMock(enable_trial_app=False) + + with patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=features, + ): + with pytest.raises(Forbidden): + view() + + +def test_trial_feature_enable_enabled(): + @trial_feature_enable + def view(): + return "ok" + + features = MagicMock(enable_trial_app=True) + + with patch( + "controllers.console.explore.wraps.FeatureService.get_system_features", + return_value=features, + ): + assert view() == "ok" + + +def test_installed_app_resource_decorators(): + decorators = InstalledAppResource.method_decorators + assert len(decorators) == 4 + + +def test_trial_app_resource_decorators(): + decorators = TrialAppResource.method_decorators + assert len(decorators) == 3 diff --git a/api/tests/unit_tests/controllers/console/tag/test_tags.py b/api/tests/unit_tests/controllers/console/tag/test_tags.py new file mode 100644 index 0000000000..e89b89c8b1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/tag/test_tags.py @@ -0,0 +1,279 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.tag.tags import ( + TagBindingCreateApi, + TagBindingDeleteApi, + TagListApi, + TagUpdateDeleteApi, +) +from models.enums import TagType + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_tag") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def admin_user(): + return MagicMock( + id="user-1", + has_edit_permission=True, + is_dataset_editor=True, + ) + + +@pytest.fixture +def readonly_user(): + return MagicMock( + id="user-2", + has_edit_permission=False, + is_dataset_editor=False, + ) + + +@pytest.fixture +def tag(): + tag = MagicMock() + tag.id = "tag-1" + tag.name = "test-tag" + tag.type = TagType.KNOWLEDGE + return tag + + +@pytest.fixture +def payload_patch(): + def _patch(payload): + return patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return _patch + + +class TestTagListApi: + def test_get_success(self, app): + api = TagListApi() + method = unwrap(api.get) + + with app.test_request_context("/?type=knowledge"): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.tag.tags.TagService.get_tags", + return_value=[{"id": "1", "name": "tag"}], + ), + ): + result, status = method(api) + + assert status == 200 + assert isinstance(result, list) + + def test_post_success(self, app, admin_user, tag, payload_patch): + api = TagListApi() + method = unwrap(api.post) + + payload = {"name": "test-tag", "type": "knowledge"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch( + "controllers.console.tag.tags.TagService.save_tags", + return_value=tag, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["name"] == "test-tag" + + def test_post_forbidden(self, app, readonly_user, payload_patch): + api = TagListApi() + method = unwrap(api.post) + + payload = {"name": "x"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch(payload), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestTagUpdateDeleteApi: + def test_patch_success(self, app, admin_user, tag, payload_patch): + api = TagUpdateDeleteApi() + method = unwrap(api.patch) + + payload = {"name": "updated", "type": "knowledge"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch( + "controllers.console.tag.tags.TagService.update_tags", + return_value=tag, + ), + patch( + "controllers.console.tag.tags.TagService.get_tag_binding_count", + return_value=3, + ), + ): + result, status = method(api, "tag-1") + + assert status == 200 + assert result["binding_count"] == 3 + + def test_patch_forbidden(self, app, readonly_user, payload_patch): + api = TagUpdateDeleteApi() + method = unwrap(api.patch) + + payload = {"name": "x"} + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch(payload), + ): + with pytest.raises(Forbidden): + method(api, "tag-1") + + def test_delete_success(self, app, admin_user): + api = TagUpdateDeleteApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, "tenant-1"), + ), + patch("controllers.console.tag.tags.TagService.delete_tag") as delete_mock, + ): + result, status = method(api, "tag-1") + + delete_mock.assert_called_once_with("tag-1") + assert status == 204 + + +class TestTagBindingCreateApi: + def test_create_success(self, app, admin_user, payload_patch): + api = TagBindingCreateApi() + method = unwrap(api.post) + + payload = { + "tag_ids": ["tag-1"], + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.save_tag_binding") as save_mock, + ): + result, status = method(api) + + save_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + def test_create_forbidden(self, app, readonly_user, payload_patch): + api = TagBindingCreateApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={}): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch({}), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestTagBindingDeleteApi: + def test_remove_success(self, app, admin_user, payload_patch): + api = TagBindingDeleteApi() + method = unwrap(api.post) + + payload = { + "tag_id": "tag-1", + "target_id": "target-1", + "type": "knowledge", + } + + with app.test_request_context("/", json=payload): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(admin_user, None), + ), + payload_patch(payload), + patch("controllers.console.tag.tags.TagService.delete_tag_binding") as delete_mock, + ): + result, status = method(api) + + delete_mock.assert_called_once() + assert status == 200 + assert result["result"] == "success" + + def test_remove_forbidden(self, app, readonly_user, payload_patch): + api = TagBindingDeleteApi() + method = unwrap(api.post) + + with app.test_request_context("/", json={}): + with ( + patch( + "controllers.console.tag.tags.current_account_with_tenant", + return_value=(readonly_user, None), + ), + payload_patch({}), + ): + with pytest.raises(Forbidden): + method(api) diff --git a/api/tests/unit_tests/controllers/console/test_admin.py b/api/tests/unit_tests/controllers/console/test_admin.py index e0ddf6542e..16197fcd0c 100644 --- a/api/tests/unit_tests/controllers/console/test_admin.py +++ b/api/tests/unit_tests/controllers/console/test_admin.py @@ -1,13 +1,483 @@ """Final working unit tests for admin endpoints - tests business logic directly.""" import uuid -from unittest.mock import Mock, patch +from unittest.mock import Mock, PropertyMock, patch import pytest from werkzeug.exceptions import NotFound, Unauthorized -from controllers.console.admin import InsertExploreAppPayload -from models.model import App, RecommendedApp +from controllers.console.admin import ( + DeleteExploreBannerApi, + InsertExploreAppApi, + InsertExploreAppListApi, + InsertExploreAppPayload, + InsertExploreBannerApi, + InsertExploreBannerPayload, +) +from models.model import App, InstalledApp, RecommendedApp + + +@pytest.fixture(autouse=True) +def bypass_only_edition_cloud(mocker): + """ + Bypass only_edition_cloud decorator by setting EDITION to "CLOUD". + """ + mocker.patch( + "controllers.console.wraps.dify_config.EDITION", + new="CLOUD", + ) + + +@pytest.fixture +def mock_admin_auth(mocker): + """ + Provide valid admin authentication for controller tests. + """ + mocker.patch( + "controllers.console.admin.dify_config.ADMIN_API_KEY", + "test-admin-key", + ) + mocker.patch( + "controllers.console.admin.extract_access_token", + return_value="test-admin-key", + ) + + +@pytest.fixture +def mock_console_payload(mocker): + payload = { + "app_id": str(uuid.uuid4()), + "language": "en-US", + "category": "Productivity", + "position": 1, + } + + mocker.patch( + "flask_restx.namespace.Namespace.payload", + new_callable=PropertyMock, + return_value=payload, + ) + + return payload + + +@pytest.fixture +def mock_banner_payload(mocker): + mocker.patch( + "flask_restx.namespace.Namespace.payload", + new_callable=PropertyMock, + return_value={ + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + }, + ) + + +@pytest.fixture +def mock_session_factory(mocker): + mock_session = Mock() + mock_session.execute = Mock() + mock_session.add = Mock() + mock_session.commit = Mock() + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + +class TestDeleteExploreBannerApi: + def setup_method(self): + self.api = DeleteExploreBannerApi() + + def test_delete_banner_not_found(self, mocker, mock_admin_auth): + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: None), + ) + + with pytest.raises(NotFound, match="is not found"): + self.api.delete(uuid.uuid4()) + + def test_delete_banner_success(self, mocker, mock_admin_auth): + mock_banner = Mock() + + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: mock_banner), + ) + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(uuid.uuid4()) + + assert status == 204 + assert response["result"] == "success" + + +class TestInsertExploreBannerApi: + def setup_method(self): + self.api = InsertExploreBannerApi() + + def test_insert_banner_success(self, mocker, mock_admin_auth, mock_banner_payload): + mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 201 + assert response["result"] == "success" + + def test_banner_payload_valid_language(self): + payload = { + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + "language": "en-US", + } + + model = InsertExploreBannerPayload.model_validate(payload) + assert model.language == "en-US" + + def test_banner_payload_invalid_language(self): + payload = { + "title": "Test Banner", + "description": "Banner description", + "img-src": "https://example.com/banner.png", + "link": "https://example.com", + "sort": 1, + "category": "homepage", + "language": "invalid-lang", + } + + with pytest.raises(ValueError, match="invalid-lang is not a valid language"): + InsertExploreBannerPayload.model_validate(payload) + + +class TestInsertExploreAppApiDelete: + def setup_method(self): + self.api = InsertExploreAppApi() + + def test_delete_when_not_in_explore(self, mocker, mock_admin_auth): + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: s, + __exit__=Mock(return_value=False), + execute=lambda *_: Mock(scalar_one_or_none=lambda: None), + ), + ) + + response, status = self.api.delete(uuid.uuid4()) + + assert status == 204 + assert response["result"] == "success" + + def test_delete_when_in_explore_with_trial_app(self, mocker, mock_admin_auth): + """Test deleting an app from explore that has a trial app.""" + app_id = uuid.uuid4() + + mock_recommended = Mock(spec=RecommendedApp) + mock_recommended.app_id = "app-123" + + mock_app = Mock(spec=App) + mock_app.is_public = True + + mock_trial = Mock() + + # Mock session context manager and its execute + mock_session = Mock() + mock_session.execute = Mock() + mock_session.delete = Mock() + + # Set up side effects for execute calls + mock_session.execute.side_effect = [ + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalars=Mock(return_value=Mock(all=lambda: []))), + Mock(scalar_one_or_none=lambda: mock_trial), + ] + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(app_id) + + assert status == 204 + assert response["result"] == "success" + assert mock_app.is_public is False + + def test_delete_with_installed_apps(self, mocker, mock_admin_auth): + """Test deleting an app that has installed apps in other tenants.""" + app_id = uuid.uuid4() + + mock_recommended = Mock(spec=RecommendedApp) + mock_recommended.app_id = "app-123" + + mock_app = Mock(spec=App) + mock_app.is_public = True + + mock_installed_app = Mock(spec=InstalledApp) + + # Mock session + mock_session = Mock() + mock_session.execute = Mock() + mock_session.delete = Mock() + + mock_session.execute.side_effect = [ + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalars=Mock(return_value=Mock(all=lambda: [mock_installed_app]))), + Mock(scalar_one_or_none=lambda: None), + ] + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.delete") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.delete(app_id) + + assert status == 204 + assert mock_session.delete.called + + +class TestInsertExploreAppListApi: + def setup_method(self): + self.api = InsertExploreAppListApi() + + def test_app_not_found(self, mocker, mock_admin_auth, mock_console_payload): + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: None), + ) + + with pytest.raises(NotFound, match="is not found"): + self.api.post() + + def test_create_recommended_app( + self, + mocker, + mock_admin_auth, + mock_console_payload, + ): + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + # db.session.execute → fetch App + mocker.patch( + "controllers.console.admin.db.session.execute", + return_value=Mock(scalar_one_or_none=lambda: mock_app), + ) + + # session_factory.create_session → recommended_app lookup + mock_session = Mock() + mock_session.execute = Mock(return_value=Mock(scalar_one_or_none=lambda: None)) + + mocker.patch( + "controllers.console.admin.session_factory.create_session", + return_value=Mock( + __enter__=lambda s: mock_session, + __exit__=Mock(return_value=False), + ), + ) + + mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 201 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_update_recommended_app(self, mocker, mock_admin_auth, mock_console_payload, mock_session_factory): + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + ], + ) + + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_site_data_overrides_payload( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + site = Mock() + site.description = "Site Desc" + site.copyright = "Site Copyright" + site.privacy_policy = "Site Privacy" + site.custom_disclaimer = "Site Disclaimer" + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = site + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: None), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + commit_spy = mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + commit_spy.assert_called_once() + + def test_create_trial_app_when_can_trial_enabled( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + mock_console_payload["can_trial"] = True + mock_console_payload["trial_limit"] = 5 + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.tenant_id = "tenant" + mock_app.is_public = False + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: None), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + add_spy = mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + self.api.post() + + assert any(call.args[0].__class__.__name__ == "TrialApp" for call in add_spy.call_args_list) + + def test_update_recommended_app_with_trial( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + """Test updating a recommended app when trial is enabled.""" + mock_console_payload["can_trial"] = True + mock_console_payload["trial_limit"] = 10 + + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + mock_app.tenant_id = "tenant-123" + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + Mock(scalar_one_or_none=lambda: None), + ], + ) + + add_spy = mocker.patch("controllers.console.admin.db.session.add") + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True + + def test_update_recommended_app_without_trial( + self, + mocker, + mock_admin_auth, + mock_console_payload, + mock_session_factory, + ): + """Test updating a recommended app without trial enabled.""" + mock_app = Mock(spec=App) + mock_app.id = "app-id" + mock_app.site = None + mock_app.is_public = False + + mock_recommended = Mock(spec=RecommendedApp) + + mocker.patch( + "controllers.console.admin.db.session.execute", + side_effect=[ + Mock(scalar_one_or_none=lambda: mock_app), + Mock(scalar_one_or_none=lambda: mock_recommended), + ], + ) + + mocker.patch("controllers.console.admin.db.session.commit") + + response, status = self.api.post() + + assert status == 200 + assert response["result"] == "success" + assert mock_app.is_public is True class TestInsertExploreAppPayload: diff --git a/api/tests/unit_tests/controllers/console/test_apikey.py b/api/tests/unit_tests/controllers/console/test_apikey.py new file mode 100644 index 0000000000..c18dd044a7 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_apikey.py @@ -0,0 +1,138 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.apikey import ( + BaseApiKeyListResource, + BaseApiKeyResource, + _get_resource, +) + + +@pytest.fixture +def tenant_context_admin(): + with patch("controllers.console.apikey.current_account_with_tenant") as mock: + user = MagicMock() + user.is_admin_or_owner = True + mock.return_value = (user, "tenant-123") + yield mock + + +@pytest.fixture +def tenant_context_non_admin(): + with patch("controllers.console.apikey.current_account_with_tenant") as mock: + user = MagicMock() + user.is_admin_or_owner = False + mock.return_value = (user, "tenant-123") + yield mock + + +@pytest.fixture +def db_mock(): + with patch("controllers.console.apikey.db") as mock_db: + mock_db.session = MagicMock() + yield mock_db + + +@pytest.fixture(autouse=True) +def bypass_permissions(): + with patch( + "controllers.console.apikey.edit_permission_required", + lambda f: f, + ): + yield + + +class DummyApiKeyListResource(BaseApiKeyListResource): + resource_type = "app" + resource_model = MagicMock() + resource_id_field = "app_id" + token_prefix = "app-" + + +class DummyApiKeyResource(BaseApiKeyResource): + resource_type = "app" + resource_model = MagicMock() + resource_id_field = "app_id" + + +class TestGetResource: + def test_get_resource_success(self): + fake_resource = MagicMock() + + with ( + patch("controllers.console.apikey.select") as mock_select, + patch("controllers.console.apikey.Session") as mock_session, + patch("controllers.console.apikey.db") as mock_db, + ): + mock_db.engine = MagicMock() + mock_select.return_value.filter_by.return_value = MagicMock() + + session = mock_session.return_value.__enter__.return_value + session.execute.return_value.scalar_one_or_none.return_value = fake_resource + + result = _get_resource("rid", "tid", MagicMock) + assert result == fake_resource + + def test_get_resource_not_found(self): + with ( + patch("controllers.console.apikey.select") as mock_select, + patch("controllers.console.apikey.Session") as mock_session, + patch("controllers.console.apikey.db") as mock_db, + patch("controllers.console.apikey.flask_restx.abort") as abort, + ): + mock_db.engine = MagicMock() + mock_select.return_value.filter_by.return_value = MagicMock() + + session = mock_session.return_value.__enter__.return_value + session.execute.return_value.scalar_one_or_none.return_value = None + + _get_resource("rid", "tid", MagicMock) + + abort.assert_called_once() + + +class TestBaseApiKeyListResource: + def test_get_apikeys_success(self, tenant_context_admin, db_mock): + resource = DummyApiKeyListResource() + + with patch("controllers.console.apikey._get_resource"): + db_mock.session.scalars.return_value.all.return_value = [MagicMock(), MagicMock()] + + result = DummyApiKeyListResource.get.__wrapped__(resource, "resource-id") + assert "items" in result + + +class TestBaseApiKeyResource: + def test_delete_forbidden(self, tenant_context_non_admin, db_mock): + resource = DummyApiKeyResource() + + with patch("controllers.console.apikey._get_resource"): + with pytest.raises(Forbidden): + DummyApiKeyResource.delete(resource, "rid", "kid") + + def test_delete_key_not_found(self, tenant_context_admin, db_mock): + resource = DummyApiKeyResource() + db_mock.session.scalar.return_value = None + + with patch("controllers.console.apikey._get_resource"): + with pytest.raises(Exception) as exc_info: + DummyApiKeyResource.delete(resource, "rid", "kid") + + # flask_restx.abort raises HTTPException with message in data attribute + assert exc_info.value.data["message"] == "API key not found" + + def test_delete_success(self, tenant_context_admin, db_mock): + resource = DummyApiKeyResource() + db_mock.session.scalar.return_value = MagicMock() + + with ( + patch("controllers.console.apikey._get_resource"), + patch("controllers.console.apikey.ApiTokenCache.delete"), + ): + result, status = DummyApiKeyResource.delete(resource, "rid", "kid") + + assert status == 204 + assert result == {"result": "success"} + db_mock.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/test_extension.py b/api/tests/unit_tests/controllers/console/test_extension.py index 85eb6e7d71..0d1fb39348 100644 --- a/api/tests/unit_tests/controllers/console/test_extension.py +++ b/api/tests/unit_tests/controllers/console/test_extension.py @@ -22,7 +22,7 @@ from controllers.console.extension import ( ) if _NEEDS_METHOD_VIEW_CLEANUP: - delattr(builtins, "MethodView") + del builtins.MethodView from models.account import AccountStatus from models.api_based_extension import APIBasedExtension diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py deleted file mode 100644 index b9bc42fb25..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_init_validate.py +++ /dev/null @@ -1,46 +0,0 @@ -import builtins -from unittest.mock import patch - -import pytest -from flask import Flask -from flask.views import MethodView - -from extensions import ext_fastopenapi - -if not hasattr(builtins, "MethodView"): - builtins.MethodView = MethodView # type: ignore[attr-defined] - - -@pytest.fixture -def app() -> Flask: - app = Flask(__name__) - app.config["TESTING"] = True - app.secret_key = "test-secret-key" - return app - - -def test_console_init_get_returns_finished_when_no_init_password(app: Flask, monkeypatch: pytest.MonkeyPatch): - ext_fastopenapi.init_app(app) - monkeypatch.delenv("INIT_PASSWORD", raising=False) - - with patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"): - client = app.test_client() - response = client.get("/console/api/init") - - assert response.status_code == 200 - assert response.get_json() == {"status": "finished"} - - -def test_console_init_post_returns_success(app: Flask, monkeypatch: pytest.MonkeyPatch): - ext_fastopenapi.init_app(app) - monkeypatch.setenv("INIT_PASSWORD", "test-init-password") - - with ( - patch("controllers.console.init_validate.dify_config.EDITION", "SELF_HOSTED"), - patch("controllers.console.init_validate.TenantService.get_tenant_count", return_value=0), - ): - client = app.test_client() - response = client.post("/console/api/init", json={"password": "test-init-password"}) - - assert response.status_code == 201 - assert response.get_json() == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py b/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py deleted file mode 100644 index c0a984e216..0000000000 --- a/api/tests/unit_tests/controllers/console/test_fastopenapi_remote_files.py +++ /dev/null @@ -1,286 +0,0 @@ -"""Tests for remote file upload API endpoints using Flask-RESTX.""" - -import contextlib -from datetime import datetime -from types import SimpleNamespace -from unittest.mock import Mock, patch - -import httpx -import pytest -from flask import Flask, g - - -@pytest.fixture -def app() -> Flask: - """Create Flask app for testing.""" - app = Flask(__name__) - app.config["TESTING"] = True - app.config["SECRET_KEY"] = "test-secret-key" - return app - - -@pytest.fixture -def client(app): - """Create test client with console blueprint registered.""" - from controllers.console import bp - - app.register_blueprint(bp) - return app.test_client() - - -@pytest.fixture -def mock_account(): - """Create a mock account for testing.""" - from models import Account - - account = Mock(spec=Account) - account.id = "test-account-id" - account.current_tenant_id = "test-tenant-id" - return account - - -@pytest.fixture -def auth_ctx(app, mock_account): - """Context manager to set auth/tenant context in flask.g for a request.""" - - @contextlib.contextmanager - def _ctx(): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - yield - - return _ctx - - -class TestGetRemoteFileInfo: - """Test GET /console/api/remote-files/ endpoint.""" - - def test_get_remote_file_info_success(self, app, client, mock_account): - """Test successful retrieval of remote file info.""" - response = httpx.Response( - 200, - request=httpx.Request("HEAD", "http://example.com/file.txt"), - headers={"Content-Type": "text/plain", "Content-Length": "1024"}, - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=response), - patch("libs.login.check_csrf_token", return_value=None), - ): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - encoded_url = "http%3A%2F%2Fexample.com%2Ffile.txt" - resp = client.get(f"/console/api/remote-files/{encoded_url}") - - assert resp.status_code == 200 - data = resp.get_json() - assert data["file_type"] == "text/plain" - assert data["file_length"] == 1024 - - def test_get_remote_file_info_fallback_to_get_on_head_failure(self, app, client, mock_account): - """Test fallback to GET when HEAD returns non-200 status.""" - head_response = httpx.Response( - 404, - request=httpx.Request("HEAD", "http://example.com/file.pdf"), - ) - get_response = httpx.Response( - 200, - request=httpx.Request("GET", "http://example.com/file.pdf"), - headers={"Content-Type": "application/pdf", "Content-Length": "2048"}, - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_response), - patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_response), - patch("libs.login.check_csrf_token", return_value=None), - ): - with app.test_request_context(): - g._login_user = mock_account - g._current_tenant = mock_account.current_tenant_id - encoded_url = "http%3A%2F%2Fexample.com%2Ffile.pdf" - resp = client.get(f"/console/api/remote-files/{encoded_url}") - - assert resp.status_code == 200 - data = resp.get_json() - assert data["file_type"] == "application/pdf" - assert data["file_length"] == 2048 - - -class TestRemoteFileUpload: - """Test POST /console/api/remote-files/upload endpoint.""" - - @pytest.mark.parametrize( - ("head_status", "use_get"), - [ - (200, False), # HEAD succeeds - (405, True), # HEAD fails -> fallback GET - ], - ) - def test_upload_remote_file_success_paths(self, client, mock_account, auth_ctx, head_status, use_get): - url = "http://example.com/file.pdf" - head_resp = httpx.Response( - head_status, - request=httpx.Request("HEAD", url), - headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, - ) - get_resp = httpx.Response( - 200, - request=httpx.Request("GET", url), - headers={"Content-Type": "application/pdf", "Content-Length": "1024"}, - content=b"file content", - ) - - file_info = SimpleNamespace( - extension="pdf", - size=1024, - filename="file.pdf", - mimetype="application/pdf", - ) - uploaded_file = SimpleNamespace( - id="uploaded-file-id", - name="file.pdf", - size=1024, - extension="pdf", - mime_type="application/pdf", - created_by="test-account-id", - created_at=datetime(2024, 1, 1, 12, 0, 0), - ) - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp) as p_head, - patch("controllers.console.remote_files.ssrf_proxy.get", return_value=get_resp) as p_get, - patch( - "controllers.console.remote_files.helpers.guess_file_info_from_response", - return_value=file_info, - ), - patch( - "controllers.console.remote_files.FileService.is_file_size_within_limit", - return_value=True, - ), - patch("controllers.console.remote_files.db", spec=["engine"]), - patch("controllers.console.remote_files.FileService") as mock_file_service, - patch( - "controllers.console.remote_files.file_helpers.get_signed_file_url", - return_value="http://example.com/signed-url", - ), - patch("libs.login.check_csrf_token", return_value=None), - ): - mock_file_service.return_value.upload_file.return_value = uploaded_file - - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - - assert resp.status_code == 201 - p_head.assert_called_once() - # GET is used either for fallback (HEAD fails) or to fetch content after HEAD succeeds - p_get.assert_called_once() - mock_file_service.return_value.upload_file.assert_called_once() - - data = resp.get_json() - assert data["id"] == "uploaded-file-id" - assert data["name"] == "file.pdf" - assert data["size"] == 1024 - assert data["extension"] == "pdf" - assert data["url"] == "http://example.com/signed-url" - assert data["mime_type"] == "application/pdf" - assert data["created_by"] == "test-account-id" - - @pytest.mark.parametrize( - ("size_ok", "raises", "expected_status", "expected_msg"), - [ - # When size check fails in controller, API returns 413 with message "File size exceeded..." - (False, None, 413, "file size exceeded"), - # When service raises unsupported type, controller maps to 415 with message "File type not allowed." - (True, "unsupported", 415, "file type not allowed"), - ], - ) - def test_upload_remote_file_errors( - self, client, mock_account, auth_ctx, size_ok, raises, expected_status, expected_msg - ): - url = "http://example.com/x.pdf" - head_resp = httpx.Response( - 200, - request=httpx.Request("HEAD", url), - headers={"Content-Type": "application/pdf", "Content-Length": "9"}, - ) - file_info = SimpleNamespace(extension="pdf", size=9, filename="x.pdf", mimetype="application/pdf") - - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch("controllers.console.remote_files.ssrf_proxy.head", return_value=head_resp), - patch( - "controllers.console.remote_files.helpers.guess_file_info_from_response", - return_value=file_info, - ), - patch( - "controllers.console.remote_files.FileService.is_file_size_within_limit", - return_value=size_ok, - ), - patch("controllers.console.remote_files.db", spec=["engine"]), - patch("libs.login.check_csrf_token", return_value=None), - ): - if raises == "unsupported": - from services.errors.file import UnsupportedFileTypeError - - with patch("controllers.console.remote_files.FileService") as mock_file_service: - mock_file_service.return_value.upload_file.side_effect = UnsupportedFileTypeError("bad") - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - else: - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": url}, - ) - - assert resp.status_code == expected_status - data = resp.get_json() - msg = (data.get("error") or {}).get("message") or data.get("message", "") - assert expected_msg in msg.lower() - - def test_upload_remote_file_fetch_failure(self, client, mock_account, auth_ctx): - """Test upload when fetching of remote file fails.""" - with ( - patch( - "controllers.console.remote_files.current_account_with_tenant", - return_value=(mock_account, "test-tenant-id"), - ), - patch( - "controllers.console.remote_files.ssrf_proxy.head", - side_effect=httpx.RequestError("Connection failed"), - ), - patch("libs.login.check_csrf_token", return_value=None), - ): - with auth_ctx(): - resp = client.post( - "/console/api/remote-files/upload", - json={"url": "http://unreachable.com/file.pdf"}, - ) - - assert resp.status_code == 400 - data = resp.get_json() - msg = (data.get("error") or {}).get("message") or data.get("message", "") - assert "failed to fetch" in msg.lower() diff --git a/api/tests/unit_tests/controllers/console/test_feature.py b/api/tests/unit_tests/controllers/console/test_feature.py new file mode 100644 index 0000000000..d8debc1f2c --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_feature.py @@ -0,0 +1,81 @@ +from werkzeug.exceptions import Unauthorized + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestFeatureApi: + def test_get_tenant_features_success(self, mocker): + from controllers.console.feature import FeatureApi + + mocker.patch( + "controllers.console.feature.current_account_with_tenant", + return_value=("account_id", "tenant_123"), + ) + + mocker.patch("controllers.console.feature.FeatureService.get_features").return_value.model_dump.return_value = { + "features": {"feature_a": True} + } + + api = FeatureApi() + + raw_get = unwrap(FeatureApi.get) + result = raw_get(api) + + assert result == {"features": {"feature_a": True}} + + +class TestSystemFeatureApi: + def test_get_system_features_authenticated(self, mocker): + """ + current_user.is_authenticated == True + """ + + from controllers.console.feature import SystemFeatureApi + + fake_user = mocker.Mock() + fake_user.is_authenticated = True + + mocker.patch( + "controllers.console.feature.current_user", + fake_user, + ) + + mocker.patch( + "controllers.console.feature.FeatureService.get_system_features" + ).return_value.model_dump.return_value = {"features": {"sys_feature": True}} + + api = SystemFeatureApi() + result = api.get() + + assert result == {"features": {"sys_feature": True}} + + def test_get_system_features_unauthenticated(self, mocker): + """ + current_user.is_authenticated raises Unauthorized + """ + + from controllers.console.feature import SystemFeatureApi + + fake_user = mocker.Mock() + type(fake_user).is_authenticated = mocker.PropertyMock(side_effect=Unauthorized()) + + mocker.patch( + "controllers.console.feature.current_user", + fake_user, + ) + + mocker.patch( + "controllers.console.feature.FeatureService.get_system_features" + ).return_value.model_dump.return_value = {"features": {"sys_feature": False}} + + api = SystemFeatureApi() + result = api.get() + + assert result == {"features": {"sys_feature": False}} diff --git a/api/tests/unit_tests/controllers/console/test_files.py b/api/tests/unit_tests/controllers/console/test_files.py new file mode 100644 index 0000000000..5df9daa7f8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_files.py @@ -0,0 +1,300 @@ +import io +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from constants import DOCUMENT_EXTENSIONS +from controllers.common.errors import ( + BlockedFileExtensionError, + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.files import ( + FileApi, + FilePreviewApi, + FileSupportTypeApi, +) + + +def unwrap(func): + """ + Recursively unwrap decorated functions. + """ + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask(__name__) + app.testing = True + return app + + +@pytest.fixture(autouse=True) +def mock_decorators(): + """ + Make decorators no-ops so logic is directly testable + """ + with ( + patch("controllers.console.files.setup_required", new=lambda f: f), + patch("controllers.console.files.login_required", new=lambda f: f), + patch("controllers.console.files.account_initialization_required", new=lambda f: f), + patch("controllers.console.files.cloud_edition_billing_resource_check", return_value=lambda f: f), + ): + yield + + +@pytest.fixture +def mock_current_user(): + user = MagicMock() + user.is_dataset_editor = True + return user + + +@pytest.fixture +def mock_account_context(mock_current_user): + with patch( + "controllers.console.files.current_account_with_tenant", + return_value=(mock_current_user, None), + ): + yield + + +@pytest.fixture +def mock_db(): + with patch("controllers.console.files.db") as db_mock: + db_mock.engine = MagicMock() + yield db_mock + + +@pytest.fixture +def mock_file_service(mock_db): + with patch("controllers.console.files.FileService") as fs: + instance = fs.return_value + yield instance + + +class TestFileApiGet: + def test_get_upload_config(self, app): + api = FileApi() + get_method = unwrap(api.get) + + with app.test_request_context(): + data, status = get_method(api) + + assert status == 200 + assert "file_size_limit" in data + assert "batch_count_limit" in data + + +class TestFileApiPost: + def test_no_file_uploaded(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + with app.test_request_context(method="POST", data={}): + with pytest.raises(NoFileUploadedError): + post_method(api) + + def test_too_many_files(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + with app.test_request_context(method="POST"): + from unittest.mock import MagicMock, patch + + with patch("controllers.console.files.request") as mock_request: + mock_request.files = MagicMock() + mock_request.files.__len__.return_value = 2 + mock_request.files.__contains__.return_value = True + mock_request.form = MagicMock() + mock_request.form.get.return_value = None + + with pytest.raises(TooManyFilesError): + post_method(api) + + def test_filename_missing(self, app, mock_account_context): + api = FileApi() + post_method = unwrap(api.post) + + data = { + "file": (io.BytesIO(b"abc"), ""), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(FilenameNotExistsError): + post_method(api) + + def test_dataset_upload_without_permission(self, app, mock_current_user): + mock_current_user.is_dataset_editor = False + + with patch( + "controllers.console.files.current_account_with_tenant", + return_value=(mock_current_user, None), + ): + api = FileApi() + post_method = unwrap(api.post) + + data = { + "file": (io.BytesIO(b"abc"), "test.txt"), + "source": "datasets", + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(Forbidden): + post_method(api) + + def test_successful_upload(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + mock_file = MagicMock() + mock_file.id = "file-id-123" + mock_file.filename = "test.txt" + mock_file.name = "test.txt" + mock_file.size = 1024 + mock_file.extension = "txt" + mock_file.mime_type = "text/plain" + mock_file.created_by = "user-123" + mock_file.created_at = 1234567890 + mock_file.preview_url = "http://example.com/preview/file-id-123" + mock_file.source_url = "http://example.com/source/file-id-123" + mock_file.original_url = None + mock_file.user_id = "user-123" + mock_file.tenant_id = "tenant-123" + mock_file.conversation_id = None + mock_file.file_key = "file-key-123" + + mock_file_service.upload_file.return_value = mock_file + + data = { + "file": (io.BytesIO(b"hello"), "test.txt"), + } + + with app.test_request_context(method="POST", data=data): + response, status = post_method(api) + + assert status == 201 + assert response["id"] == "file-id-123" + assert response["name"] == "test.txt" + + def test_upload_with_invalid_source(self, app, mock_account_context, mock_file_service): + """Test that invalid source parameter gets normalized to None""" + api = FileApi() + post_method = unwrap(api.post) + + # Create a properly structured mock file object + mock_file = MagicMock() + mock_file.id = "file-id-456" + mock_file.filename = "test.txt" + mock_file.name = "test.txt" + mock_file.size = 512 + mock_file.extension = "txt" + mock_file.mime_type = "text/plain" + mock_file.created_by = "user-456" + mock_file.created_at = 1234567890 + mock_file.preview_url = None + mock_file.source_url = None + mock_file.original_url = None + mock_file.user_id = "user-456" + mock_file.tenant_id = "tenant-456" + mock_file.conversation_id = None + mock_file.file_key = "file-key-456" + + mock_file_service.upload_file.return_value = mock_file + + data = { + "file": (io.BytesIO(b"content"), "test.txt"), + "source": "invalid_source", # Should be normalized to None + } + + with app.test_request_context(method="POST", data=data): + response, status = post_method(api) + + assert status == 201 + assert response["id"] == "file-id-456" + # Verify that FileService was called with source=None + mock_file_service.upload_file.assert_called_once() + call_kwargs = mock_file_service.upload_file.call_args[1] + assert call_kwargs["source"] is None + + def test_file_too_large_error(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import FileTooLargeError as ServiceFileTooLargeError + + error = ServiceFileTooLargeError("File is too large") + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x" * 1000000), "big.txt"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(FileTooLargeError): + post_method(api) + + def test_unsupported_file_type(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + error = ServiceUnsupportedFileTypeError() + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x"), "bad.exe"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(UnsupportedFileTypeError): + post_method(api) + + def test_blocked_extension(self, app, mock_account_context, mock_file_service): + api = FileApi() + post_method = unwrap(api.post) + + from services.errors.file import BlockedFileExtensionError as ServiceBlockedFileExtensionError + + error = ServiceBlockedFileExtensionError("File extension is blocked") + mock_file_service.upload_file.side_effect = error + + data = { + "file": (io.BytesIO(b"x"), "blocked.txt"), + } + + with app.test_request_context(method="POST", data=data): + with pytest.raises(BlockedFileExtensionError): + post_method(api) + + +class TestFilePreviewApi: + def test_get_preview(self, app, mock_file_service): + api = FilePreviewApi() + get_method = unwrap(api.get) + mock_file_service.get_file_preview.return_value = "preview text" + + with app.test_request_context(): + result = get_method(api, "1234") + + assert result == {"content": "preview text"} + + +class TestFileSupportTypeApi: + def test_get_supported_types(self, app): + api = FileSupportTypeApi() + get_method = unwrap(api.get) + + with app.test_request_context(): + result = get_method(api) + + assert result == {"allowed_extensions": list(DOCUMENT_EXTENSIONS)} diff --git a/api/tests/unit_tests/controllers/console/test_human_input_form.py b/api/tests/unit_tests/controllers/console/test_human_input_form.py new file mode 100644 index 0000000000..232b6eee79 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_human_input_form.py @@ -0,0 +1,293 @@ +from __future__ import annotations + +import json +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest +from flask import Response + +from controllers.console.human_input_form import ( + ConsoleHumanInputFormApi, + ConsoleWorkflowEventsApi, + DifyAPIRepositoryFactory, + WorkflowResponseConverter, + _jsonify_form_definition, +) +from controllers.web.error import NotFoundError +from models.enums import CreatorUserRole +from models.human_input import RecipientType +from models.model import AppMode + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def test_jsonify_form_definition() -> None: + expiration = datetime(2024, 1, 1, tzinfo=UTC) + definition = SimpleNamespace(model_dump=lambda: {"fields": []}) + form = SimpleNamespace(get_definition=lambda: definition, expiration_time=expiration) + + response = _jsonify_form_definition(form) + + assert isinstance(response, Response) + payload = json.loads(response.get_data(as_text=True)) + assert payload["expiration_time"] == int(expiration.timestamp()) + + +def test_ensure_console_access_rejects(monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1") + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-2")) + + with pytest.raises(NotFoundError): + ConsoleHumanInputFormApi._ensure_console_access(form) + + +def test_get_form_definition_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + expiration = datetime(2024, 1, 1, tzinfo=UTC) + definition = SimpleNamespace(model_dump=lambda: {"fields": ["a"]}) + form = SimpleNamespace(tenant_id="tenant-1", get_definition=lambda: definition, expiration_time=expiration) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_definition_by_token_for_console(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/form/human_input/token", method="GET"): + response = handler(api, form_token="token") + + payload = json.loads(response.get_data(as_text=True)) + assert payload["fields"] == ["a"] + + +def test_get_form_definition_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_definition_by_token_for_console(self, _token): + return None + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr("controllers.console.human_input_form.current_account_with_tenant", lambda: (None, "tenant-1")) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/form/human_input/token", method="GET"): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + +def test_post_form_invalid_recipient_type(app, monkeypatch: pytest.MonkeyPatch) -> None: + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.EMAIL_MEMBER) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + with pytest.raises(NotFoundError): + handler(api, form_token="token") + + +def test_post_form_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + submit_mock = Mock() + form = SimpleNamespace(tenant_id="tenant-1", recipient_type=RecipientType.CONSOLE) + + class _ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def get_form_by_token(self, _token): + return form + + def submit_form_by_token(self, **kwargs): + submit_mock(**kwargs) + + monkeypatch.setattr("controllers.console.human_input_form.HumanInputService", _ServiceStub) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "tenant-1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleHumanInputFormApi() + handler = _unwrap(api.post) + + with app.test_request_context( + "/console/api/form/human_input/token", + method="POST", + json={"inputs": {"content": "ok"}, "action": "approve"}, + ): + response = handler(api, form_token="token") + + assert response.get_json() == {} + submit_mock.assert_called_once() + + +def test_workflow_events_not_found(app, monkeypatch: pytest.MonkeyPatch) -> None: + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return None + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_requires_account(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.END_USER, + created_by="user-1", + tenant_id="t1", + ) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_requires_creator(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-2", + tenant_id="t1", + ) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="u1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + with pytest.raises(NotFoundError): + handler(api, workflow_run_id="run-1") + + +def test_workflow_events_finished(app, monkeypatch: pytest.MonkeyPatch) -> None: + workflow_run = SimpleNamespace( + id="run-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-1", + tenant_id="t1", + app_id="app-1", + finished_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + app_model = SimpleNamespace(mode=AppMode.WORKFLOW) + + class _RepoStub: + def get_workflow_run_by_id_and_tenant_id(self, **_kwargs): + return workflow_run + + response_obj = SimpleNamespace( + event=SimpleNamespace(value="finished"), + model_dump=lambda mode="json": {"status": "done"}, + ) + + monkeypatch.setattr( + DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda *_args, **_kwargs: _RepoStub(), + ) + monkeypatch.setattr( + "controllers.console.human_input_form._retrieve_app_for_workflow_run", + lambda *_args, **_kwargs: app_model, + ) + monkeypatch.setattr( + WorkflowResponseConverter, + "workflow_run_result_to_finish_response", + lambda **_kwargs: response_obj, + ) + monkeypatch.setattr( + "controllers.console.human_input_form.current_account_with_tenant", + lambda: (SimpleNamespace(id="user-1"), "t1"), + ) + monkeypatch.setattr("controllers.console.human_input_form.db", SimpleNamespace(engine=object())) + + api = ConsoleWorkflowEventsApi() + handler = _unwrap(api.get) + + with app.test_request_context("/console/api/workflow/run/events", method="GET"): + response = handler(api, workflow_run_id="run-1") + + assert response.mimetype == "text/event-stream" + assert "data" in response.get_data(as_text=True) diff --git a/api/tests/unit_tests/controllers/console/test_init_validate.py b/api/tests/unit_tests/controllers/console/test_init_validate.py new file mode 100644 index 0000000000..3077304cbe --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_init_validate.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock + +import pytest + +from controllers.console import init_validate +from controllers.console.error import AlreadySetupError, InitValidateFailedError + + +class _SessionStub: + def __init__(self, has_setup: bool): + self._has_setup = has_setup + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def execute(self, *_args, **_kwargs): + return SimpleNamespace(scalar_one_or_none=lambda: Mock() if self._has_setup else None) + + +def test_get_init_status_finished(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: True) + result = init_validate.get_init_status() + assert result.status == "finished" + + +def test_get_init_status_not_started(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate, "get_init_validate_status", lambda: False) + result = init_validate.get_init_status() + assert result.status == "not_started" + + +def test_validate_init_password_already_setup(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 1) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + with pytest.raises(AlreadySetupError): + init_validate.validate_init_password(init_validate.InitValidatePayload(password="pw")) + + +def test_validate_init_password_wrong_password(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + with pytest.raises(InitValidateFailedError): + init_validate.validate_init_password(init_validate.InitValidatePayload(password="wrong")) + assert init_validate.session.get("is_init_validated") is False + + +def test_validate_init_password_success(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setattr(init_validate.TenantService, "get_tenant_count", lambda: 0) + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="POST"): + result = init_validate.validate_init_password(init_validate.InitValidatePayload(password="expected")) + assert result.result == "success" + assert init_validate.session.get("is_init_validated") is True + + +def test_get_init_validate_status_not_self_hosted(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "CLOUD") + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_validated_session(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session["is_init_validated"] = True + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_setup_exists(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(True)) + monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object())) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session.pop("is_init_validated", None) + assert init_validate.get_init_validate_status() is True + + +def test_get_init_validate_status_not_validated(app, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(init_validate.dify_config, "EDITION", "SELF_HOSTED") + monkeypatch.setenv("INIT_PASSWORD", "expected") + monkeypatch.setattr(init_validate, "Session", lambda *_args, **_kwargs: _SessionStub(False)) + monkeypatch.setattr(init_validate, "db", SimpleNamespace(engine=object())) + app.secret_key = "test-secret" + + with app.test_request_context("/console/api/init", method="GET"): + init_validate.session.pop("is_init_validated", None) + assert init_validate.get_init_validate_status() is False diff --git a/api/tests/unit_tests/controllers/console/test_remote_files.py b/api/tests/unit_tests/controllers/console/test_remote_files.py new file mode 100644 index 0000000000..1be402c8ab --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_remote_files.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import urllib.parse +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import httpx +import pytest + +from controllers.common.errors import FileTooLargeError, RemoteFileUploadError, UnsupportedFileTypeError +from controllers.console import remote_files as remote_files_module +from services.errors.file import FileTooLargeError as ServiceFileTooLargeError +from services.errors.file import UnsupportedFileTypeError as ServiceUnsupportedFileTypeError + + +def _unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class _FakeResponse: + def __init__( + self, + *, + status_code: int = 200, + headers: dict[str, str] | None = None, + method: str = "GET", + content: bytes = b"", + text: str = "", + error: Exception | None = None, + ) -> None: + self.status_code = status_code + self.headers = headers or {} + self.request = SimpleNamespace(method=method) + self.content = content + self.text = text + self._error = error + + def raise_for_status(self) -> None: + if self._error: + raise self._error + + +def _mock_upload_dependencies( + monkeypatch: pytest.MonkeyPatch, + *, + file_size_within_limit: bool = True, +): + file_info = SimpleNamespace( + filename="report.txt", + extension=".txt", + mimetype="text/plain", + size=3, + ) + monkeypatch.setattr( + remote_files_module.helpers, + "guess_file_info_from_response", + MagicMock(return_value=file_info), + ) + + file_service_cls = MagicMock() + file_service_cls.is_file_size_within_limit.return_value = file_size_within_limit + monkeypatch.setattr(remote_files_module, "FileService", file_service_cls) + monkeypatch.setattr(remote_files_module, "current_account_with_tenant", lambda: (SimpleNamespace(id="u1"), None)) + monkeypatch.setattr(remote_files_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr( + remote_files_module.file_helpers, + "get_signed_file_url", + lambda upload_file_id: f"https://signed.example/{upload_file_id}", + ) + + return file_service_cls + + +def test_get_remote_file_info_uses_head_when_successful(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.GetRemoteFileInfo() + handler = _unwrap(api.get) + decoded_url = "https://example.com/test.txt" + encoded_url = urllib.parse.quote(decoded_url, safe="") + + head_resp = _FakeResponse( + status_code=200, + headers={"Content-Type": "text/plain", "Content-Length": "128"}, + method="HEAD", + ) + head_mock = MagicMock(return_value=head_resp) + get_mock = MagicMock() + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", head_mock) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + with app.test_request_context(method="GET"): + payload = handler(api, url=encoded_url) + + assert payload == {"file_type": "text/plain", "file_length": 128} + head_mock.assert_called_once_with(decoded_url) + get_mock.assert_not_called() + + +def test_get_remote_file_info_falls_back_to_get_and_uses_default_headers(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.GetRemoteFileInfo() + handler = _unwrap(api.get) + decoded_url = "https://example.com/test.txt" + encoded_url = urllib.parse.quote(decoded_url, safe="") + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=503))) + get_mock = MagicMock(return_value=_FakeResponse(status_code=200, headers={}, method="GET")) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + with app.test_request_context(method="GET"): + payload = handler(api, url=encoded_url) + + assert payload == {"file_type": "application/octet-stream", "file_length": 0} + get_mock.assert_called_once_with(decoded_url, timeout=3) + + +def test_remote_file_upload_success_when_fetch_falls_back_to_get(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/report.txt" + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=404))) + get_resp = _FakeResponse(status_code=200, method="GET", content=b"fallback-content") + get_mock = MagicMock(return_value=get_resp) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + file_service_cls = _mock_upload_dependencies(monkeypatch) + upload_file = SimpleNamespace( + id="file-1", + name="report.txt", + size=16, + extension=".txt", + mime_type="text/plain", + created_by="u1", + created_at=datetime(2024, 1, 1, tzinfo=UTC), + ) + file_service_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context(method="POST", json={"url": url}): + payload, status = handler(api) + + assert status == 201 + assert payload["id"] == "file-1" + assert payload["url"] == "https://signed.example/file-1" + get_mock.assert_called_once_with(url=url, timeout=3, follow_redirects=True) + file_service_cls.return_value.upload_file.assert_called_once_with( + filename="report.txt", + content=b"fallback-content", + mimetype="text/plain", + user=SimpleNamespace(id="u1"), + source_url=url, + ) + + +def test_remote_file_upload_fetches_content_with_second_get_when_head_succeeds( + app, monkeypatch: pytest.MonkeyPatch +) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/photo.jpg" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="HEAD", content=b"head-content")), + ) + extra_get_resp = _FakeResponse(status_code=200, method="GET", content=b"downloaded-content") + get_mock = MagicMock(return_value=extra_get_resp) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", get_mock) + + file_service_cls = _mock_upload_dependencies(monkeypatch) + upload_file = SimpleNamespace( + id="file-2", + name="photo.jpg", + size=18, + extension=".jpg", + mime_type="image/jpeg", + created_by="u1", + created_at=datetime(2024, 1, 2, tzinfo=UTC), + ) + file_service_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context(method="POST", json={"url": url}): + payload, status = handler(api) + + assert status == 201 + assert payload["id"] == "file-2" + get_mock.assert_called_once_with(url) + assert file_service_cls.return_value.upload_file.call_args.kwargs["content"] == b"downloaded-content" + + +def test_remote_file_upload_raises_when_fallback_get_still_not_ok(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/fail.txt" + + monkeypatch.setattr(remote_files_module.ssrf_proxy, "head", MagicMock(return_value=_FakeResponse(status_code=500))) + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "get", + MagicMock(return_value=_FakeResponse(status_code=502, text="bad gateway")), + ) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: bad gateway"): + handler(api) + + +def test_remote_file_upload_raises_on_httpx_request_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/fail.txt" + + request = httpx.Request("HEAD", url) + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(side_effect=httpx.RequestError("network down", request=request)), + ) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(RemoteFileUploadError, match=f"Failed to fetch file from {url}: network down"): + handler(api) + + +def test_remote_file_upload_rejects_oversized_file(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/large.bin" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + + _mock_upload_dependencies(monkeypatch, file_size_within_limit=False) + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(FileTooLargeError): + handler(api) + + +def test_remote_file_upload_translates_service_file_too_large_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/large.bin" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls.return_value.upload_file.side_effect = ServiceFileTooLargeError("size exceeded") + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(FileTooLargeError, match="size exceeded"): + handler(api) + + +def test_remote_file_upload_translates_service_unsupported_type_error(app, monkeypatch: pytest.MonkeyPatch) -> None: + api = remote_files_module.RemoteFileUpload() + handler = _unwrap(api.post) + url = "https://example.com/file.exe" + + monkeypatch.setattr( + remote_files_module.ssrf_proxy, + "head", + MagicMock(return_value=_FakeResponse(status_code=200, method="GET", content=b"payload")), + ) + monkeypatch.setattr(remote_files_module.ssrf_proxy, "get", MagicMock()) + file_service_cls = _mock_upload_dependencies(monkeypatch) + file_service_cls.return_value.upload_file.side_effect = ServiceUnsupportedFileTypeError() + + with app.test_request_context(method="POST", json={"url": url}): + with pytest.raises(UnsupportedFileTypeError): + handler(api) diff --git a/api/tests/unit_tests/controllers/console/test_spec.py b/api/tests/unit_tests/controllers/console/test_spec.py new file mode 100644 index 0000000000..05a4befaa8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_spec.py @@ -0,0 +1,49 @@ +from unittest.mock import patch + +import controllers.console.spec as spec_module + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestSpecSchemaDefinitionsApi: + def test_get_success(self): + api = spec_module.SpecSchemaDefinitionsApi() + method = unwrap(api.get) + + schema_definitions = [{"type": "string"}] + + with patch.object( + spec_module, + "SchemaManager", + ) as schema_manager_cls: + schema_manager_cls.return_value.get_all_schema_definitions.return_value = schema_definitions + + resp, status = method(api) + + assert status == 200 + assert resp == schema_definitions + + def test_get_exception_returns_empty_list(self): + api = spec_module.SpecSchemaDefinitionsApi() + method = unwrap(api.get) + + with ( + patch.object( + spec_module, + "SchemaManager", + side_effect=Exception("boom"), + ), + patch.object( + spec_module.logger, + "exception", + ) as log_exception, + ): + resp, status = method(api) + + assert status == 200 + assert resp == [] + log_exception.assert_called_once() diff --git a/api/tests/unit_tests/controllers/console/test_version.py b/api/tests/unit_tests/controllers/console/test_version.py new file mode 100644 index 0000000000..8d8d324be1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_version.py @@ -0,0 +1,162 @@ +from unittest.mock import MagicMock, patch + +import controllers.console.version as version_module + + +class TestHasNewVersion: + def test_has_new_version_true(self): + result = version_module._has_new_version( + latest_version="1.2.0", + current_version="1.1.0", + ) + assert result is True + + def test_has_new_version_false(self): + result = version_module._has_new_version( + latest_version="1.0.0", + current_version="1.1.0", + ) + assert result is False + + def test_has_new_version_invalid_version(self): + with patch.object(version_module.logger, "warning") as log_warning: + result = version_module._has_new_version( + latest_version="invalid", + current_version="1.0.0", + ) + + assert result is False + log_warning.assert_called_once() + + +class TestCheckVersionUpdate: + def test_no_check_update_url(self): + query = version_module.VersionQuery(current_version="1.0.0") + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "", + ), + patch.object( + version_module.dify_config.project, + "version", + "1.0.0", + ), + patch.object( + version_module.dify_config, + "CAN_REPLACE_LOGO", + True, + ), + patch.object( + version_module.dify_config, + "MODEL_LB_ENABLED", + False, + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.0.0" + assert result.can_auto_update is False + assert result.features.can_replace_logo is True + assert result.features.model_load_balancing_enabled is False + + def test_http_error_fallback(self): + query = version_module.VersionQuery(current_version="1.0.0") + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + side_effect=Exception("boom"), + ), + patch.object( + version_module.logger, + "warning", + ) as log_warning, + ): + result = version_module.check_version_update(query) + + assert result.version == "1.0.0" + log_warning.assert_called_once() + + def test_new_version_available(self): + query = version_module.VersionQuery(current_version="1.0.0") + + response = MagicMock() + response.json.return_value = { + "version": "1.2.0", + "releaseDate": "2024-01-01", + "releaseNotes": "New features", + "canAutoUpdate": True, + } + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + return_value=response, + ), + patch.object( + version_module.dify_config.project, + "version", + "1.0.0", + ), + patch.object( + version_module.dify_config, + "CAN_REPLACE_LOGO", + False, + ), + patch.object( + version_module.dify_config, + "MODEL_LB_ENABLED", + True, + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.2.0" + assert result.release_date == "2024-01-01" + assert result.release_notes == "New features" + assert result.can_auto_update is True + + def test_no_new_version(self): + query = version_module.VersionQuery(current_version="1.2.0") + + response = MagicMock() + response.json.return_value = { + "version": "1.1.0", + } + + with ( + patch.object( + version_module.dify_config, + "CHECK_UPDATE_URL", + "http://example.com", + ), + patch.object( + version_module.httpx, + "get", + return_value=response, + ), + patch.object( + version_module.dify_config.project, + "version", + "1.2.0", + ), + ): + result = version_module.check_version_update(query) + + assert result.version == "1.2.0" + assert result.can_auto_update is False diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py index 6777077de8..f6e096a97b 100644 --- a/api/tests/unit_tests/controllers/console/test_wraps.py +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -328,7 +328,7 @@ class TestSystemSetup: def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = "some_password" @setup_required @@ -345,7 +345,7 @@ class TestSystemSetup: def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): """Test NotSetupError when no INIT_PASSWORD and setup not complete""" # Arrange - mock_db.session.query.return_value.first.return_value = None # No setup + mock_db.session.scalar.return_value = None # No setup mock_environ_get.return_value = None # No INIT_PASSWORD @setup_required diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py new file mode 100644 index 0000000000..42be02cdaf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -0,0 +1,341 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, +) +from controllers.console.error import AccountInFreezeError +from controllers.console.workspace.account import ( + AccountAvatarApi, + AccountDeleteApi, + AccountDeleteVerifyApi, + AccountInitApi, + AccountIntegrateApi, + AccountInterfaceLanguageApi, + AccountInterfaceThemeApi, + AccountNameApi, + AccountPasswordApi, + AccountProfileApi, + AccountTimezoneApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + CheckEmailUnique, +) +from controllers.console.workspace.error import ( + AccountAlreadyInitedError, + CurrentPasswordIncorrectError, + InvalidAccountDeletionCodeError, +) +from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAccountInitApi: + def test_init_success(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="inactive") + payload = { + "interface_language": "en-US", + "timezone": "UTC", + "invitation_code": "code123", + } + + with ( + app.test_request_context("/account/init", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.commit", return_value=None), + patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), + patch("controllers.console.workspace.account.db.session.scalar") as scalar_mock, + ): + scalar_mock.return_value = MagicMock(status="unused") + resp = method(api) + + assert resp["result"] == "success" + + def test_init_already_initialized(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="active") + + with ( + app.test_request_context("/account/init"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + ): + with pytest.raises(AccountAlreadyInitedError): + method(api) + + +class TestAccountProfileApi: + def test_get_profile_success(self, app): + api = AccountProfileApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/account/profile"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountUpdateApis: + @pytest.mark.parametrize( + ("api_cls", "payload"), + [ + (AccountNameApi, {"name": "test"}), + (AccountAvatarApi, {"avatar": "img.png"}), + (AccountInterfaceLanguageApi, {"interface_language": "en-US"}), + (AccountInterfaceThemeApi, {"interface_theme": "dark"}), + (AccountTimezoneApi, {"timezone": "UTC"}), + ], + ) + def test_update_success(self, app, api_cls, payload): + api = api_cls() + method = unwrap(api.post) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account", return_value=user), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountPasswordApi: + def test_password_success(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "old", + "new_password": "new123", + "repeat_new_password": "new123", + } + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None), + ): + result = method(api) + + assert result["id"] == "u1" + + def test_password_wrong_current(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "bad", + "new_password": "new123", + "repeat_new_password": "new123", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.update_account_password", + side_effect=ServicePwdError(), + ), + ): + with pytest.raises(CurrentPasswordIncorrectError): + method(api) + + +class TestAccountIntegrateApi: + def test_get_integrates(self, app): + api = AccountIntegrateApi() + method = unwrap(api.get) + + account = MagicMock(id="acc1") + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock, + ): + scalars_mock.return_value.all.return_value = [] + result = method(api) + + assert "data" in result + assert len(result["data"]) == 2 + + +class TestAccountDeleteApi: + def test_delete_verify_success(self, app): + api = AccountDeleteVerifyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code", + return_value=("token", "1234"), + ), + patch( + "controllers.console.workspace.account.AccountService.send_account_deletion_verification_email", + return_value=None, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_delete_invalid_code(self, app): + api = AccountDeleteApi() + method = unwrap(api.post) + + payload = {"token": "t", "code": "x"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.verify_account_deletion_code", + return_value=False, + ), + ): + with pytest.raises(InvalidAccountDeletionCodeError): + method(api) + + +class TestChangeEmailApis: + def test_check_email_code_invalid(self, app): + api = ChangeEmailCheckApi() + method = unwrap(api.post) + + payload = {"email": "a@test.com", "code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.account.AccountService.get_change_email_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_reset_email_already_used(self, app): + api = ChangeEmailResetApi() + method = unwrap(api.post) + + payload = {"new_email": "x@test.com", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False), + ): + with pytest.raises(EmailAlreadyInUseError): + method(api) + + +class TestCheckEmailUniqueApi: + def test_email_unique_success(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "ok@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True), + ): + result = method(api) + + assert result["result"] == "success" + + def test_email_in_freeze(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "x@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True), + ): + with pytest.raises(AccountInFreezeError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py new file mode 100644 index 0000000000..b4e03f681d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -0,0 +1,139 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.error import AccountNotFound +from controllers.console.workspace.agent_providers import ( + AgentProviderApi, + AgentProviderListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAgentProviderListApi: + def test_get_success(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + providers = [{"name": "openai"}, {"name": "anthropic"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=providers, + ), + ): + result = method(api) + + assert result == providers + + def test_get_empty_list(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=[], + ), + ): + result = method(api) + + assert result == [] + + def test_get_account_not_found(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api) + + +class TestAgentProviderApi: + def test_get_success(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "openai" + provider_data = {"name": "openai", "models": ["gpt-4"]} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=provider_data, + ), + ): + result = method(api, provider_name) + + assert result == provider_data + + def test_get_provider_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "unknown" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=None, + ), + ): + result = method(api, provider_name) + + assert result is None + + def test_get_account_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api, "openai") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py new file mode 100644 index 0000000000..51f76af172 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -0,0 +1,305 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.workspace.endpoint import ( + EndpointCreateApi, + EndpointDeleteApi, + EndpointDisableApi, + EndpointEnableApi, + EndpointListApi, + EndpointListForSinglePluginApi, + EndpointUpdateApi, +) +from core.plugin.impl.exc import PluginPermissionDeniedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user_and_tenant(): + return MagicMock(id="u1"), "t1" + + +@pytest.fixture +def patch_current_account(user_and_tenant): + with patch( + "controllers.console.workspace.endpoint.current_account_with_tenant", + return_value=user_and_tenant, + ): + yield + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointCreateApi: + def test_create_success(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {"a": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_create_permission_denied(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.endpoint.EndpointService.create_endpoint", + side_effect=PluginPermissionDeniedError("denied"), + ), + ): + with pytest.raises(ValueError): + method(api) + + def test_create_validation_error(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p1", + "name": "", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListApi: + def test_list_success(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]), + ): + result = method(api) + + assert "endpoints" in result + assert len(result["endpoints"]) == 1 + + def test_list_invalid_query(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=0&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListForSinglePluginApi: + def test_list_for_plugin_success(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10&plugin_id=p1"), + patch( + "controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin", + return_value=[{"id": "e1"}], + ), + ): + result = method(api) + + assert "endpoints" in result + + def test_list_for_plugin_missing_param(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDeleteApi: + def test_delete_success(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_delete_invalid_payload(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_delete_service_failure(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointUpdateApi: + def test_update_success(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "new-name", + "settings": {"x": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_update_validation_error(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1", "settings": {}} + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + def test_update_service_failure(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "n", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointEnableApi: + def test_enable_success(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_enable_invalid_payload(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_enable_service_failure(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDisableApi: + def test_disable_success(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_disable_invalid_payload(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index 59b6614d5e..f2e57eb65f 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -13,8 +13,8 @@ from flask import Flask from flask.views import MethodView from werkzeug.exceptions import Forbidden -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.errors.validate import CredentialsValidateFailedError +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py new file mode 100644 index 0000000000..718b57ba6b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -0,0 +1,607 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import HTTPException + +import services +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded +from controllers.console.workspace.members import ( + DatasetOperatorMemberListApi, + MemberCancelInviteApi, + MemberInviteEmailApi, + MemberListApi, + MemberUpdateRoleApi, + OwnerTransfer, + OwnerTransferCheckApi, + SendOwnerTransferEmailApi, +) +from services.errors.account import AccountAlreadyInTenantError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestMemberListApi: + def test_get_success(self, app): + api = MemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "m1" + member.name = "Member" + member.email = "member@test.com" + member.avatar = "avatar.png" + member.role = "admin" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = MemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestMemberInviteEmailApi: + def test_invite_success(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + "language": "en-US", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert status == 201 + assert result["result"] == "success" + + def test_invite_limit_exceeded(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = False + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + ): + with pytest.raises(WorkspaceMembersLimitExceeded): + method(api) + + def test_invite_already_member(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=AccountAlreadyInTenantError(), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert result["invitation_results"][0]["status"] == "success" + + def test_invite_invalid_role(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + payload = { + "emails": ["a@test.com"], + "role": "owner", + } + + with app.test_request_context("/", json=payload): + result, status = method(api) + + assert status == 400 + assert result["code"] == "invalid-role" + + def test_invite_generic_exception(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=Exception("boom"), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, _ = method(api) + + assert result["invitation_results"][0]["status"] == "failed" + + +class TestMemberCancelInviteApi: + def test_cancel_success(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get") as get_mock, + patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), + ): + get_mock.return_value = member + result, status = method(api, member.id) + + assert status == 200 + assert result["result"] == "success" + + def test_cancel_not_found(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get") as get_mock, + ): + get_mock.return_value = None + + with pytest.raises(HTTPException): + method(api, "x") + + def test_cancel_cannot_operate_self(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get") as get_mock, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.CannotOperateSelfError("x"), + ), + ): + get_mock.return_value = member + result, status = method(api, member.id) + + assert status == 400 + + def test_cancel_no_permission(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get") as get_mock, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.NoPermissionError("x"), + ), + ): + get_mock.return_value = member + result, status = method(api, member.id) + + assert status == 403 + + def test_cancel_member_not_in_tenant(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get") as get_mock, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.MemberNotInTenantError(), + ), + ): + get_mock.return_value = member + result, status = method(api, member.id) + + assert status == 404 + + +class TestMemberUpdateRoleApi: + def test_update_success(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.update_member_role"), + ): + result = method(api, "id") + + if isinstance(result, tuple): + result = result[0] + + assert result["result"] == "success" + + def test_update_invalid_role(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "invalid-role"} + + with app.test_request_context("/", json=payload): + result, status = method(api, "id") + + assert status == 400 + + def test_update_member_not_found(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.members.current_account_with_tenant", + return_value=(MagicMock(current_tenant=MagicMock()), "t1"), + ), + patch("controllers.console.workspace.members.db.session.get", return_value=None), + ): + with pytest.raises(HTTPException): + method(api, "id") + + +class TestDatasetOperatorMemberListApi: + def test_get_success(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "op1" + member.name = "Operator" + member.email = "operator@test.com" + member.avatar = "avatar.png" + member.role = "operator" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members + ), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestSendOwnerTransferEmailApi: + def test_send_success(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(name="ws") + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token" + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_send_ip_limit(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True), + ): + with pytest.raises(EmailSendIpLimitError): + method(api) + + def test_send_not_owner(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/", json={}), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False), + ): + with pytest.raises(NotOwnerError): + method(api) + + +class TestOwnerTransferCheckApi: + def test_check_invalid_code(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_rate_limited(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=True, + ), + ): + with pytest.raises(OwnerTransferLimitError): + method(api) + + def test_invalid_token(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api) + + def test_invalid_email(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "b@test.com", "code": "x"}, + ), + ): + with pytest.raises(InvalidEmailError): + method(api) + + +class TestOwnerTransferApi: + def test_transfer_self(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + ): + with pytest.raises(CannotTransferOwnerToSelfError): + method(api, "1") + + def test_invalid_token(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api, "2") + + def test_member_not_in_tenant(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + member = MagicMock() + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com"}, + ), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.is_member", return_value=False), + ): + with pytest.raises(MemberNotInTenantError): + method(api, "2") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py new file mode 100644 index 0000000000..af0c2c5594 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -0,0 +1,388 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic_core import ValidationError +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.model_providers import ( + ModelProviderCredentialApi, + ModelProviderCredentialSwitchApi, + ModelProviderIconApi, + ModelProviderListApi, + ModelProviderPaymentCheckoutUrlApi, + ModelProviderValidateApi, + PreferredProviderTypeUpdateApi, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" +INVALID_UUID = "123" + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestModelProviderListApi: + def test_get_success(self, app): + api = ModelProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?model_type=llm"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_list", + return_value=[{"name": "openai"}], + ), + ): + result = method(api) + + assert "data" in result + + +class TestModelProviderCredentialApi: + def test_get_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={VALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential", + return_value={"key": "value"}, + ), + ): + result = method(api, provider="openai") + + assert "credentials" in result + + def test_get_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={INVALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_post_create_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}, "name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 201 + + def test_post_create_validation_error(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_put_update_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_put_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_delete_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.delete) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 204 + + +class TestModelProviderCredentialSwitchApi: + def test_switch_success(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_switch_invalid_uuid(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": INVALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderValidateApi: + def test_validate_success(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_validate_failure(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "error" + + +class TestModelProviderIconApi: + def test_icon_success(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(b"123", "image/png"), + ), + ): + response = api.get("t1", "openai", "logo", "en") + + assert response.mimetype == "image/png" + + def test_icon_not_found(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(None, None), + ), + ): + with pytest.raises(ValueError): + api.get("t1", "openai", "logo", "en") + + +class TestPreferredProviderTypeUpdateApi: + def test_update_success(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "custom"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_invalid_enum(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderPaymentCheckoutUrlApi: + def test_checkout_success(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + return_value=None, + ), + patch( + "controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link", + return_value={"url": "x"}, + ), + ): + result = method(api, provider="anthropic") + + assert "url" in result + + def test_invalid_provider(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_permission_denied(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + side_effect=Forbidden(), + ), + ): + with pytest.raises(Forbidden): + method(api, provider="anthropic") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py new file mode 100644 index 0000000000..43b8e1ac2e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -0,0 +1,447 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.workspace.models import ( + DefaultModelApi, + ModelProviderAvailableModelApi, + ModelProviderModelApi, + ModelProviderModelCredentialApi, + ModelProviderModelCredentialSwitchApi, + ModelProviderModelDisableApi, + ModelProviderModelEnableApi, + ModelProviderModelParameterRuleApi, + ModelProviderModelValidateApi, +) +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDefaultModelApi: + def test_get_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={"model_type": ModelType.LLM.value}, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"} + + result = method(api) + + assert "data" in result + + def test_post_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.post) + + payload = { + "model_settings": [ + { + "model_type": ModelType.LLM.value, + "provider": "openai", + "model": "gpt-4", + } + ] + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api) + + assert result["result"] == "success" + + def test_get_returns_empty_when_no_default(self, app): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_default_model_of_model_type.return_value = None + + result = method(api) + + assert "data" in result + + +class TestModelProviderModelApi: + def test_get_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_post_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "load_balancing": { + "configs": [{"weight": 1}], + "enabled": True, + }, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + patch("controllers.console.workspace.models.ModelLoadBalancingService"), + ): + result, status = method(api, "openai") + + assert status == 200 + + def test_delete_model_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + def test_get_models_returns_empty(self, app): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + +class TestModelProviderModelCredentialApi: + def test_get_credentials_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={ + "model": "gpt-4", + "model_type": ModelType.LLM.value, + }, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as provider_service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service, + ): + provider_service.return_value.get_model_credential.return_value = { + "credentials": {}, + "current_credential_id": None, + "current_credential_name": None, + } + provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb_service.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert "credentials" in result + + def test_create_credential_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 201 + + def test_get_empty_credentials(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb, + ): + service.return_value.get_model_credential.return_value = None + service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert result["credentials"] == {} + + def test_delete_success(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt", + "model_type": ModelType.LLM.value, + "credential_id": "123e4567-e89b-12d3-a456-426614174000", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + +class TestModelProviderModelCredentialSwitchApi: + def test_switch_success(self, app: Flask): + api = ModelProviderModelCredentialSwitchApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credential_id": "abc", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelEnableDisableApis: + def test_enable_model(self, app: Flask): + api = ModelProviderModelEnableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + def test_disable_model(self, app: Flask): + api = ModelProviderModelDisableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelProviderModelValidateApi: + def test_validate_success(self, app: Flask): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + @pytest.mark.parametrize("model_name", ["gpt-4", "gpt"]) + def test_validate_failure(self, app: Flask, model_name: str): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": model_name, + "model_type": ModelType.LLM.value, + "credentials": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid") + + result = method(api, "openai") + + assert result["result"] == "error" + + +class TestParameterAndAvailableModels: + def test_parameter_rules(self, app: Flask): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt-4"}), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_available_models(self, app: Flask): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert "data" in result + + def test_empty_rules(self, app): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt"}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert result["data"] == [] + + def test_no_models(self, app): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert result["data"] == [] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py new file mode 100644 index 0000000000..eb19243225 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -0,0 +1,1025 @@ +import io +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.plugin import ( + PluginAssetApi, + PluginAutoUpgradeExcludePluginApi, + PluginChangePermissionApi, + PluginChangePreferencesApi, + PluginDebuggingKeyApi, + PluginDeleteAllInstallTaskItemsApi, + PluginDeleteInstallTaskApi, + PluginDeleteInstallTaskItemApi, + PluginFetchDynamicSelectOptionsApi, + PluginFetchDynamicSelectOptionsWithCredentialsApi, + PluginFetchInstallTaskApi, + PluginFetchInstallTasksApi, + PluginFetchManifestApi, + PluginFetchMarketplacePkgApi, + PluginFetchPermissionApi, + PluginFetchPreferencesApi, + PluginIconApi, + PluginInstallFromGithubApi, + PluginInstallFromMarketplaceApi, + PluginInstallFromPkgApi, + PluginListApi, + PluginListInstallationsFromIdsApi, + PluginListLatestVersionsApi, + PluginReadmeApi, + PluginUninstallApi, + PluginUpgradeFromGithubApi, + PluginUpgradeFromMarketplaceApi, + PluginUploadFromBundleApi, + PluginUploadFromGithubApi, + PluginUploadFromPkgApi, +) +from core.plugin.impl.exc import PluginDaemonClientSideError +from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user(): + u = MagicMock() + u.id = "u1" + u.is_admin_or_owner = True + return u + + +@pytest.fixture +def tenant(): + return "t1" + + +class TestPluginListLatestVersionsApi: + def test_success(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", return_value={"p1": "1.0"} + ), + ): + result = method(api) + + assert "versions" in result + + def test_daemon_error(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDebuggingKeyApi: + def test_debugging_key_success(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"), + ): + result = method(api) + + assert result["key"] == "k" + + def test_debugging_key_error(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.get_debugging_key", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginListApi: + def test_plugin_list(self, app): + api = PluginListApi() + method = unwrap(api.get) + + mock_list = MagicMock(list=[{"id": 1}], total=1) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list), + ): + result = method(api) + + assert result["total"] == 1 + + +class TestPluginIconApi: + def test_plugin_icon(self, app): + api = PluginIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?tenant_id=t1&filename=a.png"), + patch("controllers.console.workspace.plugin.PluginService.get_asset", return_value=(b"x", "image/png")), + ): + response = method(api) + + assert response.mimetype == "image/png" + + +class TestPluginAssetApi: + def test_plugin_asset(self, app): + api = PluginAssetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"), + ): + response = method(api) + + assert response.mimetype == "application/octet-stream" + + +class TestPluginUploadFromPkgApi: + def test_upload_pkg_success(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_upload_pkg_too_large(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, + ): + with pytest.raises(ValueError): + method(api) + + upload_pkg_mock.assert_not_called() + + +class TestPluginInstallFromPkgApi: + def test_install_from_pkg(self, app): + api = PluginInstallFromPkgApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + +class TestPluginUninstallApi: + def test_uninstall(self, app): + api = PluginUninstallApi() + method = unwrap(api.post) + + payload = {"plugin_installation_id": "x"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginChangePermissionApi: + def test_change_permission_forbidden(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=False) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(Forbidden): + method(api) + + def test_change_permission_success(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginFetchPermissionApi: + def test_fetch_permission_default(self, app): + api = PluginFetchPermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None), + ): + result = method(api) + + assert result["install_permission"] is not None + + +class TestPluginFetchDynamicSelectOptionsApi: + def test_fetch_dynamic_options(self, app, user): + api = PluginFetchDynamicSelectOptionsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_id=p&provider=x&action=y¶meter=z&provider_type=tool"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options", + return_value=[1, 2], + ), + ): + result = method(api) + + assert result["options"] == [1, 2] + + +class TestPluginReadmeApi: + def test_fetch_readme(self, app): + api = PluginReadmeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"), + ): + result = method(api) + + assert result["readme"] == "readme" + + +class TestPluginListInstallationsFromIdsApi: + def test_success(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1", "p2"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + return_value=[{"id": "p1"}], + ), + ): + result = method(api) + + assert "plugins" in result + + def test_daemon_error(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromGithubApi: + def test_success(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromBundleApi: + def test_success(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_too_large(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, + ): + with pytest.raises(ValueError): + method(api) + + upload_bundle_mock.assert_not_called() + + +class TestPluginInstallFromGithubApi: + def test_success(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromMarketplaceApi: + def test_success(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchMarketplacePkgApi: + def test_success(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchManifestApi: + def test_success(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + manifest = MagicMock() + manifest.model_dump.return_value = {"x": 1} + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTasksApi: + def test_success(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]), + ): + result = method(api) + + assert "tasks" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_tasks", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTaskApi: + def test_success(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}), + ): + result = method(api, "x") + + assert "task" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteInstallTaskApi: + def test_success(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True), + ): + result = method(api, "x") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteAllInstallTaskItemsApi: + def test_success(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True + ), + ): + result = method(api) + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDeleteInstallTaskItemApi: + def test_success(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True), + ): + result = method(api, "task1", "item1") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task_item", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "task1", "item1") + + +class TestPluginUpgradeFromMarketplaceApi: + def test_success(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUpgradeFromGithubApi: + def test_success(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: + def test_success(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + return_value=[1], + ), + ): + result = method(api) + + assert result["options"] == [1] + + def test_daemon_error(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginChangePreferencesApi: + def test_success(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_permission_fail(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +class TestPluginFetchPreferencesApi: + def test_success(self, app): + api = PluginFetchPreferencesApi() + method = unwrap(api.get) + + permission = MagicMock( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + + auto_upgrade = MagicMock( + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=1, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=[], + include_plugins=[], + ) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission + ), + patch( + "controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade + ), + ): + result = method(api) + + assert "permission" in result + assert "auto_upgrade" in result + + +class TestPluginAutoUpgradeExcludePluginApi: + def test_success(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_fail(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False), + ): + result = method(api) + + assert result["success"] is False diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index b15676d9b7..16ea1bf509 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Forbidden -from controllers.console.workspace.tool_providers import ToolProviderMCPApi +from controllers.console.workspace.tool_providers import ( + ToolApiListApi, + ToolApiProviderAddApi, + ToolApiProviderDeleteApi, + ToolApiProviderGetApi, + ToolApiProviderGetRemoteSchemaApi, + ToolApiProviderListToolsApi, + ToolApiProviderUpdateApi, + ToolBuiltinListApi, + ToolBuiltinProviderAddApi, + ToolBuiltinProviderCredentialsSchemaApi, + ToolBuiltinProviderDeleteApi, + ToolBuiltinProviderGetCredentialInfoApi, + ToolBuiltinProviderGetCredentialsApi, + ToolBuiltinProviderGetOauthClientSchemaApi, + ToolBuiltinProviderIconApi, + ToolBuiltinProviderInfoApi, + ToolBuiltinProviderListToolsApi, + ToolBuiltinProviderSetDefaultApi, + ToolBuiltinProviderUpdateApi, + ToolLabelsApi, + ToolOAuthCallback, + ToolOAuthCustomClient, + ToolPluginOAuthApi, + ToolProviderListApi, + ToolProviderMCPApi, + ToolWorkflowListApi, + ToolWorkflowProviderCreateApi, + ToolWorkflowProviderDeleteApi, + ToolWorkflowProviderGetApi, + ToolWorkflowProviderUpdateApi, + is_valid_url, +) from core.db.session_factory import configure_session_factory from extensions.ext_database import db from services.tools.mcp_tools_manage_service import ReconnectResult -# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file. -# They are intentionally no-ops because the test already patches the required -# behaviors explicitly via @patch and context managers below. +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + @pytest.fixture def _mock_cache(): return @@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ # 若 transform 后包含 tools 字段,确保非空 assert isinstance(body.get("tools"), list) assert body["tools"] + + +class TestUtils: + def test_is_valid_url(self): + assert is_valid_url("https://example.com") + assert is_valid_url("http://example.com") + assert not is_valid_url("") + assert not is_valid_url("ftp://example.com") + assert not is_valid_url("not-a-url") + assert not is_valid_url(None) + + +class TestToolProviderListApi: + def test_get_success(self, app): + api = ToolProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers", + return_value=["p1"], + ), + ): + assert method(api) == ["p1"] + + +class TestBuiltinProviderApis: + def test_list_tools(self, app): + api = ToolBuiltinProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools", + return_value=[{"a": 1}], + ), + ): + assert method(api, "provider") == [{"a": 1}] + + def test_info(self, app): + api = ToolBuiltinProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info", + return_value={"x": 1}, + ), + ): + assert method(api, "provider") == {"x": 1} + + def test_delete(self, app): + api = ToolBuiltinProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_id": "cid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api, "provider")["result"] == "success" + + def test_add_invalid_type(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}, "type": "invalid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api, "provider") + + def test_add_success(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + payload = {"credentials": {}, "type": "oauth2", "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api, "provider")["id"] == 1 + + def test_update(self, app): + api = ToolBuiltinProviderUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "c1", "credentials": {}, "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credentials(self, app): + api = ToolBuiltinProviderGetCredentialsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials", + return_value={"k": "v"}, + ), + ): + assert method(api, "provider") == {"k": "v"} + + def test_icon(self, app): + api = ToolBuiltinProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon", + return_value=(b"x", "image/png"), + ), + ): + response = method(api, "provider") + assert response.mimetype == "image/png" + + def test_credentials_schema(self, app): + api = ToolBuiltinProviderCredentialsSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider", "oauth2") == {"schema": {}} + + def test_set_default_credential(self, app): + api = ToolBuiltinProviderSetDefaultApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"id": "c1"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credential_info(self, app): + api = ToolBuiltinProviderGetCredentialInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info", + return_value={"info": "x"}, + ), + ): + assert method(api, "provider") == {"info": "x"} + + def test_get_oauth_client_schema(self, app): + api = ToolBuiltinProviderGetOauthClientSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider") == {"schema": {}} + + +class TestApiProviderApis: + def test_add(self, app): + api = ToolApiProviderAddApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_remote_schema(self, app): + api = ToolApiProviderGetRemoteSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?url=http://x.com"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema", + return_value={"schema": "x"}, + ), + ): + assert method(api)["schema"] == "x" + + def test_list_tools(self, app): + api = ToolApiProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools", + return_value=[{"tool": 1}], + ), + ): + assert method(api) == [{"tool": 1}] + + def test_update(self, app): + api = ToolApiProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "original_provider": "o", + "icon": {}, + "privacy_policy": "", + "custom_disclaimer": "", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_delete(self, app): + api = ToolApiProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"provider": "p"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api)["result"] == "success" + + def test_get(self, app): + api = ToolApiProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider", + return_value={"x": 1}, + ), + ): + assert method(api) == {"x": 1} + + +class TestWorkflowApis: + def test_create(self, app): + api = ToolWorkflowProviderCreateApi() + method = unwrap(api.post) + + payload = { + "workflow_app_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "n", + "label": "l", + "description": "d", + "icon": {}, + "parameters": [], + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_update_invalid(self, app): + api = ToolWorkflowProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "Tool", + "label": "Tool Label", + "description": "A tool", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool", + return_value={"ok": True}, + ), + ): + result = method(api) + assert result["ok"] + + def test_delete(self, app): + api = ToolWorkflowProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_get_error(self, app): + api = ToolWorkflowProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestLists: + def test_builtin_list(self, app): + api = ToolBuiltinListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_api_list(self, app): + api = ToolApiListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_workflow_list(self, app): + api = ToolWorkflowListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + +class TestLabels: + def test_labels(self, app): + api = ToolLabelsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels", + return_value=["l1"], + ), + ): + assert method(api) == ["l1"] + + +class TestOAuth: + def test_oauth_no_client(self, app): + api = ToolPluginOAuthApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "provider") + + def test_oauth_callback_no_cookie(self, app): + api = ToolOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "provider") + + +class TestOAuthCustomClient: + def test_save_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"client_params": {"a": 1}}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params", + return_value={"client_id": "x"}, + ), + ): + assert method(api, "provider") == {"client_id": "x"} + + def test_delete_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py new file mode 100644 index 0000000000..4776bc7af0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py @@ -0,0 +1,558 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden + +from controllers.console.workspace.trigger_providers import ( + TriggerOAuthAuthorizeApi, + TriggerOAuthCallbackApi, + TriggerOAuthClientManageApi, + TriggerProviderIconApi, + TriggerProviderInfoApi, + TriggerProviderListApi, + TriggerSubscriptionBuilderBuildApi, + TriggerSubscriptionBuilderCreateApi, + TriggerSubscriptionBuilderGetApi, + TriggerSubscriptionBuilderLogsApi, + TriggerSubscriptionBuilderUpdateApi, + TriggerSubscriptionBuilderVerifyApi, + TriggerSubscriptionDeleteApi, + TriggerSubscriptionListApi, + TriggerSubscriptionUpdateApi, + TriggerSubscriptionVerifyApi, +) +from controllers.web.error import NotFoundError +from core.plugin.entities.plugin_daemon import CredentialType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def mock_user(): + user = MagicMock(spec=Account) + user.id = "u1" + user.current_tenant_id = "t1" + return user + + +class TestTriggerProviderApis: + def test_icon_success(self, app): + api = TriggerProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon", + return_value="icon", + ), + ): + assert method(api, "github") == "icon" + + def test_list_providers(self, app): + api = TriggerProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers", + return_value=[], + ), + ): + assert method(api) == [] + + def test_provider_info(self, app): + api = TriggerProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider", + return_value={"id": "p1"}, + ), + ): + assert method(api, "github") == {"id": "p1"} + + +class TestTriggerSubscriptionListApi: + def test_list_success(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + return_value=[], + ), + ): + assert method(api, "github") == [] + + def test_list_invalid_provider(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + side_effect=ValueError("bad"), + ), + ): + result, status = method(api, "bad") + assert status == 404 + + +class TestTriggerSubscriptionBuilderApis: + def test_create_builder(self, app): + api = TriggerSubscriptionBuilderCreateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + result = method(api, "github") + assert "subscription_builder" in result + + def test_get_builder(self, app): + api = TriggerSubscriptionBuilderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_verify_builder(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {"a": 1}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "b1") == {"ok": True} + + def test_verify_builder_error(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + side_effect=Exception("err"), + ), + ): + with pytest.raises(ValueError): + method(api, "github", "b1") + + def test_update_builder(self, app): + api = TriggerSubscriptionBuilderUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "n"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_logs(self, app): + api = TriggerSubscriptionBuilderLogsApi() + method = unwrap(api.get) + + log = MagicMock() + log.model_dump.return_value = {"a": 1} + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs", + return_value=[log], + ), + ): + assert "logs" in method(api, "github", "b1") + + def test_build(self, app): + api = TriggerSubscriptionBuilderBuildApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder", + return_value=None, + ), + ): + assert method(api, "github", "b1") == 200 + + +class TestTriggerSubscriptionCrud: + def test_update_rename_only(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.UNAUTHORIZED + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"), + ): + assert method(api, "s1") == 200 + + def test_update_not_found(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "x") + + def test_update_rebuild(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.OAUTH2 + sub.credentials = {} + sub.parameters = {} + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription" + ), + ): + assert method(api, "s1") == 200 + + def test_delete_subscription(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + mock_session = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls, + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription" + ), + ): + mock_db.engine = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + result = method(api, "sub1") + + assert result["result"] == "success" + + def test_delete_subscription_value_error(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as session_cls, + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider", + side_effect=ValueError("bad"), + ), + ): + mock_db.engine = MagicMock() + session_cls.return_value.__enter__.return_value = MagicMock() + + with pytest.raises(BadRequest): + method(api, "sub1") + + +class TestTriggerOAuthApis: + def test_oauth_authorize_success(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value=MagicMock(id="b1"), + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context", + return_value="ctx", + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url", + return_value=MagicMock(authorization_url="url"), + ), + ): + resp = method(api, "github") + assert resp.status_code == 200 + + def test_oauth_authorize_no_client(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "github") + + def test_oauth_callback_forbidden(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_success(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials={"a": 1}, expires_at=1), + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder" + ), + ): + resp = method(api, "github") + assert resp.status_code == 302 + + def test_oauth_callback_no_oauth_client(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_empty_credentials(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials=None, expires_at=None), + ), + ): + with pytest.raises(ValueError): + method(api, "github") + + +class TestTriggerOAuthClientManageApi: + def test_get_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params", + return_value={}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled", + return_value=False, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists", + return_value=True, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider", + return_value=MagicMock(get_oauth_client_schema=lambda: {}), + ), + ): + result = method(api, "github") + assert "configured" in result + + def test_post_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_delete_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_oauth_client_post_value_error(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + side_effect=ValueError("bad"), + ), + ): + with pytest.raises(BadRequest): + method(api, "github") + + +class TestTriggerSubscriptionVerifyApi: + def test_verify_success(self, app): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "s1") == {"ok": True} + + @pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")]) + def test_verify_errors(self, app, raised_exception): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + side_effect=raised_exception, + ), + ): + with pytest.raises(BadRequest): + method(api, "github", "s1") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py new file mode 100644 index 0000000000..f5ebe0b534 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -0,0 +1,796 @@ +from datetime import datetime +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Unauthorized + +import services +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.error import AccountNotLinkTenantError +from controllers.console.workspace.workspace import ( + CustomConfigWorkspaceApi, + SwitchWorkspaceApi, + TenantApi, + TenantListApi, + WebappLogoWorkspaceApi, + WorkspaceInfoApi, + WorkspaceListApi, + WorkspacePermissionApi, +) +from enums.cloud_plan import CloudPlan +from models.account import TenantStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestTenantListApi: + def test_get_success_saas_path(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={ + "t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}, + "t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0}, + }, + ) as get_plan_bulk_mock, + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert len(result["workspaces"]) == 2 + assert result["workspaces"][0]["current"] is True + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_not_called() + + def test_get_saas_path_partial_fallback_does_not_gate_plan_on_billing_enabled(self, app): + """Bulk omits a tenant: resolve plan via subscription.plan only; billing.enabled is not used. + + billing.enabled is mocked False to prove the endpoint does not gate on it for this path + (SaaS contract treats enabled as on; display follows subscription.plan). + """ + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + features_t2 = MagicMock() + features_t2.billing.enabled = False + features_t2.billing.subscription.plan = CloudPlan.PROFESSIONAL + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}}, + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features_t2, + ) as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + get_features_mock.assert_called_once_with("t2") + + def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app): + """Test fallback to FeatureService when bulk billing returns empty result. + + BillingService.get_plan_bulk catches exceptions internally and returns empty dict, + so we simulate the real failure mode by returning empty dict for non-empty input. + """ + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.TEAM + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"), + patch( + "controllers.console.workspace.workspace.BillingService.get_plan_bulk", + return_value={}, # Simulates real failure: empty result for non-empty input + ) as get_plan_bulk_mock, + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ) as get_features_mock, + patch("controllers.console.workspace.workspace.logger.warning") as logger_warning_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.TEAM + assert result["workspaces"][1]["plan"] == CloudPlan.TEAM + get_plan_bulk_mock.assert_called_once_with(["t1", "t2"]) + assert get_features_mock.call_count == 2 + logger_warning_mock.assert_called_once() + + def test_get_billing_disabled_community_path(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="Tenant", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = False + features.billing.subscription.plan = CloudPlan.SANDBOX + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ) as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + get_features_mock.assert_called_once_with("t1") + + def test_get_enterprise_only_skips_feature_service(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX + assert result["workspaces"][0]["current"] is False + assert result["workspaces"][1]["current"] is True + get_features_mock.assert_not_called() + + def test_get_enterprise_only_with_empty_tenants(self, app): + api = TenantListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None) + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[], + ), + patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True), + patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False), + patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"), + patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock, + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"] == [] + get_features_mock.assert_not_called() + + +class TestWorkspaceListApi: + def test_get_success(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + + paginate_result = MagicMock( + items=[tenant], + has_next=False, + total=1, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}), + patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result), + ): + result, status = method(api) + + assert status == 200 + assert result["total"] == 1 + assert result["has_more"] is False + + def test_get_has_next_true(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="T", + status="active", + created_at=datetime.utcnow(), + ) + + paginate_result = MagicMock( + items=[tenant], + has_next=True, + total=10, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}), + patch( + "controllers.console.workspace.workspace.db.paginate", + return_value=paginate_result, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["has_more"] is True + + +class TestTenantApi: + def test_post_active_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result, status = method(api) + + assert status == 200 + assert result["id"] == "t1" + + def test_post_archived_with_switch(self, app): + api = TenantApi() + method = unwrap(api.post) + + archived = MagicMock(status=TenantStatus.ARCHIVE) + new_tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=archived) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"} + ), + ): + result, status = method(api) + + assert result["id"] == "new" + + def test_post_archived_no_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE)) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]), + ): + with pytest.raises(Unauthorized): + method(api) + + def test_post_info_path(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/info"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(user, "t1"), + ), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + patch("controllers.console.workspace.workspace.logger.warning") as warn_mock, + ): + result, status = method(api) + + warn_mock.assert_called_once() + assert status == 200 + + +class TestSwitchWorkspaceApi: + def test_switch_success(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "t2"} + tenant = MagicMock(id="t2") + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} + ), + ): + get_mock.return_value = tenant + result = method(api) + + assert result["result"] == "success" + + def test_switch_not_linked(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "bad"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception), + ): + with pytest.raises(AccountNotLinkTenantError): + method(api) + + def test_switch_tenant_not_found(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "missing"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.get") as get_mock, + ): + get_mock.return_value = None + + with pytest.raises(ValueError): + method(api) + + +class TestCustomConfigWorkspaceApi: + def test_post_success(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={}) + + payload = {"remove_webapp_brand": True} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_logo_fallback(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"}) + + payload = {"remove_webapp_brand": False} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.db.get_or_404", + return_value=tenant, + ), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + ): + result = method(api) + + assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo" + assert result["result"] == "success" + + +class TestWebappLogoWorkspaceApi: + def test_no_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/upload", data={}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(NoFileUploadedError): + method(api) + + def test_too_many_files(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + data = { + "file": MagicMock(), + "extra": MagicMock(), + } + + with ( + app.test_request_context("/upload", data=data), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(TooManyFilesError): + method(api) + + def test_invalid_extension(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = MagicMock(filename="test.txt") + + with ( + app.test_request_context("/upload", data={"file": file}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(UnsupportedFileTypeError): + method(api) + + def test_upload_success(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="logo.png", + content_type="image/png", + ) + + upload = MagicMock(id="file1") + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.return_value = upload + + result, status = method(api) + + assert status == 201 + assert result["id"] == "file1" + + def test_filename_missing(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(FilenameNotExistsError): + method(api) + + def test_file_too_large(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big") + + with pytest.raises(FileTooLargeError): + method(api) + + def test_service_unsupported_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError() + + with pytest.raises(UnsupportedFileTypeError): + method(api) + + +class TestWorkspaceInfoApi: + def test_post_success(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + tenant = MagicMock() + + payload = {"name": "New Name"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"name": "New Name"}, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_no_current_tenant(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + payload = {"name": "X"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestWorkspacePermissionApi: + def test_get_success(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + permission = MagicMock( + workspace_id="t1", + allow_member_invite=True, + allow_owner_transfer=False, + ) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission", + return_value=permission, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["workspace_id"] == "t1" + + def test_no_current_tenant(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py new file mode 100644 index 0000000000..b290748155 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import importlib +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace import plugin_permission_required +from models.account import TenantPluginPermission + + +class _SessionStub: + def __init__(self, permission): + self._permission = permission + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *_args, **_kwargs): + return self + + def where(self, *_args, **_kwargs): + return self + + def first(self): + return self._permission + + +def _workspace_module(): + return importlib.import_module(plugin_permission_required.__module__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, permission): + module = _workspace_module() + monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission)) + monkeypatch.setattr(module, "db", SimpleNamespace(engine=object())) + + +def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, None) + + @plugin_permission_required() + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.NOBODY, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.NOBODY, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.ADMINS, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() diff --git a/api/tests/unit_tests/controllers/inner_api/__init__.py b/api/tests/unit_tests/controllers/inner_api/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py b/api/tests/unit_tests/controllers/inner_api/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py new file mode 100644 index 0000000000..844f04fe72 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin.py @@ -0,0 +1,313 @@ +""" +Unit tests for inner_api plugin endpoints + +Tests endpoint structure (method existence) for all plugin APIs, plus +handler-level logic tests for representative non-streaming endpoints. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.inner_api.plugin.plugin import ( + PluginFetchAppInfoApi, + PluginInvokeAppApi, + PluginInvokeEncryptApi, + PluginInvokeLLMApi, + PluginInvokeLLMWithStructuredOutputApi, + PluginInvokeModerationApi, + PluginInvokeParameterExtractorNodeApi, + PluginInvokeQuestionClassifierNodeApi, + PluginInvokeRerankApi, + PluginInvokeSpeech2TextApi, + PluginInvokeSummaryApi, + PluginInvokeTextEmbeddingApi, + PluginInvokeToolApi, + PluginInvokeTTSApi, + PluginUploadFileRequestApi, +) + + +def _extract_raw_post(cls): + """Extract the raw post() method from a plugin endpoint class. + + Plugin endpoint methods are wrapped by several decorators (get_user_tenant, + setup_required, plugin_inner_api_only, plugin_data). These decorators + use @wraps where possible. This helper ensures we retrieve the original + post(self, user_model, tenant_model, payload) function by unwrapping + and, if necessary, walking the closure of the innermost wrapper. + """ + bottom = inspect.unwrap(cls.post) + + # If unwrap() didn't get us to the raw function (e.g. if a decorator + # missed @wraps), try to extract it from the closure if it looks like + # a plugin_data or similar wrapper that closes over 'view_func'. + if hasattr(bottom, "__code__") and "view_func" in bottom.__code__.co_freevars: + try: + idx = bottom.__code__.co_freevars.index("view_func") + return bottom.__closure__[idx].cell_contents + except (AttributeError, TypeError, IndexError): + pass + + return bottom + + +class TestPluginInvokeLLMApi: + """Test PluginInvokeLLMApi endpoint structure""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMApi() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeLLMWithStructuredOutputApi: + """Test PluginInvokeLLMWithStructuredOutputApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeLLMWithStructuredOutputApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTextEmbeddingApi: + """Test PluginInvokeTextEmbeddingApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTextEmbeddingApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeRerankApi: + """Test PluginInvokeRerankApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeRerankApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeTTSApi: + """Test PluginInvokeTTSApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeTTSApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeSpeech2TextApi: + """Test PluginInvokeSpeech2TextApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSpeech2TextApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeModerationApi: + """Test PluginInvokeModerationApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeModerationApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeToolApi: + """Test PluginInvokeToolApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeToolApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeParameterExtractorNodeApi: + """Test PluginInvokeParameterExtractorNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeParameterExtractorNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeQuestionClassifierNodeApi: + """Test PluginInvokeQuestionClassifierNodeApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeQuestionClassifierNodeApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeAppApi: + """Test PluginInvokeAppApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeAppApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginInvokeEncryptApi: + """Test PluginInvokeEncryptApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeEncryptApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_encrypted_data(self, mock_encrypter, api_instance, app: Flask): + """Test that post() delegates to PluginEncrypter and returns model_dump output""" + # Arrange + mock_encrypter.invoke_encrypt.return_value = {"encrypted": "data"} + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act — extract raw post() bypassing all decorators including plugin_data + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_encrypter.invoke_encrypt.assert_called_once_with(mock_tenant, mock_payload) + assert result["data"] == {"encrypted": "data"} + assert result.get("error") == "" + + @patch("controllers.inner_api.plugin.plugin.PluginEncrypter") + def test_post_returns_error_on_exception(self, mock_encrypter, api_instance, app: Flask): + """Test that post() catches exceptions and returns error response""" + # Arrange + mock_encrypter.invoke_encrypt.side_effect = RuntimeError("encrypt failed") + mock_tenant = MagicMock() + mock_user = MagicMock() + mock_payload = MagicMock() + + # Act + raw_post = _extract_raw_post(PluginInvokeEncryptApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + assert "encrypt failed" in result["error"] + + +class TestPluginInvokeSummaryApi: + """Test PluginInvokeSummaryApi endpoint""" + + @pytest.fixture + def api_instance(self): + return PluginInvokeSummaryApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + +class TestPluginUploadFileRequestApi: + """Test PluginUploadFileRequestApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginUploadFileRequestApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.get_signed_file_url_for_plugin") + def test_post_returns_signed_url(self, mock_get_url, api_instance, app: Flask): + """Test that post() generates a signed URL and returns it""" + # Arrange + mock_get_url.return_value = "https://storage.example.com/signed-upload-url" + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_user.id = "user-id" + mock_payload = MagicMock() + mock_payload.filename = "test.pdf" + mock_payload.mimetype = "application/pdf" + + # Act + raw_post = _extract_raw_post(PluginUploadFileRequestApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_get_url.assert_called_once_with( + filename="test.pdf", mimetype="application/pdf", tenant_id="tenant-id", user_id="user-id" + ) + assert result["data"]["url"] == "https://storage.example.com/signed-upload-url" + + +class TestPluginFetchAppInfoApi: + """Test PluginFetchAppInfoApi endpoint structure and handler logic""" + + @pytest.fixture + def api_instance(self): + return PluginFetchAppInfoApi() + + def test_has_post_method(self, api_instance): + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.plugin.plugin.PluginAppBackwardsInvocation") + def test_post_returns_app_info(self, mock_invocation, api_instance, app: Flask): + """Test that post() fetches app info and returns it""" + # Arrange + mock_invocation.fetch_app_info.return_value = {"app_name": "My App", "mode": "chat"} + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_user = MagicMock() + mock_payload = MagicMock() + mock_payload.app_id = "app-123" + + # Act + raw_post = _extract_raw_post(PluginFetchAppInfoApi) + result = raw_post(api_instance, user_model=mock_user, tenant_model=mock_tenant, payload=mock_payload) + + # Assert + mock_invocation.fetch_app_info.assert_called_once_with("app-123", "tenant-id") + assert result["data"] == {"app_name": "My App", "mode": "chat"} diff --git a/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py new file mode 100644 index 0000000000..eac57fe4b7 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/plugin/test_plugin_wraps.py @@ -0,0 +1,308 @@ +""" +Unit tests for inner_api plugin decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.plugin.wraps import ( + TenantUserPayload, + get_user, + get_user_tenant, + plugin_data, +) + + +class TestTenantUserPayload: + """Test TenantUserPayload Pydantic model""" + + def test_valid_payload(self): + """Test valid payload passes validation""" + data = {"tenant_id": "tenant123", "user_id": "user456"} + payload = TenantUserPayload.model_validate(data) + assert payload.tenant_id == "tenant123" + assert payload.user_id == "user456" + + def test_missing_tenant_id(self): + """Test missing tenant_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"user_id": "user456"}) + + def test_missing_user_id(self): + """Test missing user_id raises ValidationError""" + with pytest.raises(ValidationError): + TenantUserPayload.model_validate({"tenant_id": "tenant123"}) + + +class TestGetUser: + """Test get_user function""" + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_user_by_id(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test returning existing user when found by ID""" + # Arrange + mock_user = MagicMock() + mock_user.id = "user123" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.get.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_user + mock_session.get.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_return_existing_anonymous_user_by_session_id( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test returning existing anonymous user by session_id""" + # Arrange + mock_user = MagicMock() + mock_user.session_id = "anonymous_session" + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + # non-anonymous path uses session.get(); anonymous uses session.scalar() + mock_session.get.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", "anonymous_session") + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_create_new_user_when_not_found(self, mock_db, mock_session_class, mock_enduser_class, app: Flask): + """Test creating new user when not found in database""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.get.return_value = None + mock_new_user = MagicMock() + mock_enduser_class.return_value = mock_new_user + + # Act + with app.app_context(): + result = get_user("tenant123", "user123") + + # Assert + assert result == mock_new_user + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + mock_session.refresh.assert_called_once() + + @patch("controllers.inner_api.plugin.wraps.select") + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_use_default_session_id_when_user_id_none( + self, mock_db, mock_session_class, mock_enduser_class, mock_select, app: Flask + ): + """Test using default session ID when user_id is None""" + # Arrange + mock_user = MagicMock() + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + # When user_id is None, is_anonymous=True, so session.scalar() is used + mock_session.scalar.return_value = mock_user + + # Act + with app.app_context(): + result = get_user("tenant123", None) + + # Assert + assert result == mock_user + + @patch("controllers.inner_api.plugin.wraps.EndUser") + @patch("controllers.inner_api.plugin.wraps.Session") + @patch("controllers.inner_api.plugin.wraps.db") + def test_should_raise_error_on_database_exception( + self, mock_db, mock_session_class, mock_enduser_class, app: Flask + ): + """Test raising ValueError when database operation fails""" + # Arrange + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.get.side_effect = Exception("Database error") + + # Act & Assert + with app.app_context(): + with pytest.raises(ValueError, match="user not found"): + get_user("tenant123", "user123") + + +class TestGetUserTenant: + """Test get_user_tenant decorator""" + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_inject_tenant_and_user_models(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that decorator injects tenant_model and user_model into kwargs""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + mock_user.id = "user456" + + # Act + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": "user456"}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_get.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + + def test_should_raise_error_when_tenant_id_missing(self, app: Flask): + """Test that Pydantic ValidationError is raised when tenant_id is missing from payload""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert - Pydantic validates payload before manual check + with app.test_request_context(json={"user_id": "user456"}): + with pytest.raises(ValidationError): + protected_view() + + def test_should_raise_error_when_tenant_not_found(self, app: Flask): + """Test that ValueError is raised when tenant is not found""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return "success" + + # Act & Assert + with app.test_request_context(json={"tenant_id": "nonexistent", "user_id": "user456"}): + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: + mock_get.return_value = None + with pytest.raises(ValueError, match="tenant not found"): + protected_view() + + @patch("controllers.inner_api.plugin.wraps.Tenant") + def test_should_use_default_session_id_when_user_id_empty(self, mock_tenant_class, app: Flask, monkeypatch): + """Test that default session ID is used when user_id is empty string""" + + # Arrange + @get_user_tenant + def protected_view(tenant_model, user_model, **kwargs): + return {"tenant": tenant_model, "user": user_model} + + mock_tenant = MagicMock() + mock_tenant.id = "tenant123" + mock_user = MagicMock() + + # Act - use empty string for user_id to trigger default logic + with app.test_request_context(json={"tenant_id": "tenant123", "user_id": ""}): + monkeypatch.setattr(app, "login_manager", MagicMock(), raising=False) + with patch("controllers.inner_api.plugin.wraps.db.session.get") as mock_get: + with patch("controllers.inner_api.plugin.wraps.get_user") as mock_get_user: + mock_get.return_value = mock_tenant + mock_get_user.return_value = mock_user + result = protected_view() + + # Assert + assert result["tenant"] == mock_tenant + assert result["user"] == mock_user + from models.model import DefaultEndUserSessionID + + mock_get_user.assert_called_once_with("tenant123", DefaultEndUserSessionID.DEFAULT_SESSION_ID) + + +class PluginTestPayload: + """Simple test payload class""" + + def __init__(self, data: dict): + self.value = data.get("value") + + @classmethod + def model_validate(cls, data: dict): + return cls(data) + + +class TestPluginData: + """Test plugin_data decorator""" + + def test_should_inject_valid_payload(self, app: Flask): + """Test that valid payload is injected into kwargs""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "test_data"}): + result = protected_view() + + # Assert + assert result.value == "test_data" + + def test_should_raise_error_on_invalid_json(self, app: Flask): + """Test that ValueError is raised when JSON parsing fails""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert - Malformed JSON triggers ValueError + with app.test_request_context(data="not valid json", content_type="application/json"): + with pytest.raises(ValueError): + protected_view() + + def test_should_raise_error_on_invalid_payload(self, app: Flask): + """Test that ValueError is raised when payload validation fails""" + + # Arrange + class InvalidPayload: + @classmethod + def model_validate(cls, data: dict): + raise Exception("Validation failed") + + @plugin_data(payload_type=InvalidPayload) + def protected_view(payload, **kwargs): + return payload + + # Act & Assert + with app.test_request_context(json={"data": "test"}): + with pytest.raises(ValueError, match="invalid payload"): + protected_view() + + def test_should_work_as_parameterized_decorator(self, app: Flask): + """Test that decorator works when used with parentheses""" + + # Arrange + @plugin_data(payload_type=PluginTestPayload) + def protected_view(payload, **kwargs): + return payload + + # Act + with app.test_request_context(json={"value": "parameterized"}): + result = protected_view() + + # Assert + assert result.value == "parameterized" diff --git a/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py new file mode 100644 index 0000000000..efe1841f08 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_auth_wraps.py @@ -0,0 +1,309 @@ +""" +Unit tests for inner_api auth decorators +""" + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import HTTPException + +from configs import dify_config +from controllers.inner_api.wraps import ( + billing_inner_api_only, + enterprise_inner_api_only, + enterprise_inner_api_user_auth, + plugin_inner_api_only, +) + + +class TestBillingInnerApiOnly: + """Test billing_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @billing_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiOnly: + """Test enterprise_inner_api_only decorator""" + + def test_should_allow_when_inner_api_enabled_and_valid_key(self, app: Flask): + """Test that valid API key allows access when INNER_API is enabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_inner_api_disabled(self, app: Flask): + """Test that 404 is returned when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_401_when_api_key_missing(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is missing""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + def test_should_return_401_when_api_key_invalid(self, app: Flask): + """Test that 401 is returned when X-Inner-Api-Key header is invalid""" + + # Arrange + @enterprise_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "INNER_API", True): + with patch.object(dify_config, "INNER_API_KEY", "valid_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 401 + + +class TestEnterpriseInnerApiUserAuth: + """Test enterprise_inner_api_user_auth decorator for HMAC-based user authentication""" + + def test_should_pass_through_when_inner_api_disabled(self, app: Flask): + """Test that request passes through when INNER_API is disabled""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(): + with patch.object(dify_config, "INNER_API", False): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_header_missing(self, app: Flask): + """Test that request passes through when Authorization header is missing""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_authorization_format_invalid(self, app: Flask): + """Test that request passes through when Authorization format is invalid (no colon)""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act + with app.test_request_context(headers={"Authorization": "invalid_format"}): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_pass_through_when_hmac_signature_invalid(self, app: Flask): + """Test that request passes through when HMAC signature is invalid""" + + # Arrange + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user", "no_user") + + # Act - use wrong signature + with app.test_request_context( + headers={"Authorization": "Bearer user123:wrong_signature", "X-Inner-Api-Key": "valid_key"} + ): + with patch.object(dify_config, "INNER_API", True): + result = protected_view() + + # Assert + assert result == "no_user" + + def test_should_inject_user_when_hmac_signature_valid(self, app: Flask): + """Test that user is injected when HMAC signature is valid""" + # Arrange + from base64 import b64encode + from hashlib import sha1 + from hmac import new as hmac_new + + @enterprise_inner_api_user_auth + def protected_view(**kwargs): + return kwargs.get("user") + + # Calculate valid HMAC signature + user_id = "user123" + inner_api_key = "valid_key" + data_to_sign = f"DIFY {user_id}" + signature = hmac_new(inner_api_key.encode("utf-8"), data_to_sign.encode("utf-8"), sha1) + valid_signature = b64encode(signature.digest()).decode("utf-8") + + # Create mock user + mock_user = MagicMock() + mock_user.id = user_id + + # Act + with app.test_request_context( + headers={"Authorization": f"Bearer {user_id}:{valid_signature}", "X-Inner-Api-Key": inner_api_key} + ): + with patch.object(dify_config, "INNER_API", True): + with patch("controllers.inner_api.wraps.db.session.get") as mock_get: + mock_get.return_value = mock_user + result = protected_view() + + # Assert + assert result == mock_user + + +class TestPluginInnerApiOnly: + """Test plugin_inner_api_only decorator""" + + def test_should_allow_when_plugin_daemon_key_set_and_valid_key(self, app: Flask): + """Test that valid API key allows access when PLUGIN_DAEMON_KEY is set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act + with app.test_request_context(headers={"X-Inner-Api-Key": "valid_plugin_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_return_404_when_plugin_daemon_key_not_set(self, app: Flask): + """Test that 404 is returned when PLUGIN_DAEMON_KEY is not set""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", ""): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 + + def test_should_return_404_when_api_key_invalid(self, app: Flask): + """Test that 404 is returned when X-Inner-Api-Key header is invalid (note: returns 404, not 401)""" + + # Arrange + @plugin_inner_api_only + def protected_view(): + return "success" + + # Act & Assert + with app.test_request_context(headers={"X-Inner-Api-Key": "invalid_key"}): + with patch.object(dify_config, "PLUGIN_DAEMON_KEY", "plugin_key"): + with patch.object(dify_config, "INNER_API_KEY_FOR_PLUGIN", "valid_plugin_key"): + with pytest.raises(HTTPException) as exc_info: + protected_view() + assert exc_info.value.code == 404 diff --git a/api/tests/unit_tests/controllers/inner_api/test_mail.py b/api/tests/unit_tests/controllers/inner_api/test_mail.py new file mode 100644 index 0000000000..c2ca35693e --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/test_mail.py @@ -0,0 +1,206 @@ +""" +Unit tests for inner_api mail module +""" + +from unittest.mock import patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.mail import ( + BaseMail, + BillingMail, + EnterpriseMail, + InnerMailPayload, +) + + +class TestInnerMailPayload: + """Test InnerMailPayload Pydantic model""" + + def test_valid_payload_with_all_fields(self): + """Test valid payload with all fields passes validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + "substitutions": {"key": "value"}, + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions == {"key": "value"} + + def test_valid_payload_without_substitutions(self): + """Test valid payload without optional substitutions""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert payload.to == ["test@example.com"] + assert payload.subject == "Test Subject" + assert payload.body == "Test Body" + assert payload.substitutions is None + + def test_empty_to_list_fails_validation(self): + """Test that empty 'to' list fails validation due to min_length=1""" + data = { + "to": [], + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_multiple_recipients_allowed(self): + """Test that multiple recipients are allowed""" + data = { + "to": ["user1@example.com", "user2@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + payload = InnerMailPayload.model_validate(data) + assert len(payload.to) == 2 + assert "user1@example.com" in payload.to + assert "user2@example.com" in payload.to + + def test_missing_to_field_fails_validation(self): + """Test that missing 'to' field fails validation""" + data = { + "subject": "Test Subject", + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_subject_fails_validation(self): + """Test that missing 'subject' field fails validation""" + data = { + "to": ["test@example.com"], + "body": "Test Body", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + def test_missing_body_fails_validation(self): + """Test that missing 'body' field fails validation""" + data = { + "to": ["test@example.com"], + "subject": "Test Subject", + } + with pytest.raises(ValidationError): + InnerMailPayload.model_validate(data) + + +class TestBaseMail: + """Test BaseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BaseMail API instance""" + return BaseMail() + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_sends_email_task(self, mock_task, api_instance, app: Flask): + """Test that POST sends inner email task""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context( + json={ + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + ): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Test Subject", + "body": "Test Body", + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Test Subject", + body="Test Body", + substitutions=None, + ) + + @patch("controllers.inner_api.mail.send_inner_email_task") + def test_post_with_substitutions(self, mock_task, api_instance, app: Flask): + """Test that POST sends email with substitutions""" + # Arrange + mock_task.delay.return_value = None + + # Act + with app.test_request_context(): + with patch("controllers.inner_api.mail.inner_api_ns") as mock_ns: + mock_ns.payload = { + "to": ["test@example.com"], + "subject": "Hello {{name}}", + "body": "Welcome {{name}}!", + "substitutions": {"name": "John"}, + } + result = api_instance.post() + + # Assert + assert result == ({"message": "success"}, 200) + mock_task.delay.assert_called_once_with( + to=["test@example.com"], + subject="Hello {{name}}", + body="Welcome {{name}}!", + substitutions={"name": "John"}, + ) + + +class TestEnterpriseMail: + """Test EnterpriseMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create EnterpriseMail API instance""" + return EnterpriseMail() + + def test_has_enterprise_inner_api_only_decorator(self, api_instance): + """Test that EnterpriseMail has enterprise_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import enterprise_inner_api_only + + assert enterprise_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that EnterpriseMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names + + +class TestBillingMail: + """Test BillingMail API endpoint""" + + @pytest.fixture + def api_instance(self): + """Create BillingMail API instance""" + return BillingMail() + + def test_has_billing_inner_api_only_decorator(self, api_instance): + """Test that BillingMail has billing_inner_api_only decorator""" + # Check method_decorators + from controllers.inner_api.wraps import billing_inner_api_only + + assert billing_inner_api_only in api_instance.method_decorators + + def test_has_setup_required_decorator(self, api_instance): + """Test that BillingMail has setup_required decorator""" + # Check by decorator name instead of object reference + decorator_names = [d.__name__ for d in api_instance.method_decorators] + assert "setup_required" in decorator_names diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py b/api/tests/unit_tests/controllers/inner_api/workspace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py new file mode 100644 index 0000000000..56a8f94963 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/workspace/test_workspace.py @@ -0,0 +1,184 @@ +""" +Unit tests for inner_api workspace module + +Tests Pydantic model validation and endpoint handler logic. +Auth/setup decorators are tested separately in test_auth_wraps.py; +handler tests use inspect.unwrap() to bypass them and focus on business logic. +""" + +import inspect +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.workspace.workspace import ( + EnterpriseWorkspace, + EnterpriseWorkspaceNoOwnerEmail, + WorkspaceCreatePayload, + WorkspaceOwnerlessPayload, +) + + +class TestWorkspaceCreatePayload: + """Test WorkspaceCreatePayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with all fields passes validation""" + data = { + "name": "My Workspace", + "owner_email": "owner@example.com", + } + payload = WorkspaceCreatePayload.model_validate(data) + assert payload.name == "My Workspace" + assert payload.owner_email == "owner@example.com" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {"owner_email": "owner@example.com"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "name" in str(exc_info.value) + + def test_missing_owner_email_fails_validation(self): + """Test that missing owner_email fails validation""" + data = {"name": "My Workspace"} + with pytest.raises(ValidationError) as exc_info: + WorkspaceCreatePayload.model_validate(data) + assert "owner_email" in str(exc_info.value) + + +class TestWorkspaceOwnerlessPayload: + """Test WorkspaceOwnerlessPayload Pydantic model validation""" + + def test_valid_payload(self): + """Test valid payload with name passes validation""" + data = {"name": "My Workspace"} + payload = WorkspaceOwnerlessPayload.model_validate(data) + assert payload.name == "My Workspace" + + def test_missing_name_fails_validation(self): + """Test that missing name fails validation""" + data = {} + with pytest.raises(ValidationError) as exc_info: + WorkspaceOwnerlessPayload.model_validate(data) + assert "name" in str(exc_info.value) + + +class TestEnterpriseWorkspace: + """Test EnterpriseWorkspace API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspace() + + def test_has_post_method(self, api_instance): + """Test that EnterpriseWorkspace has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_creates_workspace_with_owner(self, mock_db, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace and assigns the owner account""" + # Arrange + mock_account = MagicMock() + mock_account.email = "owner@example.com" + mock_db.session.scalar.return_value = mock_account + + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "owner@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["name"] == "My Workspace" + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_tenant_svc.create_tenant_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_event.send.assert_called_once_with(mock_tenant) + + @patch("controllers.inner_api.workspace.workspace.db") + def test_post_returns_404_when_owner_not_found(self, mock_db, api_instance, app: Flask): + """Test that post() returns 404 when the owner account does not exist""" + # Arrange + mock_db.session.scalar.return_value = None + + # Act + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace", "owner_email": "missing@example.com"} + result = unwrapped_post(api_instance) + + # Assert + assert result == ({"message": "owner account not found."}, 404) + + +class TestEnterpriseWorkspaceNoOwnerEmail: + """Test EnterpriseWorkspaceNoOwnerEmail API endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators (tested in test_auth_wraps.py) + and exercise the core business logic directly. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseWorkspaceNoOwnerEmail() + + def test_has_post_method(self, api_instance): + """Test that endpoint has post method""" + assert hasattr(api_instance, "post") + assert callable(api_instance.post) + + @patch("controllers.inner_api.workspace.workspace.tenant_was_created") + @patch("controllers.inner_api.workspace.workspace.TenantService") + def test_post_creates_ownerless_workspace(self, mock_tenant_svc, mock_event, api_instance, app: Flask): + """Test that post() creates a workspace without an owner and returns expected fields""" + # Arrange + now = datetime(2025, 1, 1, 12, 0, 0) + mock_tenant = MagicMock() + mock_tenant.id = "tenant-id" + mock_tenant.name = "My Workspace" + mock_tenant.encrypt_public_key = "pub-key" + mock_tenant.plan = "sandbox" + mock_tenant.status = "normal" + mock_tenant.custom_config = None + mock_tenant.created_at = now + mock_tenant.updated_at = now + mock_tenant_svc.create_tenant.return_value = mock_tenant + + # Act — unwrap to bypass auth/setup decorators (tested in test_auth_wraps.py) + unwrapped_post = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.workspace.workspace.inner_api_ns") as mock_ns: + mock_ns.payload = {"name": "My Workspace"} + result = unwrapped_post(api_instance) + + # Assert + assert result["message"] == "enterprise workspace created." + assert result["tenant"]["id"] == "tenant-id" + assert result["tenant"]["encrypt_public_key"] == "pub-key" + assert result["tenant"]["custom_config"] == {} + mock_tenant_svc.create_tenant.assert_called_once_with("My Workspace", is_from_dashboard=True) + mock_event.send.assert_called_once_with(mock_tenant) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index b70e70105c..1923ab7fa7 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -29,7 +29,7 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index c5b1cbc127..4e4482f704 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -34,7 +34,7 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from core.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_message.py b/api/tests/unit_tests/controllers/service_api/app/test_message.py index 4de12de829..c2b8aed1ae 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_message.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_message.py @@ -31,6 +31,7 @@ from controllers.service_api.app.message import ( MessageListQuery, MessageSuggestedApi, ) +from models.enums import FeedbackRating from models.model import App, AppMode, EndUser from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( @@ -310,7 +311,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id=str(uuid.uuid4()), user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content="Great response!", ) @@ -326,7 +327,7 @@ class TestMessageService: app_model=Mock(spec=App), message_id="invalid_message_id", user=Mock(spec=EndUser), - rating="like", + rating=FeedbackRating.LIKE, content=None, ) diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 0eb3854c84..4eada73b82 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -35,7 +35,7 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from core.workflow.graph_engine.manager import GraphEngineManager + from dify_graph.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index fcaa61a871..9e95f45a0a 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,7 @@ from types import SimpleNamespace from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 7cb2f1050c..8fe41cd19f 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.dataset import ( from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError from models.account import Account from models.dataset import DatasetPermissionEnum +from models.enums import TagType from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.tag_service import TagService @@ -277,7 +278,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Test Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "0" # Required for Pydantic validation - must be string mock_tag_service.get_tags.return_value = [mock_tag] @@ -316,7 +317,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "new_tag_1" mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag_service.save_tags.return_value = mock_tag mock_service_api_ns.payload = {"name": "New Tag"} @@ -378,7 +379,7 @@ class TestDatasetTagsApi: mock_tag = Mock() mock_tag.id = "tag_1" mock_tag.name = "Updated Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_tag.binding_count = "5" mock_tag_service.update_tags.return_value = mock_tag mock_tag_service.get_tag_binding_count.return_value = 5 @@ -866,7 +867,7 @@ class TestTagService: mock_tag = Mock() mock_tag.id = str(uuid.uuid4()) mock_tag.name = "New Tag" - mock_tag.type = "knowledge" + mock_tag.type = TagType.KNOWLEDGE mock_save.return_value = mock_tag result = TagService.save_tags({"name": "New Tag", "type": "knowledge"}) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index dc651a1627..5c48ef1804 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -32,6 +32,7 @@ from controllers.service_api.dataset.segment import ( SegmentListQuery, ) from models.dataset import ChildChunk, Dataset, Document, DocumentSegment +from models.enums import IndexingStatus from services.dataset_service import DocumentService, SegmentService @@ -657,12 +658,27 @@ class TestSegmentIndexingRequirements: dataset.indexing_technique = technique assert dataset.indexing_technique in ["high_quality", "economy"] - @pytest.mark.parametrize("status", ["waiting", "parsing", "indexing", "completed", "error"]) + @pytest.mark.parametrize( + "status", + [ + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + ], + ) def test_valid_indexing_statuses(self, status): """Test valid document indexing statuses.""" document = Mock(spec=Document) document.indexing_status = status - assert document.indexing_status in ["waiting", "parsing", "indexing", "completed", "error"] + assert document.indexing_status in { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + } def test_completed_status_required_for_segments(self): """Test that completed status is required for segment operations.""" diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index f98109af79..e6e841be19 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -35,6 +35,7 @@ from controllers.service_api.dataset.document import ( InvalidMetadataError, ) from controllers.service_api.dataset.error import ArchivedDocumentImmutableError +from models.enums import IndexingStatus from services.dataset_service import DocumentService from services.entities.knowledge_entities.knowledge_entities import ProcessRule, RetrievalModel @@ -244,23 +245,26 @@ class TestDocumentService: class TestDocumentIndexingStatus: """Test document indexing status values.""" + _VALID_STATUSES = { + IndexingStatus.WAITING, + IndexingStatus.PARSING, + IndexingStatus.INDEXING, + IndexingStatus.COMPLETED, + IndexingStatus.ERROR, + IndexingStatus.PAUSED, + } + def test_completed_status(self): """Test completed status.""" - status = "completed" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.COMPLETED in self._VALID_STATUSES def test_indexing_status(self): """Test indexing status.""" - status = "indexing" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.INDEXING in self._VALID_STATUSES def test_error_status(self): """Test error status.""" - status = "error" - valid_statuses = ["waiting", "parsing", "indexing", "completed", "error", "paused"] - assert status in valid_statuses + assert IndexingStatus.ERROR in self._VALID_STATUSES class TestDocumentDocForm: diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py index 61fce3ed97..95c2f5cf92 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_hit_testing.py @@ -39,14 +39,21 @@ class TestHitTestingPayload: def test_payload_with_all_fields(self): """Test payload with all optional fields.""" + retrieval_model_data = { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": False, + "top_k": 5, + } payload = HitTestingPayload( query="test query", - retrieval_model={"top_k": 5}, + retrieval_model=retrieval_model_data, external_retrieval_model={"provider": "openai"}, attachment_ids=["att_1", "att_2"], ) assert payload.query == "test query" - assert payload.retrieval_model == {"top_k": 5} + assert payload.retrieval_model is not None + assert payload.retrieval_model.top_k == 5 assert payload.external_retrieval_model == {"provider": "openai"} assert payload.attachment_ids == ["att_1", "att_2"] @@ -134,7 +141,13 @@ class TestHitTestingApiPost: mock_dataset_svc.get_dataset.return_value = mock_dataset mock_dataset_svc.check_dataset_permission.return_value = None - retrieval_model = {"search_method": "semantic", "top_k": 10, "score_threshold": 0.8} + retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "score_threshold_enabled": True, + "top_k": 10, + "score_threshold": 0.8, + } mock_hit_svc.retrieve.return_value = {"query": "complex query", "records": []} mock_hit_svc.hit_testing_args_check.return_value = None @@ -152,7 +165,11 @@ class TestHitTestingApiPost: assert response["query"] == "complex query" call_kwargs = mock_hit_svc.retrieve.call_args - assert call_kwargs.kwargs.get("retrieval_model") == retrieval_model + # retrieval_model is serialized via model_dump, verify key fields + passed_retrieval_model = call_kwargs.kwargs.get("retrieval_model") + assert passed_retrieval_model is not None + assert passed_retrieval_model["search_method"] == "semantic_search" + assert passed_retrieval_model["top_k"] == 10 @patch("controllers.service_api.dataset.hit_testing.service_api_ns") @patch("controllers.console.datasets.hit_testing_base.DatasetService") diff --git a/api/tests/unit_tests/controllers/trigger/test_webhook.py b/api/tests/unit_tests/controllers/trigger/test_webhook.py index d633365f2b..91c793d292 100644 --- a/api/tests/unit_tests/controllers/trigger/test_webhook.py +++ b/api/tests/unit_tests/controllers/trigger/test_webhook.py @@ -23,6 +23,7 @@ def mock_jsonify(): class DummyWebhookTrigger: webhook_id = "wh-1" + webhook_url = "http://localhost:5001/triggers/webhook/wh-1" tenant_id = "tenant-1" app_id = "app-1" node_id = "node-1" @@ -104,7 +105,32 @@ class TestHandleWebhookDebug: @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") @patch.object(module.WebhookService, "extract_and_validate_webhook_data") @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) - @patch.object(module.TriggerDebugEventBus, "dispatch") + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=0) + def test_debug_requires_active_listener( + self, + mock_dispatch, + mock_build_inputs, + mock_extract, + mock_get, + ): + mock_get.return_value = (DummyWebhookTrigger(), None, "node_config") + mock_extract.return_value = {"method": "POST"} + + response, status = module.handle_webhook_debug("wh-1") + + assert status == 409 + assert response["error"] == "No active debug listener" + assert response["message"] == ( + "The webhook debug URL only works while the Variable Inspector is listening. " + "Use the published webhook URL to execute the workflow in Celery." + ) + assert response["execution_url"] == DummyWebhookTrigger.webhook_url + mock_dispatch.assert_called_once() + + @patch.object(module.WebhookService, "get_webhook_trigger_and_workflow") + @patch.object(module.WebhookService, "extract_and_validate_webhook_data") + @patch.object(module.WebhookService, "build_workflow_inputs", return_value={"x": 1}) + @patch.object(module.TriggerDebugEventBus, "dispatch", return_value=1) @patch.object(module.WebhookService, "generate_webhook_response") def test_debug_success( self, diff --git a/api/tests/unit_tests/controllers/web/__init__.py b/api/tests/unit_tests/controllers/web/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/web/conftest.py b/api/tests/unit_tests/controllers/web/conftest.py new file mode 100644 index 0000000000..274d78c9cf --- /dev/null +++ b/api/tests/unit_tests/controllers/web/conftest.py @@ -0,0 +1,85 @@ +"""Shared fixtures for controllers.web unit tests.""" + +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any + +import pytest +from flask import Flask + + +@pytest.fixture +def app() -> Flask: + """Minimal Flask app for request contexts.""" + flask_app = Flask(__name__) + flask_app.config["TESTING"] = True + return flask_app + + +class FakeSession: + """Stand-in for db.session that returns pre-seeded objects by model class name.""" + + def __init__(self, mapping: dict[str, Any] | None = None): + self._mapping: dict[str, Any] = mapping or {} + self._model_name: str | None = None + + def query(self, model: type) -> FakeSession: + self._model_name = model.__name__ + return self + + def where(self, *_args: object, **_kwargs: object) -> FakeSession: + return self + + def first(self) -> Any: + assert self._model_name is not None + return self._mapping.get(self._model_name) + + +class FakeDB: + """Minimal db stub exposing engine and session.""" + + def __init__(self, session: FakeSession | None = None): + self.session = session or FakeSession() + self.engine = object() + + +def make_app_model( + *, + app_id: str = "app-1", + tenant_id: str = "tenant-1", + mode: str = "chat", + enable_site: bool = True, + status: str = "normal", +) -> SimpleNamespace: + """Build a fake App model with common defaults.""" + tenant = SimpleNamespace( + id=tenant_id, + status="normal", + plan="basic", + custom_config_dict={}, + ) + return SimpleNamespace( + id=app_id, + tenant_id=tenant_id, + tenant=tenant, + mode=mode, + enable_site=enable_site, + status=status, + workflow=None, + app_model_config=None, + ) + + +def make_end_user( + *, + user_id: str = "end-user-1", + session_id: str = "session-1", + external_user_id: str = "ext-user-1", +) -> SimpleNamespace: + """Build a fake EndUser model with common defaults.""" + return SimpleNamespace( + id=user_id, + session_id=session_id, + external_user_id=external_user_id, + ) diff --git a/api/tests/unit_tests/controllers/web/test_app.py b/api/tests/unit_tests/controllers/web/test_app.py new file mode 100644 index 0000000000..ce7ae27188 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_app.py @@ -0,0 +1,165 @@ +"""Unit tests for controllers.web.app endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.app import AppAccessMode, AppMeta, AppParameterApi, AppWebAuthPermission +from controllers.web.error import AppUnavailableError + + +# --------------------------------------------------------------------------- +# AppParameterApi +# --------------------------------------------------------------------------- +class TestAppParameterApi: + def test_advanced_chat_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {"opening_statement": "Hello"} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [], + ) + app_model = SimpleNamespace(mode="advanced-chat", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {"result": "ok"} + result = AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[]) + assert result == {"result": "ok"} + + def test_workflow_mode_uses_workflow(self, app: Flask) -> None: + features_dict = {} + workflow = SimpleNamespace( + features_dict=features_dict, + user_input_form=lambda to_old_structure=False: [{"var": "x"}], + ) + app_model = SimpleNamespace(mode="workflow", workflow=workflow) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + mock_params.assert_called_once_with(features_dict=features_dict, user_input_form=[{"var": "x"}]) + + def test_advanced_chat_mode_no_workflow_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="advanced-chat", workflow=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + def test_standard_mode_uses_app_model_config(self, app: Flask) -> None: + config = SimpleNamespace(to_dict=lambda: {"user_input_form": [{"var": "y"}], "key": "val"}) + app_model = SimpleNamespace(mode="chat", app_model_config=config) + + with ( + app.test_request_context("/parameters"), + patch("controllers.web.app.get_parameters_from_feature_dict", return_value={}) as mock_params, + patch("controllers.web.app.fields.Parameters") as mock_fields, + ): + mock_fields.model_validate.return_value.model_dump.return_value = {} + AppParameterApi().get(app_model, SimpleNamespace()) + + call_kwargs = mock_params.call_args + assert call_kwargs.kwargs["user_input_form"] == [{"var": "y"}] + + def test_standard_mode_no_config_raises(self, app: Flask) -> None: + app_model = SimpleNamespace(mode="chat", app_model_config=None) + with app.test_request_context("/parameters"): + with pytest.raises(AppUnavailableError): + AppParameterApi().get(app_model, SimpleNamespace()) + + +# --------------------------------------------------------------------------- +# AppMeta +# --------------------------------------------------------------------------- +class TestAppMeta: + @patch("controllers.web.app.AppService") + def test_get_returns_meta(self, mock_service_cls: MagicMock, app: Flask) -> None: + mock_service_cls.return_value.get_app_meta.return_value = {"tool_icons": {}} + app_model = SimpleNamespace(id="app-1") + + with app.test_request_context("/meta"): + result = AppMeta().get(app_model, SimpleNamespace()) + + assert result == {"tool_icons": {}} + + +# --------------------------------------------------------------------------- +# AppAccessMode +# --------------------------------------------------------------------------- +class TestAppAccessMode: + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_public_when_webapp_auth_disabled(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "public"} + + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_returns_access_mode_with_app_id( + self, mock_features: MagicMock, mock_access: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="internal") + + with app.test_request_context("/webapp/access-mode?appId=app-1"): + result = AppAccessMode().get() + + assert result == {"accessMode": "internal"} + mock_access.assert_called_once_with("app-1") + + @patch("controllers.web.app.AppService.get_app_id_by_code", return_value="resolved-id") + @patch("controllers.web.app.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.app.FeatureService.get_system_features") + def test_resolves_app_code_to_id( + self, mock_features: MagicMock, mock_access: MagicMock, mock_resolve: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_access.return_value = SimpleNamespace(access_mode="external") + + with app.test_request_context("/webapp/access-mode?appCode=code1"): + result = AppAccessMode().get() + + mock_resolve.assert_called_once_with("code1") + mock_access.assert_called_once_with("resolved-id") + assert result == {"accessMode": "external"} + + @patch("controllers.web.app.FeatureService.get_system_features") + def test_raises_when_no_app_id_or_code(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + + with app.test_request_context("/webapp/access-mode"): + with pytest.raises(ValueError, match="appId or appCode"): + AppAccessMode().get() + + +# --------------------------------------------------------------------------- +# AppWebAuthPermission +# --------------------------------------------------------------------------- +class TestAppWebAuthPermission: + @patch("controllers.web.app.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_returns_true_when_no_permission_check_required(self, mock_check: MagicMock, app: Flask) -> None: + with app.test_request_context("/webapp/permission?appId=app-1", headers={"X-App-Code": "code1"}): + result = AppWebAuthPermission().get() + + assert result == {"result": True} + + def test_raises_when_missing_app_id(self, app: Flask) -> None: + with app.test_request_context("/webapp/permission", headers={"X-App-Code": "code1"}): + with pytest.raises(ValueError, match="appId"): + AppWebAuthPermission().get() diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py new file mode 100644 index 0000000000..01f34345aa --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -0,0 +1,135 @@ +"""Unit tests for controllers.web.audio endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.audio import AudioApi, TextApi +from controllers.web.error import ( + AudioTooLargeError, + CompletionRequestError, + NoAudioUploadedError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError +from services.errors.audio import ( + AudioTooLargeServiceError, + NoAudioUploadedServiceError, + ProviderNotSupportSpeechToTextServiceError, + UnsupportedAudioTypeServiceError, +) + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1", external_user_id="ext-1") + + +# --------------------------------------------------------------------------- +# AudioApi (audio-to-text) +# --------------------------------------------------------------------------- +class TestAudioApi: + @patch("controllers.web.audio.AudioService.transcript_asr", return_value={"text": "hello"}) + def test_happy_path(self, mock_asr: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + data = {"file": (BytesIO(b"fake-audio"), "test.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + result = AudioApi().post(_app_model(), _end_user()) + + assert result == {"text": "hello"} + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=NoAudioUploadedServiceError()) + def test_no_audio_uploaded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b""), "empty.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(NoAudioUploadedError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=AudioTooLargeServiceError("too big")) + def test_audio_too_large(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"big"), "big.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(AudioTooLargeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=UnsupportedAudioTypeServiceError()) + def test_unsupported_type(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"bad"), "bad.xyz")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(UnsupportedAudioTypeError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderNotSupportSpeechToTextServiceError(), + ) + def test_provider_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotSupportSpeechToTextError): + AudioApi().post(_app_model(), _end_user()) + + @patch( + "controllers.web.audio.AudioService.transcript_asr", + side_effect=ProviderTokenNotInitError(description="no token"), + ) + def test_provider_not_init(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderNotInitializeError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=QuotaExceededError()) + def test_quota_exceeded(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderQuotaExceededError): + AudioApi().post(_app_model(), _end_user()) + + @patch("controllers.web.audio.AudioService.transcript_asr", side_effect=ModelCurrentlyNotSupportError()) + def test_model_not_support(self, mock_asr: MagicMock, app: Flask) -> None: + data = {"file": (BytesIO(b"x"), "x.mp3")} + with app.test_request_context("/audio-to-text", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + AudioApi().post(_app_model(), _end_user()) + + +# --------------------------------------------------------------------------- +# TextApi (text-to-audio) +# --------------------------------------------------------------------------- +class TestTextApi: + @patch("controllers.web.audio.AudioService.transcript_tts", return_value="audio-bytes") + @patch("controllers.web.audio.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello", "voice": "alloy"} + + with app.test_request_context("/text-to-audio", method="POST"): + result = TextApi().post(_app_model(), _end_user()) + + assert result == "audio-bytes" + mock_tts.assert_called_once() + + @patch( + "controllers.web.audio.AudioService.transcript_tts", + side_effect=InvokeError(description="invoke failed"), + ) + @patch("controllers.web.audio.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_tts: MagicMock, app: Flask) -> None: + mock_ns.payload = {"text": "hello"} + + with app.test_request_context("/text-to-audio", method="POST"): + with pytest.raises(CompletionRequestError): + TextApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py new file mode 100644 index 0000000000..e88bcf2ae6 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -0,0 +1,161 @@ +"""Unit tests for controllers.web.completion endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi +from controllers.web.error import ( + CompletionRequestError, + NotChatAppError, + NotCompletionAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError +from dify_graph.model_runtime.errors.invoke import InvokeError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# CompletionApi +# --------------------------------------------------------------------------- +class TestCompletionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "hi"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "test"} + mock_gen.return_value = "response-obj" + + with app.test_request_context("/completion-messages", method="POST"): + result = CompletionApi().post(_completion_app(), _end_user()) + + assert result == {"answer": "hi"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.completion.web_ns") + def test_provider_not_init_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderNotInitializeError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.completion.web_ns") + def test_quota_exceeded_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + CompletionApi().post(_completion_app(), _end_user()) + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=ModelCurrentlyNotSupportError(), + ) + @patch("controllers.web.completion.web_ns") + def test_model_not_support_error(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/completion-messages", method="POST"): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + CompletionApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# CompletionStopApi +# --------------------------------------------------------------------------- +class TestCompletionStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + with pytest.raises(NotCompletionAppError): + CompletionStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/completion-messages/task-1/stop", method="POST"): + result, status = CompletionStopApi().post(_completion_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} + + +# --------------------------------------------------------------------------- +# ChatApi +# --------------------------------------------------------------------------- +class TestChatApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(NotChatAppError): + ChatApi().post(_completion_app(), _end_user()) + + @patch("controllers.web.completion.helper.compact_generate_response", return_value={"answer": "reply"}) + @patch("controllers.web.completion.AppGenerateService.generate") + @patch("controllers.web.completion.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "hi"} + mock_gen.return_value = "response" + + with app.test_request_context("/chat-messages", method="POST"): + result = ChatApi().post(_chat_app(), _end_user()) + + assert result == {"answer": "reply"} + + @patch( + "controllers.web.completion.AppGenerateService.generate", + side_effect=InvokeError(description="rate limit"), + ) + @patch("controllers.web.completion.web_ns") + def test_invoke_error_mapped(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}, "query": "x"} + + with app.test_request_context("/chat-messages", method="POST"): + with pytest.raises(CompletionRequestError): + ChatApi().post(_chat_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# ChatStopApi +# --------------------------------------------------------------------------- +class TestChatStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + with pytest.raises(NotChatAppError): + ChatStopApi().post(_completion_app(), _end_user(), "task-1") + + @patch("controllers.web.completion.AppTaskService.stop_task") + def test_stop_success(self, mock_stop: MagicMock, app: Flask) -> None: + with app.test_request_context("/chat-messages/task-1/stop", method="POST"): + result, status = ChatStopApi().post(_chat_app(), _end_user(), "task-1") + + assert status == 200 + assert result == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/web/test_conversation.py b/api/tests/unit_tests/controllers/web/test_conversation.py new file mode 100644 index 0000000000..e5adbbbf66 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_conversation.py @@ -0,0 +1,183 @@ +"""Unit tests for controllers.web.conversation endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.conversation import ( + ConversationApi, + ConversationListApi, + ConversationPinApi, + ConversationRenameApi, + ConversationUnPinApi, +) +from controllers.web.error import NotChatAppError +from services.errors.conversation import ConversationNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# ConversationListApi +# --------------------------------------------------------------------------- +class TestConversationListApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/conversations"): + with pytest.raises(NotChatAppError): + ConversationListApi().get(_completion_app(), _end_user()) + + @patch("controllers.web.conversation.WebConversationService.pagination_by_last_id") + @patch("controllers.web.conversation.db") + def test_happy_path(self, mock_db: MagicMock, mock_paginate: MagicMock, app: Flask) -> None: + conv_id = str(uuid4()) + conv = SimpleNamespace( + id=conv_id, + name="Test", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[conv]) + mock_db.engine = "engine" + + session_mock = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + + with ( + app.test_request_context("/conversations?limit=20"), + patch("controllers.web.conversation.Session", return_value=session_ctx), + ): + result = ConversationListApi().get(_chat_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# ConversationApi (delete) +# --------------------------------------------------------------------------- +class TestConversationApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}"): + with pytest.raises(NotChatAppError): + ConversationApi().delete(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + result, status = ConversationApi().delete(_chat_app(), _end_user(), c_id) + + assert status == 204 + assert result["result"] == "success" + + @patch("controllers.web.conversation.ConversationService.delete", side_effect=ConversationNotExistsError()) + def test_delete_not_found(self, mock_delete: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}"): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationApi().delete(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationRenameApi +# --------------------------------------------------------------------------- +class TestConversationRenameApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/name", method="POST", json={"name": "x"}): + with pytest.raises(NotChatAppError): + ConversationRenameApi().post(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.ConversationService.rename") + @patch("controllers.web.conversation.web_ns") + def test_rename_success(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "New Name", "auto_generate": False} + conv = SimpleNamespace( + id=str(c_id), + name="New Name", + inputs={}, + status="normal", + introduction="", + created_at=1700000000, + updated_at=1700000000, + ) + mock_rename.return_value = conv + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "New Name"}): + result = ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + assert result["name"] == "New Name" + + @patch( + "controllers.web.conversation.ConversationService.rename", + side_effect=ConversationNotExistsError(), + ) + @patch("controllers.web.conversation.web_ns") + def test_rename_not_found(self, mock_ns: MagicMock, mock_rename: MagicMock, app: Flask) -> None: + c_id = uuid4() + mock_ns.payload = {"name": "X", "auto_generate": False} + + with app.test_request_context(f"/conversations/{c_id}/name", method="POST", json={"name": "X"}): + with pytest.raises(NotFound, match="Conversation Not Exists"): + ConversationRenameApi().post(_chat_app(), _end_user(), c_id) + + +# --------------------------------------------------------------------------- +# ConversationPinApi / ConversationUnPinApi +# --------------------------------------------------------------------------- +class TestConversationPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/pin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.pin") + def test_pin_success(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + result = ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" + + @patch("controllers.web.conversation.WebConversationService.pin", side_effect=ConversationNotExistsError()) + def test_pin_not_found(self, mock_pin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/pin", method="PATCH"): + with pytest.raises(NotFound): + ConversationPinApi().patch(_chat_app(), _end_user(), c_id) + + +class TestConversationUnPinApi: + def test_non_chat_mode_raises(self, app: Flask) -> None: + with app.test_request_context(f"/conversations/{uuid4()}/unpin", method="PATCH"): + with pytest.raises(NotChatAppError): + ConversationUnPinApi().patch(_completion_app(), _end_user(), uuid4()) + + @patch("controllers.web.conversation.WebConversationService.unpin") + def test_unpin_success(self, mock_unpin: MagicMock, app: Flask) -> None: + c_id = uuid4() + with app.test_request_context(f"/conversations/{c_id}/unpin", method="PATCH"): + result = ConversationUnPinApi().patch(_chat_app(), _end_user(), c_id) + + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_error.py b/api/tests/unit_tests/controllers/web/test_error.py new file mode 100644 index 0000000000..0387d002ba --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_error.py @@ -0,0 +1,75 @@ +"""Unit tests for controllers.web.error HTTP exception classes.""" + +from __future__ import annotations + +import pytest + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + AppSuggestedQuestionsAfterAnswerDisabledError, + AppUnavailableError, + AudioTooLargeError, + CompletionRequestError, + ConversationCompletedError, + InvalidArgumentError, + InvokeRateLimitError, + NoAudioUploadedError, + NotChatAppError, + NotCompletionAppError, + NotFoundError, + NotWorkflowAppError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderNotSupportSpeechToTextError, + ProviderQuotaExceededError, + UnsupportedAudioTypeError, + WebAppAuthAccessDeniedError, + WebAppAuthRequiredError, + WebFormRateLimitExceededError, +) + +_ERROR_SPECS: list[tuple[type, str, int]] = [ + (AppUnavailableError, "app_unavailable", 400), + (NotCompletionAppError, "not_completion_app", 400), + (NotChatAppError, "not_chat_app", 400), + (NotWorkflowAppError, "not_workflow_app", 400), + (ConversationCompletedError, "conversation_completed", 400), + (ProviderNotInitializeError, "provider_not_initialize", 400), + (ProviderQuotaExceededError, "provider_quota_exceeded", 400), + (ProviderModelCurrentlyNotSupportError, "model_currently_not_support", 400), + (CompletionRequestError, "completion_request_error", 400), + (AppMoreLikeThisDisabledError, "app_more_like_this_disabled", 403), + (AppSuggestedQuestionsAfterAnswerDisabledError, "app_suggested_questions_after_answer_disabled", 403), + (NoAudioUploadedError, "no_audio_uploaded", 400), + (AudioTooLargeError, "audio_too_large", 413), + (UnsupportedAudioTypeError, "unsupported_audio_type", 415), + (ProviderNotSupportSpeechToTextError, "provider_not_support_speech_to_text", 400), + (WebAppAuthRequiredError, "web_sso_auth_required", 401), + (WebAppAuthAccessDeniedError, "web_app_access_denied", 401), + (InvokeRateLimitError, "rate_limit_error", 429), + (WebFormRateLimitExceededError, "web_form_rate_limit_exceeded", 429), + (NotFoundError, "not_found", 404), + (InvalidArgumentError, "invalid_param", 400), +] + + +@pytest.mark.parametrize( + ("cls", "expected_code", "expected_status"), + _ERROR_SPECS, + ids=[cls.__name__ for cls, _, _ in _ERROR_SPECS], +) +def test_error_class_attributes(cls: type, expected_code: str, expected_status: int) -> None: + """Each error class exposes the correct error_code and HTTP status code.""" + assert cls.error_code == expected_code + assert cls.code == expected_status + + +def test_error_classes_have_description() -> None: + """Every error class has a description (string or None for generic errors).""" + # NotFoundError and InvalidArgumentError use None description by design + _NO_DESCRIPTION = {NotFoundError, InvalidArgumentError} + for cls, _, _ in _ERROR_SPECS: + if cls in _NO_DESCRIPTION: + continue + assert isinstance(cls.description, str), f"{cls.__name__} missing description" + assert len(cls.description) > 0, f"{cls.__name__} has empty description" diff --git a/api/tests/unit_tests/controllers/web/test_feature.py b/api/tests/unit_tests/controllers/web/test_feature.py new file mode 100644 index 0000000000..fe45d5f059 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_feature.py @@ -0,0 +1,38 @@ +"""Unit tests for controllers.web.feature endpoints.""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from flask import Flask + +from controllers.web.feature import SystemFeatureApi + + +class TestSystemFeatureApi: + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_returns_system_features(self, mock_features: MagicMock, app: Flask) -> None: + mock_model = MagicMock() + mock_model.model_dump.return_value = {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.return_value = mock_model + + with app.test_request_context("/system-features"): + result = SystemFeatureApi().get() + + assert result == {"sso_enforced_for_signin": False, "webapp_auth": {"enabled": False}} + mock_features.assert_called_once() + + @patch("controllers.web.feature.FeatureService.get_system_features") + def test_unauthenticated_access(self, mock_features: MagicMock, app: Flask) -> None: + """SystemFeatureApi is unauthenticated by design — no WebApiResource decorator.""" + mock_model = MagicMock() + mock_model.model_dump.return_value = {} + mock_features.return_value = mock_model + + # Verify it's a bare Resource, not WebApiResource + from flask_restx import Resource + + from controllers.web.wraps import WebApiResource + + assert issubclass(SystemFeatureApi, Resource) + assert not issubclass(SystemFeatureApi, WebApiResource) diff --git a/api/tests/unit_tests/controllers/web/test_files.py b/api/tests/unit_tests/controllers/web/test_files.py new file mode 100644 index 0000000000..a3921b0373 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_files.py @@ -0,0 +1,89 @@ +"""Unit tests for controllers.web.files endpoints.""" + +from __future__ import annotations + +from io import BytesIO +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, +) +from controllers.web.files import FileApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +class TestFileApi: + def test_no_file_uploaded(self, app: Flask) -> None: + with app.test_request_context("/files/upload", method="POST", content_type="multipart/form-data"): + with pytest.raises(NoFileUploadedError): + FileApi().post(_app_model(), _end_user()) + + def test_too_many_files(self, app: Flask) -> None: + data = { + "file": (BytesIO(b"a"), "a.txt"), + "file2": (BytesIO(b"b"), "b.txt"), + } + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + # Now has "file" key but len(request.files) > 1 + with pytest.raises(TooManyFilesError): + FileApi().post(_app_model(), _end_user()) + + def test_filename_missing(self, app: Flask) -> None: + data = {"file": (BytesIO(b"content"), "")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FilenameNotExistsError): + FileApi().post(_app_model(), _end_user()) + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_upload_success(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + from datetime import datetime + + upload_file = SimpleNamespace( + id="file-1", + name="test.txt", + size=100, + extension="txt", + mime_type="text/plain", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + data = {"file": (BytesIO(b"content"), "test.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + result, status = FileApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "file-1" + assert result["name"] == "test.txt" + + @patch("controllers.web.files.FileService") + @patch("controllers.web.files.db") + def test_file_too_large_from_service(self, mock_db: MagicMock, mock_file_svc_cls: MagicMock, app: Flask) -> None: + import services.errors.file + + mock_db.engine = "engine" + mock_file_svc_cls.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError( + description="max 10MB" + ) + + data = {"file": (BytesIO(b"big"), "big.txt")} + with app.test_request_context("/files/upload", method="POST", data=data, content_type="multipart/form-data"): + with pytest.raises(FileTooLargeError): + FileApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_human_input_form.py b/api/tests/unit_tests/controllers/web/test_human_input_form.py index 4fb735b033..a1dbc80b20 100644 --- a/api/tests/unit_tests/controllers/web/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/web/test_human_input_form.py @@ -49,6 +49,17 @@ class _FakeSession: assert self._model_name is not None return self._mapping.get(self._model_name) + def get(self, model, ident): + return self._mapping.get(model.__name__) + + def scalar(self, stmt): + # Extract the model name from the select statement's column_descriptions + try: + name = stmt.column_descriptions[0]["entity"].__name__ + except (AttributeError, IndexError, KeyError): + return None + return self._mapping.get(name) + class _FakeDB: """Minimal db stub exposing engine and session.""" diff --git a/api/tests/unit_tests/controllers/web/test_message_endpoints.py b/api/tests/unit_tests/controllers/web/test_message_endpoints.py new file mode 100644 index 0000000000..89ab93d8d4 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_message_endpoints.py @@ -0,0 +1,156 @@ +"""Unit tests for controllers.web.message — feedback, more-like-this, suggested questions.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import ( + AppMoreLikeThisDisabledError, + NotChatAppError, + NotCompletionAppError, +) +from controllers.web.message import ( + MessageFeedbackApi, + MessageMoreLikeThisApi, + MessageSuggestedQuestionApi, +) +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# MessageFeedbackApi +# --------------------------------------------------------------------------- +class TestMessageFeedbackApi: + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_success(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "like", "content": "great"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + mock_create.assert_called_once() + + @patch("controllers.web.message.MessageService.create_feedback") + @patch("controllers.web.message.web_ns") + def test_feedback_null_rating(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": None} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + result = MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + assert result == {"result": "success"} + + @patch( + "controllers.web.message.MessageService.create_feedback", + side_effect=MessageNotExistsError(), + ) + @patch("controllers.web.message.web_ns") + def test_feedback_message_not_found(self, mock_ns: MagicMock, mock_create: MagicMock, app: Flask) -> None: + mock_ns.payload = {"rating": "dislike"} + msg_id = uuid4() + + with app.test_request_context(f"/messages/{msg_id}/feedbacks", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageFeedbackApi().post(_chat_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageMoreLikeThisApi +# --------------------------------------------------------------------------- +class TestMessageMoreLikeThisApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotCompletionAppError): + MessageMoreLikeThisApi().get(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.message.helper.compact_generate_response", return_value={"answer": "similar"}) + @patch("controllers.web.message.AppGenerateService.generate_more_like_this") + def test_happy_path(self, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_gen.return_value = "response" + + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + result = MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + assert result == {"answer": "similar"} + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(NotFound, match="Message Not Exists"): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + @patch( + "controllers.web.message.AppGenerateService.generate_more_like_this", + side_effect=MoreLikeThisDisabledError(), + ) + def test_feature_disabled(self, mock_gen: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/more-like-this?response_mode=blocking"): + with pytest.raises(AppMoreLikeThisDisabledError): + MessageMoreLikeThisApi().get(_completion_app(), _end_user(), msg_id) + + +# --------------------------------------------------------------------------- +# MessageSuggestedQuestionApi +# --------------------------------------------------------------------------- +class TestMessageSuggestedQuestionApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + def test_wrong_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotChatAppError): + MessageSuggestedQuestionApi().get(_completion_app(), _end_user(), msg_id) + + @patch("controllers.web.message.MessageService.get_suggested_questions_after_answer") + def test_happy_path(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + mock_suggest.return_value = ["What about X?", "Tell me more about Y."] + + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + result = MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) + + assert result["data"] == ["What about X?", "Tell me more about Y."] + + @patch( + "controllers.web.message.MessageService.get_suggested_questions_after_answer", + side_effect=MessageNotExistsError(), + ) + def test_message_not_found(self, mock_suggest: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/messages/{msg_id}/suggested-questions"): + with pytest.raises(NotFound, match="Message not found"): + MessageSuggestedQuestionApi().get(_chat_app(), _end_user(), msg_id) diff --git a/api/tests/unit_tests/controllers/web/test_message_list.py b/api/tests/unit_tests/controllers/web/test_message_list.py index 1c096bfbcf..2bb425cdba 100644 --- a/api/tests/unit_tests/controllers/web/test_message_list.py +++ b/api/tests/unit_tests/controllers/web/test_message_list.py @@ -137,7 +137,7 @@ def test_message_list_mapping(app: Flask) -> None: {"id": "file-dict", "filename": "a.txt", "type": "file", "transfer_method": "local"}, message_file_obj, ], - status="success", + status="normal", error=None, message_metadata_dict={"meta": "value"}, extra_contents=[ diff --git a/api/tests/unit_tests/controllers/web/test_passport.py b/api/tests/unit_tests/controllers/web/test_passport.py new file mode 100644 index 0000000000..58d58626b2 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_passport.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportService, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +def test_decode_enterprise_webapp_user_id_none() -> None: + assert decode_enterprise_webapp_user_id(None) is None + + +def test_decode_enterprise_webapp_user_id_invalid_source(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: {"token_source": "bad"}) + with pytest.raises(Unauthorized): + decode_enterprise_webapp_user_id("token") + + +def test_decode_enterprise_webapp_user_id_valid(monkeypatch: pytest.MonkeyPatch) -> None: + decoded = {"token_source": "webapp_login_token", "user_id": "u1"} + monkeypatch.setattr(PassportService, "verify", lambda *_args, **_kwargs: decoded) + assert decode_enterprise_webapp_user_id("token") == decoded + + +def test_exchange_token_public_flow(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + monkeypatch.setattr("controllers.web.passport._exchange_for_public_app_token", lambda *_args, **_kwargs: "resp") + + decoded = {"auth_type": "public"} + result = exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.PUBLIC) + assert result == "resp" + + +def test_exchange_token_requires_external(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True) + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + return site if _scalar_side_effect.calls == 1 else app_model + + db_session = SimpleNamespace(scalar=_scalar_side_effect) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(WebAppAuthRequiredError): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.EXTERNAL) + + +def test_exchange_token_missing_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + site = SimpleNamespace(id="s1", app_id="a1", code="code", status="normal") + app_model = SimpleNamespace(id="a1", status="normal", enable_site=True, tenant_id="t1") + + def _scalar_side_effect(*_args, **_kwargs): + if not hasattr(_scalar_side_effect, "calls"): + _scalar_side_effect.calls = 0 + _scalar_side_effect.calls += 1 + if _scalar_side_effect.calls == 1: + return site + if _scalar_side_effect.calls == 2: + return app_model + return None + + db_session = SimpleNamespace(scalar=_scalar_side_effect, add=lambda *_a, **_k: None, commit=lambda: None) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + decoded = {"auth_type": "internal"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user("code", decoded, WebAppAuthType.INTERNAL) + + +def test_generate_session_id(monkeypatch: pytest.MonkeyPatch) -> None: + counts = [1, 0] + + def _scalar(*_args, **_kwargs): + return counts.pop(0) + + db_session = SimpleNamespace(scalar=_scalar) + monkeypatch.setattr("controllers.web.passport.db", SimpleNamespace(session=db_session)) + + session_id = generate_session_id() + assert session_id diff --git a/api/tests/unit_tests/controllers/web/test_pydantic_models.py b/api/tests/unit_tests/controllers/web/test_pydantic_models.py new file mode 100644 index 0000000000..dcf8133712 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_pydantic_models.py @@ -0,0 +1,423 @@ +"""Unit tests for Pydantic models defined in controllers.web modules. + +Covers validation logic, field defaults, constraints, and custom validators +for all ~15 Pydantic models across the web controller layer. +""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from pydantic import ValidationError + +# --------------------------------------------------------------------------- +# app.py models +# --------------------------------------------------------------------------- +from controllers.web.app import AppAccessModeQuery + + +class TestAppAccessModeQuery: + def test_alias_resolution(self) -> None: + q = AppAccessModeQuery.model_validate({"appId": "abc", "appCode": "xyz"}) + assert q.app_id == "abc" + assert q.app_code == "xyz" + + def test_defaults_to_none(self) -> None: + q = AppAccessModeQuery.model_validate({}) + assert q.app_id is None + assert q.app_code is None + + def test_accepts_snake_case(self) -> None: + q = AppAccessModeQuery(app_id="id1", app_code="code1") + assert q.app_id == "id1" + assert q.app_code == "code1" + + +# --------------------------------------------------------------------------- +# audio.py models +# --------------------------------------------------------------------------- +from controllers.web.audio import TextToAudioPayload + + +class TestTextToAudioPayload: + def test_defaults(self) -> None: + p = TextToAudioPayload.model_validate({}) + assert p.message_id is None + assert p.voice is None + assert p.text is None + assert p.streaming is None + + def test_valid_uuid_message_id(self) -> None: + uid = str(uuid4()) + p = TextToAudioPayload(message_id=uid) + assert p.message_id == uid + + def test_none_message_id_passthrough(self) -> None: + p = TextToAudioPayload(message_id=None) + assert p.message_id is None + + def test_invalid_uuid_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + TextToAudioPayload(message_id="not-a-uuid") + + +# --------------------------------------------------------------------------- +# completion.py models +# --------------------------------------------------------------------------- +from controllers.web.completion import ChatMessagePayload, CompletionMessagePayload + + +class TestCompletionMessagePayload: + def test_defaults(self) -> None: + p = CompletionMessagePayload(inputs={}) + assert p.query == "" + assert p.files is None + assert p.response_mode is None + assert p.retriever_from == "web_app" + + def test_accepts_full_payload(self) -> None: + p = CompletionMessagePayload( + inputs={"key": "val"}, + query="test", + files=[{"id": "f1"}], + response_mode="streaming", + ) + assert p.response_mode == "streaming" + assert p.files == [{"id": "f1"}] + + def test_invalid_response_mode(self) -> None: + with pytest.raises(ValidationError): + CompletionMessagePayload(inputs={}, response_mode="invalid") + + +class TestChatMessagePayload: + def test_valid_uuid_fields(self) -> None: + cid = str(uuid4()) + pid = str(uuid4()) + p = ChatMessagePayload(inputs={}, query="hi", conversation_id=cid, parent_message_id=pid) + assert p.conversation_id == cid + assert p.parent_message_id == pid + + def test_none_uuid_fields(self) -> None: + p = ChatMessagePayload(inputs={}, query="hi") + assert p.conversation_id is None + assert p.parent_message_id is None + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", conversation_id="bad") + + def test_invalid_parent_message_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ChatMessagePayload(inputs={}, query="hi", parent_message_id="bad") + + def test_query_required(self) -> None: + with pytest.raises(ValidationError): + ChatMessagePayload(inputs={}) + + +# --------------------------------------------------------------------------- +# conversation.py models +# --------------------------------------------------------------------------- +from controllers.web.conversation import ConversationListQuery, ConversationRenamePayload + + +class TestConversationListQuery: + def test_defaults(self) -> None: + q = ConversationListQuery() + assert q.last_id is None + assert q.limit == 20 + assert q.pinned is None + assert q.sort_by == "-updated_at" + + def test_limit_lower_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=0) + + def test_limit_upper_bound(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(limit=101) + + def test_limit_boundaries_valid(self) -> None: + assert ConversationListQuery(limit=1).limit == 1 + assert ConversationListQuery(limit=100).limit == 100 + + def test_valid_sort_by_options(self) -> None: + for opt in ("created_at", "-created_at", "updated_at", "-updated_at"): + assert ConversationListQuery(sort_by=opt).sort_by == opt + + def test_invalid_sort_by(self) -> None: + with pytest.raises(ValidationError): + ConversationListQuery(sort_by="invalid") + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + assert ConversationListQuery(last_id=uid).last_id == uid + + def test_invalid_last_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + ConversationListQuery(last_id="not-uuid") + + +class TestConversationRenamePayload: + def test_auto_generate_true_no_name_required(self) -> None: + p = ConversationRenamePayload(auto_generate=True) + assert p.name is None + + def test_auto_generate_false_requires_name(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False) + + def test_auto_generate_false_blank_name_rejected(self) -> None: + with pytest.raises(ValidationError, match="name is required"): + ConversationRenamePayload(auto_generate=False, name=" ") + + def test_auto_generate_false_with_valid_name(self) -> None: + p = ConversationRenamePayload(auto_generate=False, name="My Chat") + assert p.name == "My Chat" + + def test_defaults(self) -> None: + p = ConversationRenamePayload(name="test") + assert p.auto_generate is False + assert p.name == "test" + + +# --------------------------------------------------------------------------- +# message.py models +# --------------------------------------------------------------------------- +from controllers.web.message import MessageFeedbackPayload, MessageListQuery, MessageMoreLikeThisQuery + + +class TestMessageListQuery: + def test_valid_query(self) -> None: + cid = str(uuid4()) + q = MessageListQuery(conversation_id=cid) + assert q.conversation_id == cid + assert q.first_id is None + assert q.limit == 20 + + def test_invalid_conversation_id(self) -> None: + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id="bad") + + def test_limit_bounds(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=0) + with pytest.raises(ValidationError): + MessageListQuery(conversation_id=cid, limit=101) + + def test_valid_first_id(self) -> None: + cid = str(uuid4()) + fid = str(uuid4()) + q = MessageListQuery(conversation_id=cid, first_id=fid) + assert q.first_id == fid + + def test_invalid_first_id(self) -> None: + cid = str(uuid4()) + with pytest.raises(ValidationError, match="not a valid uuid"): + MessageListQuery(conversation_id=cid, first_id="invalid") + + +class TestMessageFeedbackPayload: + def test_defaults(self) -> None: + p = MessageFeedbackPayload() + assert p.rating is None + assert p.content is None + + def test_valid_ratings(self) -> None: + assert MessageFeedbackPayload(rating="like").rating == "like" + assert MessageFeedbackPayload(rating="dislike").rating == "dislike" + + def test_invalid_rating(self) -> None: + with pytest.raises(ValidationError): + MessageFeedbackPayload(rating="neutral") + + +class TestMessageMoreLikeThisQuery: + def test_valid_modes(self) -> None: + assert MessageMoreLikeThisQuery(response_mode="blocking").response_mode == "blocking" + assert MessageMoreLikeThisQuery(response_mode="streaming").response_mode == "streaming" + + def test_invalid_mode(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery(response_mode="invalid") + + def test_required(self) -> None: + with pytest.raises(ValidationError): + MessageMoreLikeThisQuery() + + +# --------------------------------------------------------------------------- +# remote_files.py models +# --------------------------------------------------------------------------- +from controllers.web.remote_files import RemoteFileUploadPayload + + +class TestRemoteFileUploadPayload: + def test_valid_url(self) -> None: + p = RemoteFileUploadPayload(url="https://example.com/file.pdf") + assert str(p.url) == "https://example.com/file.pdf" + + def test_invalid_url(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload(url="not-a-url") + + def test_url_required(self) -> None: + with pytest.raises(ValidationError): + RemoteFileUploadPayload() + + +# --------------------------------------------------------------------------- +# saved_message.py models +# --------------------------------------------------------------------------- +from controllers.web.saved_message import SavedMessageCreatePayload, SavedMessageListQuery + + +class TestSavedMessageListQuery: + def test_defaults(self) -> None: + q = SavedMessageListQuery() + assert q.last_id is None + assert q.limit == 20 + + def test_limit_bounds(self) -> None: + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=0) + with pytest.raises(ValidationError): + SavedMessageListQuery(limit=101) + + def test_valid_last_id(self) -> None: + uid = str(uuid4()) + q = SavedMessageListQuery(last_id=uid) + assert q.last_id == uid + + def test_empty_last_id(self) -> None: + q = SavedMessageListQuery(last_id="") + assert q.last_id == "" + + +class TestSavedMessageCreatePayload: + def test_valid_message_id(self) -> None: + uid = str(uuid4()) + p = SavedMessageCreatePayload(message_id=uid) + assert p.message_id == uid + + def test_required(self) -> None: + with pytest.raises(ValidationError): + SavedMessageCreatePayload() + + +# --------------------------------------------------------------------------- +# workflow.py models +# --------------------------------------------------------------------------- +from controllers.web.workflow import WorkflowRunPayload + + +class TestWorkflowRunPayload: + def test_defaults(self) -> None: + p = WorkflowRunPayload(inputs={}) + assert p.inputs == {} + assert p.files is None + + def test_with_files(self) -> None: + p = WorkflowRunPayload(inputs={"k": "v"}, files=[{"id": "f1"}]) + assert p.files == [{"id": "f1"}] + + def test_inputs_required(self) -> None: + with pytest.raises(ValidationError): + WorkflowRunPayload() + + +# --------------------------------------------------------------------------- +# forgot_password.py models +# --------------------------------------------------------------------------- +from controllers.web.forgot_password import ( + ForgotPasswordCheckPayload, + ForgotPasswordResetPayload, + ForgotPasswordSendPayload, +) + + +class TestForgotPasswordSendPayload: + def test_valid_email(self) -> None: + p = ForgotPasswordSendPayload(email="user@example.com") + assert p.email == "user@example.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + ForgotPasswordSendPayload(email="not-an-email") + + def test_language_optional(self) -> None: + p = ForgotPasswordSendPayload(email="a@b.com") + assert p.language is None + + +class TestForgotPasswordCheckPayload: + def test_valid(self) -> None: + p = ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="tok") + assert p.email == "a@b.com" + assert p.code == "1234" + assert p.token == "tok" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + ForgotPasswordCheckPayload(email="a@b.com", code="1234", token="") + + +class TestForgotPasswordResetPayload: + def test_valid_passwords(self) -> None: + p = ForgotPasswordResetPayload(token="tok", new_password="Valid1234", password_confirm="Valid1234") + assert p.new_password == "Valid1234" + + def test_weak_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="short", password_confirm="short") + + def test_letters_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="abcdefghi", password_confirm="abcdefghi") + + def test_digits_only_password_rejected(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + ForgotPasswordResetPayload(token="tok", new_password="123456789", password_confirm="123456789") + + +# --------------------------------------------------------------------------- +# login.py models +# --------------------------------------------------------------------------- +from controllers.web.login import EmailCodeLoginSendPayload, EmailCodeLoginVerifyPayload, LoginPayload + + +class TestLoginPayload: + def test_valid(self) -> None: + p = LoginPayload(email="a@b.com", password="Valid1234") + assert p.email == "a@b.com" + + def test_invalid_email(self) -> None: + with pytest.raises(ValidationError, match="not a valid email"): + LoginPayload(email="bad", password="Valid1234") + + def test_weak_password(self) -> None: + with pytest.raises(ValidationError, match="Password must contain"): + LoginPayload(email="a@b.com", password="weak") + + +class TestEmailCodeLoginSendPayload: + def test_valid(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com") + assert p.language is None + + def test_with_language(self) -> None: + p = EmailCodeLoginSendPayload(email="a@b.com", language="zh-Hans") + assert p.language == "zh-Hans" + + +class TestEmailCodeLoginVerifyPayload: + def test_valid(self) -> None: + p = EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="tok") + assert p.code == "1234" + + def test_empty_token_rejected(self) -> None: + with pytest.raises(ValidationError): + EmailCodeLoginVerifyPayload(email="a@b.com", code="1234", token="") diff --git a/api/tests/unit_tests/controllers/web/test_remote_files.py b/api/tests/unit_tests/controllers/web/test_remote_files.py new file mode 100644 index 0000000000..8554f440b7 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_remote_files.py @@ -0,0 +1,147 @@ +"""Unit tests for controllers.web.remote_files endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.common.errors import FileTooLargeError, RemoteFileUploadError +from controllers.web.remote_files import RemoteFileInfoApi, RemoteFileUploadApi + + +def _app_model() -> SimpleNamespace: + return SimpleNamespace(id="app-1") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# RemoteFileInfoApi +# --------------------------------------------------------------------------- +class TestRemoteFileInfoApi: + @patch("controllers.web.remote_files.ssrf_proxy") + def test_head_success(self, mock_proxy: MagicMock, app: Flask) -> None: + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"Content-Type": "application/pdf", "Content-Length": "1024"} + mock_proxy.head.return_value = mock_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.pdf"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.pdf") + + assert result["file_type"] == "application/pdf" + assert result["file_length"] == 1024 + + @patch("controllers.web.remote_files.ssrf_proxy") + def test_fallback_to_get(self, mock_proxy: MagicMock, app: Flask) -> None: + head_resp = MagicMock() + head_resp.status_code = 405 # Method not allowed + get_resp = MagicMock() + get_resp.status_code = 200 + get_resp.headers = {"Content-Type": "text/plain", "Content-Length": "42"} + get_resp.raise_for_status = MagicMock() + mock_proxy.head.return_value = head_resp + mock_proxy.get.return_value = get_resp + + with app.test_request_context("/remote-files/https%3A%2F%2Fexample.com%2Ffile.txt"): + result = RemoteFileInfoApi().get(_app_model(), _end_user(), "https%3A%2F%2Fexample.com%2Ffile.txt") + + assert result["file_type"] == "text/plain" + mock_proxy.get.assert_called_once() + + +# --------------------------------------------------------------------------- +# RemoteFileUploadApi +# --------------------------------------------------------------------------- +class TestRemoteFileUploadApi: + @patch("controllers.web.remote_files.file_helpers.get_signed_file_url", return_value="https://signed-url") + @patch("controllers.web.remote_files.FileService") + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + @patch("controllers.web.remote_files.db") + def test_upload_success( + self, + mock_db: MagicMock, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_file_svc_cls: MagicMock, + mock_signed: MagicMock, + app: Flask, + ) -> None: + mock_db.engine = "engine" + mock_ns.payload = {"url": "https://example.com/file.pdf"} + head_resp = MagicMock() + head_resp.status_code = 200 + head_resp.content = b"pdf-content" + head_resp.request.method = "HEAD" + mock_proxy.head.return_value = head_resp + get_resp = MagicMock() + get_resp.content = b"pdf-content" + mock_proxy.get.return_value = get_resp + + mock_guess.return_value = SimpleNamespace( + filename="file.pdf", extension="pdf", mimetype="application/pdf", size=100 + ) + mock_file_svc_cls.is_file_size_within_limit.return_value = True + + from datetime import datetime + + upload_file = SimpleNamespace( + id="f-1", + name="file.pdf", + size=100, + extension="pdf", + mime_type="application/pdf", + created_by="eu-1", + created_at=datetime(2024, 1, 1), + ) + mock_file_svc_cls.return_value.upload_file.return_value = upload_file + + with app.test_request_context("/remote-files/upload", method="POST"): + result, status = RemoteFileUploadApi().post(_app_model(), _end_user()) + + assert status == 201 + assert result["id"] == "f-1" + + @patch("controllers.web.remote_files.FileService.is_file_size_within_limit", return_value=False) + @patch("controllers.web.remote_files.helpers.guess_file_info_from_response") + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_file_too_large( + self, + mock_ns: MagicMock, + mock_proxy: MagicMock, + mock_guess: MagicMock, + mock_size_check: MagicMock, + app: Flask, + ) -> None: + mock_ns.payload = {"url": "https://example.com/big.zip"} + head_resp = MagicMock() + head_resp.status_code = 200 + mock_proxy.head.return_value = head_resp + mock_guess.return_value = SimpleNamespace( + filename="big.zip", extension="zip", mimetype="application/zip", size=999999999 + ) + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(FileTooLargeError): + RemoteFileUploadApi().post(_app_model(), _end_user()) + + @patch("controllers.web.remote_files.ssrf_proxy") + @patch("controllers.web.remote_files.web_ns") + def test_fetch_failure_raises(self, mock_ns: MagicMock, mock_proxy: MagicMock, app: Flask) -> None: + import httpx + + mock_ns.payload = {"url": "https://example.com/bad"} + mock_proxy.head.side_effect = httpx.RequestError("connection failed") + + with app.test_request_context("/remote-files/upload", method="POST"): + with pytest.raises(RemoteFileUploadError): + RemoteFileUploadApi().post(_app_model(), _end_user()) diff --git a/api/tests/unit_tests/controllers/web/test_saved_message.py b/api/tests/unit_tests/controllers/web/test_saved_message.py new file mode 100644 index 0000000000..3d55804912 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_saved_message.py @@ -0,0 +1,97 @@ +"""Unit tests for controllers.web.saved_message endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.web.error import NotCompletionAppError +from controllers.web.saved_message import SavedMessageApi, SavedMessageListApi +from services.errors.message import MessageNotExistsError + + +def _completion_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="completion") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (GET) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiGet: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().get(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.pagination_by_last_id") + def test_happy_path(self, mock_paginate: MagicMock, app: Flask) -> None: + mock_paginate.return_value = SimpleNamespace(limit=20, has_more=False, data=[]) + + with app.test_request_context("/saved-messages?limit=20"): + result = SavedMessageListApi().get(_completion_app(), _end_user()) + + assert result["limit"] == 20 + assert result["has_more"] is False + + +# --------------------------------------------------------------------------- +# SavedMessageListApi (POST) +# --------------------------------------------------------------------------- +class TestSavedMessageListApiPost: + def test_non_completion_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotCompletionAppError): + SavedMessageListApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.saved_message.SavedMessageService.save") + @patch("controllers.web.saved_message.web_ns") + def test_save_success(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + msg_id = str(uuid4()) + mock_ns.payload = {"message_id": msg_id} + + with app.test_request_context("/saved-messages", method="POST"): + result = SavedMessageListApi().post(_completion_app(), _end_user()) + + assert result["result"] == "success" + + @patch("controllers.web.saved_message.SavedMessageService.save", side_effect=MessageNotExistsError()) + @patch("controllers.web.saved_message.web_ns") + def test_save_not_found(self, mock_ns: MagicMock, mock_save: MagicMock, app: Flask) -> None: + mock_ns.payload = {"message_id": str(uuid4())} + + with app.test_request_context("/saved-messages", method="POST"): + with pytest.raises(NotFound, match="Message Not Exists"): + SavedMessageListApi().post(_completion_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# SavedMessageApi (DELETE) +# --------------------------------------------------------------------------- +class TestSavedMessageApi: + def test_non_completion_mode_raises(self, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + with pytest.raises(NotCompletionAppError): + SavedMessageApi().delete(_chat_app(), _end_user(), msg_id) + + @patch("controllers.web.saved_message.SavedMessageService.delete") + def test_delete_success(self, mock_delete: MagicMock, app: Flask) -> None: + msg_id = uuid4() + with app.test_request_context(f"/saved-messages/{msg_id}", method="DELETE"): + result, status = SavedMessageApi().delete(_completion_app(), _end_user(), msg_id) + + assert status == 204 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/web/test_site.py b/api/tests/unit_tests/controllers/web/test_site.py new file mode 100644 index 0000000000..6e9d754c43 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_site.py @@ -0,0 +1,126 @@ +"""Unit tests for controllers.web.site endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden + +from controllers.web.site import AppSiteApi, AppSiteInfo + + +def _tenant(*, status: str = "normal") -> SimpleNamespace: + return SimpleNamespace( + id="tenant-1", + status=status, + plan="basic", + custom_config_dict={"remove_webapp_brand": False, "replace_webapp_logo": False}, + ) + + +def _site() -> SimpleNamespace: + return SimpleNamespace( + title="Site", + icon_type="emoji", + icon="robot", + icon_background="#fff", + description="desc", + default_language="en", + chat_color_theme="light", + chat_color_theme_inverted=False, + copyright=None, + privacy_policy=None, + custom_disclaimer=None, + prompt_public=False, + show_workflow_steps=True, + use_icon_as_answer_icon=False, + ) + + +# --------------------------------------------------------------------------- +# AppSiteApi +# --------------------------------------------------------------------------- +class TestAppSiteApi: + @patch("controllers.web.site.FeatureService.get_features") + @patch("controllers.web.site.db") + def test_happy_path(self, mock_db: MagicMock, mock_features: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_features.return_value = SimpleNamespace(can_replace_logo=False) + site_obj = _site() + mock_db.session.scalar.return_value = site_obj + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + result = AppSiteApi().get(app_model, end_user) + + # marshal_with serializes AppSiteInfo to a dict + assert result["app_id"] == "app-1" + assert result["plan"] == "basic" + assert result["enable_site"] is True + + @patch("controllers.web.site.db") + def test_missing_site_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + mock_db.session.scalar.return_value = None + tenant = _tenant() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant, enable_site=True) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + @patch("controllers.web.site.db") + def test_archived_tenant_raises_forbidden(self, mock_db: MagicMock, app: Flask) -> None: + app.config["RESTX_MASK_HEADER"] = "X-Fields" + from models.account import TenantStatus + + mock_db.session.scalar.return_value = _site() + tenant = SimpleNamespace( + id="tenant-1", + status=TenantStatus.ARCHIVE, + plan="basic", + custom_config_dict={}, + ) + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", tenant=tenant) + end_user = SimpleNamespace(id="eu-1") + + with app.test_request_context("/site"): + with pytest.raises(Forbidden): + AppSiteApi().get(app_model, end_user) + + +# --------------------------------------------------------------------------- +# AppSiteInfo +# --------------------------------------------------------------------------- +class TestAppSiteInfo: + def test_basic_fields(self) -> None: + tenant = _tenant() + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", False) + + assert info.app_id == "app-1" + assert info.end_user_id == "eu-1" + assert info.enable_site is True + assert info.plan == "basic" + assert info.can_replace_logo is False + assert info.model_config is None + + @patch("controllers.web.site.dify_config", SimpleNamespace(FILES_URL="https://files.example.com")) + def test_can_replace_logo_sets_custom_config(self) -> None: + tenant = SimpleNamespace( + id="tenant-1", + plan="pro", + custom_config_dict={"remove_webapp_brand": True, "replace_webapp_logo": True}, + ) + site_obj = _site() + info = AppSiteInfo(tenant, SimpleNamespace(id="app-1", enable_site=True), site_obj, "eu-1", True) + + assert info.can_replace_logo is True + assert info.custom_config["remove_webapp_brand"] is True + assert "webapp-logo" in info.custom_config["replace_webapp_logo"] diff --git a/api/tests/unit_tests/controllers/web/test_web_login.py b/api/tests/unit_tests/controllers/web/test_web_login.py index e62993e8d5..0661c02578 100644 --- a/api/tests/unit_tests/controllers/web/test_web_login.py +++ b/api/tests/unit_tests/controllers/web/test_web_login.py @@ -5,7 +5,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask -from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi +import services.errors.account +from controllers.web.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi, LoginApi, LoginStatusApi, LogoutApi def encode_code(code: str) -> str: @@ -89,3 +90,114 @@ class TestEmailCodeLoginApi: mock_revoke_token.assert_called_once_with("token-123") mock_login.assert_called_once() mock_reset_login_rate.assert_called_once_with("user@example.com") + + +class TestLoginApi: + @patch("controllers.web.login.WebAppAuthService.login", return_value="access-tok") + @patch("controllers.web.login.WebAppAuthService.authenticate") + def test_login_success(self, mock_auth: MagicMock, mock_login: MagicMock, app: Flask) -> None: + mock_auth.return_value = MagicMock() + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + response = LoginApi().post() + + assert response.get_json()["data"]["access_token"] == "access-tok" + mock_auth.assert_called_once() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountLoginError(), + ) + def test_login_banned_account(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.error import AccountBannedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AccountBannedError): + LoginApi().post() + + @patch( + "controllers.web.login.WebAppAuthService.authenticate", + side_effect=services.errors.account.AccountPasswordError(), + ) + def test_login_wrong_password(self, mock_auth: MagicMock, app: Flask) -> None: + from controllers.console.auth.error import AuthenticationFailedError + + with app.test_request_context( + "/web/login", + method="POST", + json={"email": "user@example.com", "password": base64.b64encode(b"Valid1234").decode()}, + ): + with pytest.raises(AuthenticationFailedError): + LoginApi().post() + + +class TestLoginStatusApi: + @patch("controllers.web.login.extract_webapp_access_token", return_value=None) + def test_no_app_code_returns_logged_in_false(self, mock_extract: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/login/status"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + @patch("controllers.web.login.decode_jwt_token") + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=False) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_public_app_user_logged_in( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_decode.return_value = (MagicMock(), MagicMock()) + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is True + assert result["app_logged_in"] is True + + @patch("controllers.web.login.decode_jwt_token", side_effect=Exception("bad")) + @patch("controllers.web.login.PassportService") + @patch("controllers.web.login.WebAppAuthService.is_app_require_permission_check", return_value=True) + @patch("controllers.web.login.AppService.get_app_id_by_code", return_value="app-1") + @patch("controllers.web.login.extract_webapp_access_token", return_value="tok") + def test_private_app_passport_fails( + self, + mock_extract: MagicMock, + mock_app_id: MagicMock, + mock_perm: MagicMock, + mock_passport_cls: MagicMock, + mock_decode: MagicMock, + app: Flask, + ) -> None: + mock_passport_cls.return_value.verify.side_effect = Exception("bad") + + with app.test_request_context("/web/login/status?app_code=code1"): + result = LoginStatusApi().get() + + assert result["logged_in"] is False + assert result["app_logged_in"] is False + + +class TestLogoutApi: + @patch("controllers.web.login.clear_webapp_access_token_from_cookie") + def test_logout_success(self, mock_clear: MagicMock, app: Flask) -> None: + with app.test_request_context("/web/logout", method="POST"): + response = LogoutApi().post() + + assert response.get_json() == {"result": "success"} + mock_clear.assert_called_once() diff --git a/api/tests/unit_tests/controllers/web/test_web_passport.py b/api/tests/unit_tests/controllers/web/test_web_passport.py new file mode 100644 index 0000000000..19b1d8504a --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_web_passport.py @@ -0,0 +1,192 @@ +"""Unit tests for controllers.web.passport — token issuance and enterprise auth exchange.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound, Unauthorized + +from controllers.web.error import WebAppAuthRequiredError +from controllers.web.passport import ( + PassportResource, + decode_enterprise_webapp_user_id, + exchange_token_for_existing_web_user, + generate_session_id, +) +from services.webapp_auth_service import WebAppAuthType + + +# --------------------------------------------------------------------------- +# decode_enterprise_webapp_user_id +# --------------------------------------------------------------------------- +class TestDecodeEnterpriseWebappUserId: + def test_none_token_returns_none(self) -> None: + assert decode_enterprise_webapp_user_id(None) is None + + @patch("controllers.web.passport.PassportService") + def test_valid_token_returns_decoded(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "webapp_login_token", + "user_id": "u1", + } + result = decode_enterprise_webapp_user_id("valid-jwt") + assert result["user_id"] == "u1" + + @patch("controllers.web.passport.PassportService") + def test_wrong_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = { + "token_source": "other_source", + } + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("bad-jwt") + + @patch("controllers.web.passport.PassportService") + def test_missing_source_raises_unauthorized(self, mock_passport_cls: MagicMock) -> None: + mock_passport_cls.return_value.verify.return_value = {} + with pytest.raises(Unauthorized, match="Expected 'webapp_login_token'"): + decode_enterprise_webapp_user_id("no-source-jwt") + + +# --------------------------------------------------------------------------- +# generate_session_id +# --------------------------------------------------------------------------- +class TestGenerateSessionId: + @patch("controllers.web.passport.db") + def test_returns_unique_session_id(self, mock_db: MagicMock) -> None: + mock_db.session.scalar.return_value = 0 + sid = generate_session_id() + assert isinstance(sid, str) + assert len(sid) == 36 # UUID format + + @patch("controllers.web.passport.db") + def test_retries_on_collision(self, mock_db: MagicMock) -> None: + # First call returns count=1 (collision), second returns 0 + mock_db.session.scalar.side_effect = [1, 0] + sid = generate_session_id() + assert isinstance(sid, str) + assert mock_db.session.scalar.call_count == 2 + + +# --------------------------------------------------------------------------- +# exchange_token_for_existing_web_user +# --------------------------------------------------------------------------- +class TestExchangeTokenForExistingWebUser: + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_external_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "internal"} # mismatch: expected "external" + with pytest.raises(WebAppAuthRequiredError, match="external"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_internal_auth_type_mismatch_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + site = SimpleNamespace(code="code1", app_id="app-1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + + decoded = {"user_id": "u1", "auth_type": "external"} # mismatch: expected "internal" + with pytest.raises(WebAppAuthRequiredError, match="internal"): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.INTERNAL + ) + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + def test_site_not_found_raises(self, mock_db: MagicMock, mock_passport_cls: MagicMock) -> None: + mock_db.session.scalar.return_value = None + decoded = {"user_id": "u1", "auth_type": "external"} + with pytest.raises(NotFound): + exchange_token_for_existing_web_user( + app_code="code1", enterprise_user_decoded=decoded, auth_type=WebAppAuthType.EXTERNAL + ) + + +# --------------------------------------------------------------------------- +# PassportResource.get +# --------------------------------------------------------------------------- +class TestPassportResource: + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_missing_app_code_raises_unauthorized(self, mock_features: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + with app.test_request_context("/passport"): + with pytest.raises(Unauthorized, match="X-App-Code"): + PassportResource().get() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.generate_session_id", return_value="new-sess-id") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_creates_new_end_user_when_no_user_id( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_gen_session: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + mock_db.session.scalar.side_effect = [site, app_model] + mock_passport_cls.return_value.issue.return_value = "issued-token" + + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "issued-token" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("controllers.web.passport.PassportService") + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_reuses_existing_end_user_when_user_id_provided( + self, + mock_features: MagicMock, + mock_db: MagicMock, + mock_passport_cls: MagicMock, + app: Flask, + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + app_model = SimpleNamespace(id="app-1", status="normal", enable_site=True, tenant_id="t1") + existing_user = SimpleNamespace(id="eu-1", session_id="sess-existing") + mock_db.session.scalar.side_effect = [site, app_model, existing_user] + mock_passport_cls.return_value.issue.return_value = "reused-token" + + with app.test_request_context("/passport?user_id=sess-existing", headers={"X-App-Code": "code1"}): + response = PassportResource().get() + + assert response.get_json()["access_token"] == "reused-token" + # Should not create a new end user + mock_db.session.add.assert_not_called() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_site_not_found_raises(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_db.session.scalar.return_value = None + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() + + @patch("controllers.web.passport.db") + @patch("controllers.web.passport.FeatureService.get_system_features") + def test_disabled_app_raises_not_found(self, mock_features: MagicMock, mock_db: MagicMock, app: Flask) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + site = SimpleNamespace(app_id="app-1", code="code1") + disabled_app = SimpleNamespace(id="app-1", status="normal", enable_site=False) + mock_db.session.scalar.side_effect = [site, disabled_app] + with app.test_request_context("/passport", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + PassportResource().get() diff --git a/api/tests/unit_tests/controllers/web/test_workflow.py b/api/tests/unit_tests/controllers/web/test_workflow.py new file mode 100644 index 0000000000..0973340527 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow.py @@ -0,0 +1,95 @@ +"""Unit tests for controllers.web.workflow endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import ( + NotWorkflowAppError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.web.workflow import WorkflowRunApi, WorkflowTaskStopApi +from core.errors.error import ProviderTokenNotInitError, QuotaExceededError + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="workflow") + + +def _chat_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", mode="chat") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowRunApi +# --------------------------------------------------------------------------- +class TestWorkflowRunApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowRunApi().post(_chat_app(), _end_user()) + + @patch("controllers.web.workflow.helper.compact_generate_response", return_value={"result": "ok"}) + @patch("controllers.web.workflow.AppGenerateService.generate") + @patch("controllers.web.workflow.web_ns") + def test_happy_path(self, mock_ns: MagicMock, mock_gen: MagicMock, mock_compact: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {"key": "val"}} + mock_gen.return_value = "response" + + with app.test_request_context("/workflows/run", method="POST"): + result = WorkflowRunApi().post(_workflow_app(), _end_user()) + + assert result == {"result": "ok"} + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=ProviderTokenNotInitError(description="not init"), + ) + @patch("controllers.web.workflow.web_ns") + def test_provider_not_init(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderNotInitializeError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + @patch( + "controllers.web.workflow.AppGenerateService.generate", + side_effect=QuotaExceededError(), + ) + @patch("controllers.web.workflow.web_ns") + def test_quota_exceeded(self, mock_ns: MagicMock, mock_gen: MagicMock, app: Flask) -> None: + mock_ns.payload = {"inputs": {}} + + with app.test_request_context("/workflows/run", method="POST"): + with pytest.raises(ProviderQuotaExceededError): + WorkflowRunApi().post(_workflow_app(), _end_user()) + + +# --------------------------------------------------------------------------- +# WorkflowTaskStopApi +# --------------------------------------------------------------------------- +class TestWorkflowTaskStopApi: + def test_wrong_mode_raises(self, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + with pytest.raises(NotWorkflowAppError): + WorkflowTaskStopApi().post(_chat_app(), _end_user(), "task-1") + + @patch("controllers.web.workflow.GraphEngineManager.send_stop_command") + @patch("controllers.web.workflow.AppQueueManager.set_stop_flag_no_user_check") + def test_stop_calls_both_mechanisms(self, mock_legacy: MagicMock, mock_graph: MagicMock, app: Flask) -> None: + with app.test_request_context("/workflows/tasks/task-1/stop", method="POST"): + result = WorkflowTaskStopApi().post(_workflow_app(), _end_user(), "task-1") + + assert result == {"result": "success"} + mock_legacy.assert_called_once_with("task-1") + mock_graph.assert_called_once_with("task-1") diff --git a/api/tests/unit_tests/controllers/web/test_workflow_events.py b/api/tests/unit_tests/controllers/web/test_workflow_events.py new file mode 100644 index 0000000000..64c09b5e22 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_workflow_events.py @@ -0,0 +1,127 @@ +"""Unit tests for controllers.web.workflow_events endpoints.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.web.error import NotFoundError +from controllers.web.workflow_events import WorkflowEventsApi +from models.enums import CreatorUserRole + + +def _workflow_app() -> SimpleNamespace: + return SimpleNamespace(id="app-1", tenant_id="tenant-1", mode="workflow") + + +def _end_user() -> SimpleNamespace: + return SimpleNamespace(id="eu-1") + + +# --------------------------------------------------------------------------- +# WorkflowEventsApi +# --------------------------------------------------------------------------- +class TestWorkflowEventsApi: + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_found(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = None + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_app(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="other-app", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_not_created_by_end_user( + self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask + ) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="eu-1", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_workflow_run_wrong_end_user(self, mock_db: MagicMock, mock_factory: MagicMock, app: Flask) -> None: + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="other-user", + finished_at=None, + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + with app.test_request_context("/workflow/run-1/events"): + with pytest.raises(NotFoundError): + WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + @patch("controllers.web.workflow_events.WorkflowResponseConverter") + @patch("controllers.web.workflow_events.DifyAPIRepositoryFactory") + @patch("controllers.web.workflow_events.db") + def test_finished_run_returns_sse_response( + self, mock_db: MagicMock, mock_factory: MagicMock, mock_converter: MagicMock, app: Flask + ) -> None: + from datetime import datetime + + mock_db.engine = "engine" + run = SimpleNamespace( + id="run-1", + app_id="app-1", + created_by_role=CreatorUserRole.END_USER, + created_by="eu-1", + finished_at=datetime(2024, 1, 1), + ) + mock_repo = MagicMock() + mock_repo.get_workflow_run_by_id_and_tenant_id.return_value = run + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + finish_response = MagicMock() + finish_response.model_dump.return_value = {"task_id": "run-1"} + finish_response.event.value = "workflow_finished" + mock_converter.workflow_run_result_to_finish_response.return_value = finish_response + + with app.test_request_context("/workflow/run-1/events"): + response = WorkflowEventsApi().get(_workflow_app(), _end_user(), "run-1") + + assert response.mimetype == "text/event-stream" diff --git a/api/tests/unit_tests/controllers/web/test_wraps.py b/api/tests/unit_tests/controllers/web/test_wraps.py new file mode 100644 index 0000000000..85049ae975 --- /dev/null +++ b/api/tests/unit_tests/controllers/web/test_wraps.py @@ -0,0 +1,393 @@ +"""Unit tests for controllers.web.wraps — JWT auth decorator and validation helpers.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import BadRequest, NotFound, Unauthorized + +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError +from controllers.web.wraps import ( + _validate_user_accessibility, + _validate_webapp_token, + decode_jwt_token, +) + + +# --------------------------------------------------------------------------- +# _validate_webapp_token +# --------------------------------------------------------------------------- +class TestValidateWebappToken: + def test_enterprise_enabled_and_app_auth_requires_webapp_source(self) -> None: + """When both flags are true, a non-webapp source must raise.""" + decoded = {"token_source": "other"} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_accepts_webapp_source(self) -> None: + decoded = {"token_source": "webapp"} + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_enterprise_enabled_and_app_auth_missing_source_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_webapp_token(decoded, app_web_auth_enabled=True, system_webapp_auth_enabled=True) + + def test_public_app_rejects_webapp_source(self) -> None: + """When auth is not required, a webapp-sourced token must be rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_non_webapp_source(self) -> None: + decoded = {"token_source": "other"} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_public_app_accepts_no_source(self) -> None: + decoded = {} + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=False) + + def test_system_enabled_but_app_public(self) -> None: + """system_webapp_auth_enabled=True but app is public — webapp source rejected.""" + decoded = {"token_source": "webapp"} + with pytest.raises(Unauthorized): + _validate_webapp_token(decoded, app_web_auth_enabled=False, system_webapp_auth_enabled=True) + + +# --------------------------------------------------------------------------- +# _validate_user_accessibility +# --------------------------------------------------------------------------- +class TestValidateUserAccessibility: + def test_skips_when_auth_disabled(self) -> None: + """No checks when system or app auth is disabled.""" + _validate_user_accessibility( + decoded={}, + app_code="code", + app_web_auth_enabled=False, + system_webapp_auth_enabled=False, + webapp_settings=None, + ) + + def test_missing_user_id_raises(self) -> None: + decoded = {} + with pytest.raises(WebAppAuthRequiredError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=SimpleNamespace(access_mode="internal"), + ) + + def test_missing_webapp_settings_raises(self) -> None: + decoded = {"user_id": "u1"} + with pytest.raises(WebAppAuthRequiredError, match="settings not found"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=None, + ) + + def test_missing_auth_type_raises(self) -> None: + decoded = {"user_id": "u1", "granted_at": 1} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="auth_type"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + def test_missing_granted_at_raises(self) -> None: + decoded = {"user_id": "u1", "auth_type": "external"} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="granted_at"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_type_checks_sso_update_time( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + # granted_at is before SSO update time → denied + mock_sso_time.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_workspace_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_internal_auth_type_checks_workspace_sso_update_time( + self, mock_perm_check: MagicMock, mock_workspace_sso: MagicMock + ) -> None: + mock_workspace_sso.return_value = datetime.now(UTC) + old_granted = int((datetime.now(UTC) - timedelta(hours=1)).timestamp()) + decoded = {"user_id": "u1", "auth_type": "internal", "granted_at": old_granted} + settings = SimpleNamespace(access_mode="public") + with pytest.raises(WebAppAuthAccessDeniedError, match="SSO settings"): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.get_app_sso_settings_last_update_time") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=False) + def test_external_auth_passes_when_granted_after_sso_update( + self, mock_perm_check: MagicMock, mock_sso_time: MagicMock + ) -> None: + mock_sso_time.return_value = datetime.now(UTC) - timedelta(hours=2) + recent_granted = int(datetime.now(UTC).timestamp()) + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": recent_granted} + settings = SimpleNamespace(access_mode="public") + # Should not raise + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp", return_value=False) + @patch("controllers.web.wraps.AppService.get_app_id_by_code", return_value="app-id-1") + @patch("controllers.web.wraps.WebAppAuthService.is_app_require_permission_check", return_value=True) + def test_permission_check_denies_unauthorized_user( + self, mock_perm: MagicMock, mock_app_id: MagicMock, mock_allowed: MagicMock + ) -> None: + decoded = {"user_id": "u1", "auth_type": "external", "granted_at": int(datetime.now(UTC).timestamp())} + settings = SimpleNamespace(access_mode="internal") + with pytest.raises(WebAppAuthAccessDeniedError): + _validate_user_accessibility( + decoded=decoded, + app_code="code", + app_web_auth_enabled=True, + system_webapp_auth_enabled=True, + webapp_settings=settings, + ) + + +# --------------------------------------------------------------------------- +# decode_jwt_token +# --------------------------------------------------------------------------- +class TestDecodeJwtToken: + @patch("controllers.web.wraps._validate_user_accessibility") + @patch("controllers.web.wraps._validate_webapp_token") + @patch("controllers.web.wraps.EnterpriseService.WebAppAuth.get_app_access_mode_by_id") + @patch("controllers.web.wraps.AppService.get_app_id_by_code") + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_happy_path( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + mock_app_id: MagicMock, + mock_access_mode: MagicMock, + mock_validate_token: MagicMock, + mock_validate_user: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + # Configure session mock to return correct objects via scalar() + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + result_app, result_user = decode_jwt_token() + + assert result_app.id == "app-1" + assert result_user.id == "eu-1" + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.extract_webapp_passport") + def test_missing_token_raises_unauthorized( + self, mock_extract: MagicMock, mock_features: MagicMock, app: Flask + ) -> None: + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_extract.return_value = None + + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_app_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + session_mock = MagicMock() + session_mock.scalar.return_value = None # No app found + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_disabled_site_raises_bad_request( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=False) + + session_mock = MagicMock() + # scalar calls: app_model, site (code found), then end_user + session_mock.scalar.side_effect = [app_model, SimpleNamespace(code="code1"), None] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(BadRequest, match="Site is disabled"): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_missing_end_user_raises_not_found( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, None] # end_user is None + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(NotFound): + decode_jwt_token() + + @patch("controllers.web.wraps.FeatureService.get_system_features") + @patch("controllers.web.wraps.PassportService") + @patch("controllers.web.wraps.extract_webapp_passport") + @patch("controllers.web.wraps.db") + def test_user_id_mismatch_raises_unauthorized( + self, + mock_db: MagicMock, + mock_extract: MagicMock, + mock_passport_cls: MagicMock, + mock_features: MagicMock, + app: Flask, + ) -> None: + mock_extract.return_value = "jwt-token" + mock_passport_cls.return_value.verify.return_value = { + "app_code": "code1", + "app_id": "app-1", + "end_user_id": "eu-1", + } + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + + app_model = SimpleNamespace(id="app-1", enable_site=True) + site = SimpleNamespace(code="code1") + end_user = SimpleNamespace(id="eu-1", session_id="sess-1") + + session_mock = MagicMock() + session_mock.scalar.side_effect = [app_model, site, end_user] + session_ctx = MagicMock() + session_ctx.__enter__ = MagicMock(return_value=session_mock) + session_ctx.__exit__ = MagicMock(return_value=False) + mock_db.engine = "engine" + + with patch("controllers.web.wraps.Session", return_value=session_ctx): + with app.test_request_context("/", headers={"X-App-Code": "code1"}): + with pytest.raises(Unauthorized, match="expired"): + decode_jwt_token(user_id="different-user") diff --git a/api/tests/unit_tests/core/agent/conftest.py b/api/tests/unit_tests/core/agent/conftest.py new file mode 100644 index 0000000000..a2aa501720 --- /dev/null +++ b/api/tests/unit_tests/core/agent/conftest.py @@ -0,0 +1,80 @@ +import pytest + + +class DummyTool: + def __init__(self, name): + self.name = name + + +class DummyPromptEntity: + def __init__(self, first_prompt): + self.first_prompt = first_prompt + + +class DummyAgentConfig: + def __init__(self, prompt_entity=None): + self.prompt = prompt_entity + + +class DummyAppConfig: + def __init__(self, agent=None): + self.agent = agent + + +class DummyScratchpadUnit: + def __init__( + self, + final=False, + thought=None, + action_str=None, + observation=None, + agent_response=None, + ): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def dummy_tool_factory(): + def _factory(name): + return DummyTool(name) + + return _factory + + +@pytest.fixture +def dummy_prompt_entity_factory(): + def _factory(first_prompt): + return DummyPromptEntity(first_prompt) + + return _factory + + +@pytest.fixture +def dummy_agent_config_factory(): + def _factory(prompt_entity=None): + return DummyAgentConfig(prompt_entity) + + return _factory + + +@pytest.fixture +def dummy_app_config_factory(): + def _factory(agent=None): + return DummyAppConfig(agent) + + return _factory + + +@pytest.fixture +def dummy_scratchpad_unit_factory(): + def _factory(**kwargs): + return DummyScratchpadUnit(**kwargs) + + return _factory diff --git a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py index 4a613e35b0..9073ae1044 100644 --- a/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py +++ b/api/tests/unit_tests/core/agent/output_parser/test_cot_output_parser.py @@ -1,70 +1,255 @@ +"""Unit tests for CotAgentOutputParser. + +Verifies expected parsing behavior for streaming content and JSON payloads, +including edge cases such as empty/non-string content and malformed JSON. +Assumes lightweight fixtures (SimpleNamespace/MagicMock) stand in for real +model output structures. Implementation under test: +core.agent.output_parser.cot_output_parser.CotAgentOutputParser. +""" + import json -from collections.abc import Generator +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest -from core.agent.entities import AgentScratchpadUnit from core.agent.output_parser.cot_output_parser import CotAgentOutputParser -from core.model_runtime.entities.llm_entities import AssistantPromptMessage, LLMResultChunk, LLMResultChunkDelta -def mock_llm_response(text) -> Generator[LLMResultChunk, None, None]: - for i in range(len(text)): - yield LLMResultChunk( - model="model", - prompt_messages=[], - delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=text[i], tool_calls=[])), +@pytest.fixture +def mock_action_class(mocker): + mock_action = MagicMock() + mocker.patch( + "core.agent.output_parser.cot_output_parser.AgentScratchpadUnit.Action", + mock_action, + ) + return mock_action + + +@pytest.fixture +def usage_dict(): + return {} + + +@pytest.fixture +def make_chunk(): + def _make_chunk(content=None, usage=None): + delta = SimpleNamespace( + message=SimpleNamespace(content=content), + usage=usage, ) + return SimpleNamespace(delta=delta) + + return _make_chunk -def test_cot_output_parser(): - test_cases = [ - { - "input": 'Through: abc\nAction: ```{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with json - { - "input": 'Through: abc\nAction: ```json\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # code block with JSON - { - "input": 'Through: abc\nAction: ```JSON\n{"action": "Final Answer", "action_input": "```echarts\n {' - '}\n```"}```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # list - { - "input": 'Through: abc\nAction: ```[{"action": "Final Answer", "action_input": "```echarts\n {}\n```"}]```', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block - { - "input": 'Through: abc\nAction: {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}', - "action": {"action": "Final Answer", "action_input": "```echarts\n {}\n```"}, - "output": 'Through: abc\n {"action": "Final Answer", "action_input": "```echarts\\n {}\\n```"}', - }, - # no code block and json - {"input": "Through: abc\nAction: efg", "action": {}, "output": "Through: abc\n efg"}, - ] +# ============================================================ +# Test Suite +# ============================================================ - parser = CotAgentOutputParser() - usage_dict = {} - for test_case in test_cases: - # mock llm_response as a generator by text - llm_response: Generator[LLMResultChunk, None, None] = mock_llm_response(test_case["input"]) - results = parser.handle_react_stream_output(llm_response, usage_dict) - output = "" - for result in results: - if isinstance(result, str): - output += result - elif isinstance(result, AgentScratchpadUnit.Action): - if test_case["action"]: - assert result.to_dict() == test_case["action"] - output += json.dumps(result.to_dict()) - if test_case["output"]: - assert output == test_case["output"] + +class TestCotAgentOutputParser: + """Validate CotAgentOutputParser streaming + JSON parsing behavior. + + Lifecycle: no explicit setup/teardown; relies on pytest fixtures for + lightweight chunk/action doubles. Invariants: non-string/empty content + yields no output, usage gets recorded when provided, and valid action JSON + results in Action instantiation. Usage: invoke via pytest (e.g., + `pytest -k TestCotAgentOutputParser`). + """ + + # -------------------------------------------------------- + # Basic streaming & usage + # -------------------------------------------------------- + + def test_stream_plain_text(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("hello world")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "".join(result) == "hello world" + + def test_stream_empty_string(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_stream_none_content(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk(None)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + @pytest.mark.parametrize("content", [123, 12.5, [], {}, object()]) + def test_non_string_content(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result == [] + + def test_usage_update(self, make_chunk, usage_dict) -> None: + usage_data = {"tokens": 99} + chunks = [make_chunk("abc", usage=usage_data)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert usage_dict["usage"] == usage_data + + # -------------------------------------------------------- + # JSON parsing (direct + streaming) + # -------------------------------------------------------- + + def test_single_json_action_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '{"action": "search", "input": "query"}' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="search", action_input="query") + + def test_json_list_unwrap(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '[{"action": "lookup", "input": "abc"}]' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_json_missing_fields_returns_string(self, make_chunk, usage_dict) -> None: + content = '{"foo": "bar"}' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # Expect the serialized JSON to be yielded as a single element. + assert result == [json.dumps({"foo": "bar"})] + + def test_invalid_json_string_input(self, make_chunk, usage_dict) -> None: + content = "{invalid json}" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert any("invalid json" in str(r) for r in result) + + def test_json_split_across_chunks(self, make_chunk, usage_dict, mock_action_class) -> None: + chunks = [ + make_chunk('{"action": '), + make_chunk('"multi", '), + make_chunk('"input": "step"}'), + ] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="multi", action_input="step") + + def test_unclosed_json_at_end(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('{"foo": "bar"')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('{"foo": "bar"' in item for item in result) + + # -------------------------------------------------------- + # Code block JSON extraction + # -------------------------------------------------------- + + def test_code_block_json_valid(self, make_chunk, usage_dict, mock_action_class) -> None: + content = """```json +{"action": "lookup", "input": "abc"} +```""" + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + mock_action_class.assert_called_once_with(action_name="lookup", action_input="abc") + + def test_code_block_multiple_json(self, make_chunk, usage_dict, mock_action_class) -> None: + # Multiple JSON objects inside single code fence (invalid combined JSON) + # Parser should safely ignore invalid combined block + content = """```json +{"action": "a1", "input": "x"} +{"action": "a2", "input": "y"} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # No valid parsed action expected due to invalid combined JSON + assert mock_action_class.call_count == 0 + assert isinstance(result, list) + + def test_code_block_invalid_json(self, make_chunk, usage_dict) -> None: + content = """```json +{invalid} +```""" + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert result + + def test_unclosed_code_block(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk('```json {"a":1}')] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + assert any('```json {"a":1}' in item for item in result) + + # -------------------------------------------------------- + # Action / Thought prefix handling + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + " action: something", + " ACTION: something", + " thought: reasoning", + " THOUGHT: reasoning", + ], + ) + def test_prefix_handling(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + joined = "".join(str(item) for item in result) + expected_word = "something" if "action:" in content.lower() else "reasoning" + assert expected_word in joined + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() + + def test_prefix_mid_word_yield_delta_branch(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("xaction: test")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "x" in "".join(map(str, result)) + + # -------------------------------------------------------- + # Mixed streaming scenarios + # -------------------------------------------------------- + + def test_text_json_text_mix(self, make_chunk, usage_dict, mock_action_class) -> None: + content = 'start {"action": "mix", "input": "1"} end' + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + # JSON action should be parsed + mock_action_class.assert_called_once() + # Ensure surrounding text is streamed (character-level) + joined = "".join(str(r) for r in result if not isinstance(r, MagicMock)) + assert "start" in joined + assert "end" in joined + + def test_multiple_code_blocks_in_stream(self, make_chunk, usage_dict, mock_action_class) -> None: + content = '```json\n{"action":"a1","input":"x"}\n```middle```json\n{"action":"a2","input":"y"}\n```' + chunks = [make_chunk(content)] + list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert mock_action_class.call_count == 2 + + def test_backtick_noise(self, make_chunk, usage_dict) -> None: + chunks = [make_chunk("text with ` random ` backticks")] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert "text with" in "".join(result) + + # -------------------------------------------------------- + # Boundary & edge inputs + # -------------------------------------------------------- + + @pytest.mark.parametrize( + "content", + [ + "```", + "{", + "}", + "```json", + "action:", + "thought:", + " ", + ], + ) + def test_edge_inputs(self, make_chunk, usage_dict, content) -> None: + chunks = [make_chunk(content)] + result = list(CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)) + assert all(isinstance(item, str) for item in result) + joined = "".join(result) + if content == " ": + assert result == [] or joined == content + if content in {"```", "{", "}", "```json"}: + assert content in joined + if content.lower() in {"action:", "thought:"}: + assert "action:" not in joined.lower() + assert "thought:" not in joined.lower() diff --git a/api/tests/unit_tests/core/agent/strategy/test_base.py b/api/tests/unit_tests/core/agent/strategy/test_base.py new file mode 100644 index 0000000000..83ff79e8a1 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_base.py @@ -0,0 +1,174 @@ +from collections.abc import Generator +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.base import BaseAgentStrategy + + +class DummyStrategy(BaseAgentStrategy): + """ + Concrete implementation for testing BaseAgentStrategy + """ + + def __init__(self, return_values=None, raise_exception=None): + self.return_values = return_values or [] + self.raise_exception = raise_exception + self.received_args = None + + def _invoke( + self, + params, + user_id, + conversation_id=None, + app_id=None, + message_id=None, + credentials=None, + ) -> Generator: + self.received_args = ( + params, + user_id, + conversation_id, + app_id, + message_id, + credentials, + ) + + if self.raise_exception: + raise self.raise_exception + + yield from self.return_values + + +class TestBaseAgentStrategyInstantiation: + def test_cannot_instantiate_abstract_class(self) -> None: + with pytest.raises(TypeError): + BaseAgentStrategy() + + +class TestBaseAgentStrategyInvoke: + @pytest.fixture + def mock_message(self): + return MagicMock(name="AgentInvokeMessage") + + @pytest.fixture + def mock_credentials(self): + return MagicMock(name="InvokeCredentials") + + @pytest.mark.parametrize( + ("params", "user_id", "conversation_id", "app_id", "message_id"), + [ + ({"key": "value"}, "user1", "conv1", "app1", "msg1"), + ({}, "user2", None, None, None), + ({"a": 1}, "", "", "", ""), + ({"nested": {"x": 1}}, "user3", None, "app3", None), + ], + ) + def test_invoke_success( + self, + mock_message, + mock_credentials, + params, + user_id, + conversation_id, + app_id, + message_id, + ) -> None: + # Arrange + strategy = DummyStrategy(return_values=[mock_message]) + + # Act + result = list( + strategy.invoke( + params=params, + user_id=user_id, + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + credentials=mock_credentials, + ) + ) + + # Assert + assert result == [mock_message] + assert strategy.received_args == ( + params, + user_id, + conversation_id, + app_id, + message_id, + mock_credentials, + ) + + def test_invoke_multiple_yields(self, mock_message) -> None: + # Arrange + messages = [mock_message, MagicMock(), MagicMock()] + strategy = DummyStrategy(return_values=messages) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == messages + + def test_invoke_empty_generator(self) -> None: + # Arrange + strategy = DummyStrategy(return_values=[]) + + # Act + result = list(strategy.invoke(params={}, user_id="user")) + + # Assert + assert result == [] + + def test_invoke_propagates_exception(self) -> None: + # Arrange + strategy = DummyStrategy(raise_exception=ValueError("failure")) + + # Act & Assert + with pytest.raises(ValueError, match="failure"): + list(strategy.invoke(params={}, user_id="user")) + + @pytest.mark.parametrize( + "invalid_params", + [ + None, + "", + 123, + [], + ], + ) + def test_invoke_invalid_params_type_pass_through(self, invalid_params) -> None: + """ + Base class does not validate types — ensure pass-through behavior + """ + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params=invalid_params, user_id="user")) + + assert result == [] + + def test_invoke_none_user_id(self) -> None: + strategy = DummyStrategy(return_values=[]) + + result = list(strategy.invoke(params={}, user_id=None)) + + assert result == [] + + +class TestBaseAgentStrategyGetParameters: + def test_get_parameters_default_empty_list(self) -> None: + strategy = DummyStrategy() + result = strategy.get_parameters() + + assert isinstance(result, list) + assert result == [] + + def test_get_parameters_returns_new_list_each_time(self) -> None: + strategy = DummyStrategy() + + first = strategy.get_parameters() + second = strategy.get_parameters() + + assert first == second == [] + assert first is not second diff --git a/api/tests/unit_tests/core/agent/strategy/test_plugin.py b/api/tests/unit_tests/core/agent/strategy/test_plugin.py new file mode 100644 index 0000000000..e0894f1e90 --- /dev/null +++ b/api/tests/unit_tests/core/agent/strategy/test_plugin.py @@ -0,0 +1,272 @@ +# File: tests/unit_tests/core/agent/strategy/test_plugin.py + +from unittest.mock import MagicMock + +import pytest + +from core.agent.strategy.plugin import PluginAgentStrategy + +# ============================================================ +# Fixtures +# ============================================================ + + +@pytest.fixture +def mock_parameter(): + def _factory(name="param", return_value="initialized"): + param = MagicMock() + param.name = name + param.init_frontend_parameter = MagicMock(return_value=return_value) + return param + + return _factory + + +@pytest.fixture +def mock_declaration(mock_parameter): + param1 = mock_parameter("param1", "init1") + param2 = mock_parameter("param2", "init2") + + identity = MagicMock() + identity.provider = "provider_x" + identity.name = "strategy_x" + + declaration = MagicMock() + declaration.parameters = [param1, param2] + declaration.identity = identity + + return declaration + + +@pytest.fixture +def strategy(mock_declaration): + return PluginAgentStrategy( + tenant_id="tenant_123", + declaration=mock_declaration, + meta_version="v1", + ) + + +# ============================================================ +# Initialization Tests +# ============================================================ + + +class TestPluginAgentStrategyInitialization: + def test_init_sets_attributes(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version="meta_v", + ) + + assert strategy.tenant_id == "tenant_test" + assert strategy.declaration == mock_declaration + assert strategy.meta_version == "meta_v" + + def test_init_meta_version_none(self, mock_declaration) -> None: + strategy = PluginAgentStrategy( + tenant_id="tenant_test", + declaration=mock_declaration, + meta_version=None, + ) + + assert strategy.meta_version is None + + +# ============================================================ +# get_parameters Tests +# ============================================================ + + +class TestGetParameters: + def test_get_parameters_returns_parameters(self, strategy, mock_declaration) -> None: + result = strategy.get_parameters() + assert result == mock_declaration.parameters + + +# ============================================================ +# initialize_parameters Tests +# ============================================================ + + +class TestInitializeParameters: + def test_initialize_parameters_success(self, strategy, mock_declaration) -> None: + params = {"param1": "value1"} + + result = strategy.initialize_parameters(params.copy()) + + assert result["param1"] == "init1" + assert result["param2"] == "init2" + + mock_declaration.parameters[0].init_frontend_parameter.assert_called_once_with("value1") + mock_declaration.parameters[1].init_frontend_parameter.assert_called_once_with(None) + + @pytest.mark.parametrize( + "input_params", + [ + {}, + {"param1": None}, + {"param1": ""}, + {"param1": 0}, + {"param1": []}, + {"param1": {}, "param2": "value"}, + ], + ) + def test_initialize_parameters_edge_cases(self, strategy, input_params) -> None: + result = strategy.initialize_parameters(input_params.copy()) + + for param in strategy.declaration.parameters: + assert param.name in result + + def test_initialize_parameters_invalid_input_type(self, strategy) -> None: + with pytest.raises(AttributeError): + strategy.initialize_parameters(None) + + +# ============================================================ +# _invoke Tests +# ============================================================ + + +class TestInvoke: + def test_invoke_success_all_arguments(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter(["msg1", "msg2"])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mock_convert = mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={"converted": True}, + ) + + result = list( + strategy._invoke( + params={"param1": "value"}, + user_id="user_1", + conversation_id="conv_1", + app_id="app_1", + message_id="msg_1", + credentials=None, + ) + ) + + assert result == ["msg1", "msg2"] + mock_convert.assert_called_once() + mock_manager.invoke.assert_called_once() + + call_kwargs = mock_manager.invoke.call_args.kwargs + assert call_kwargs["tenant_id"] == "tenant_123" + assert call_kwargs["user_id"] == "user_1" + assert call_kwargs["agent_provider"] == "provider_x" + assert call_kwargs["agent_strategy"] == "strategy_x" + assert call_kwargs["agent_params"] == {"converted": True} + assert call_kwargs["conversation_id"] == "conv_1" + assert call_kwargs["app_id"] == "app_1" + assert call_kwargs["message_id"] == "msg_1" + assert call_kwargs["context"] is not None + + def test_invoke_with_credentials(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + # Patch PluginInvokeContext to bypass pydantic validation + mock_context = MagicMock() + mocker.patch( + "core.agent.strategy.plugin.PluginInvokeContext", + return_value=mock_context, + ) + + credentials = MagicMock() + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + credentials=credentials, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + @pytest.mark.parametrize( + ("conversation_id", "app_id", "message_id"), + [ + (None, None, None), + ("conv", None, None), + (None, "app", None), + (None, None, "msg"), + ], + ) + def test_invoke_optional_arguments(self, strategy, mocker, conversation_id, app_id, message_id) -> None: + mock_manager = MagicMock() + mock_manager.invoke = MagicMock(return_value=iter([])) + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + result = list( + strategy._invoke( + params={}, + user_id="user_1", + conversation_id=conversation_id, + app_id=app_id, + message_id=message_id, + ) + ) + + assert result == [] + mock_manager.invoke.assert_called_once() + + def test_invoke_convert_raises_exception(self, strategy, mocker) -> None: + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=MagicMock(), + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + side_effect=ValueError("conversion failed"), + ) + + with pytest.raises(ValueError): + list(strategy._invoke(params={}, user_id="user_1")) + + def test_invoke_manager_raises_exception(self, strategy, mocker) -> None: + mock_manager = MagicMock() + mock_manager.invoke.side_effect = RuntimeError("invoke failed") + + mocker.patch( + "core.agent.strategy.plugin.PluginAgentClient", + return_value=mock_manager, + ) + + mocker.patch( + "core.agent.strategy.plugin.convert_parameters_to_plugin_format", + return_value={}, + ) + + with pytest.raises(RuntimeError): + list(strategy._invoke(params={}, user_id="user_1")) diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py new file mode 100644 index 0000000000..683cc0e36f --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -0,0 +1,802 @@ +import json +from decimal import Decimal +from unittest.mock import MagicMock + +import pytest + +import core.agent.base_agent_runner as module +from core.agent.base_agent_runner import BaseAgentRunner + +# ========================================================== +# Fixtures +# ========================================================== + + +@pytest.fixture +def mock_db_session(mocker): + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + return session + + +@pytest.fixture +def runner(mocker, mock_db_session): + r = BaseAgentRunner.__new__(BaseAgentRunner) + r.tenant_id = "tenant" + r.user_id = "user" + r.agent_thought_count = 0 + r.message = mocker.MagicMock(id="msg_current", conversation_id="conv1") + r.app_config = mocker.MagicMock() + r.app_config.app_id = "app1" + r.app_config.agent = None + r.dataset_tools = [] + r.application_generate_entity = mocker.MagicMock(invoke_from="test") + r._current_thoughts = [] + return r + + +# ========================================================== +# _repack_app_generate_entity +# ========================================================== + + +class TestRepack: + def test_sets_empty_if_none(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = None + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "" + + def test_keeps_existing(self, runner, mocker): + entity = mocker.MagicMock() + entity.app_config.prompt_template.simple_prompt_template = "abc" + result = runner._repack_app_generate_entity(entity) + assert result.app_config.prompt_template.simple_prompt_template == "abc" + + +# ========================================================== +# update_prompt_message_tool +# ========================================================== + + +class TestUpdatePromptTool: + def build_param(self, mocker, **kwargs): + p = mocker.MagicMock() + p.form = kwargs.get("form") + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + p.type = mock_type + + p.name = kwargs.get("name", "p1") + p.llm_description = "desc" + p.input_schema = kwargs.get("input_schema") + p.options = kwargs.get("options") + p.required = kwargs.get("required", False) + return p + + def test_skip_non_llm(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form="NOT_LLM") + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_enum_and_required(self, runner, mocker): + option = mocker.MagicMock(value="opt1") + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + options=[option], + required=True, + ) + + tool = mocker.MagicMock() + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert "p1" in result.parameters["required"] + + def test_skip_file_type_param(self, runner, mocker): + tool = mocker.MagicMock() + param = self.build_param(mocker, form=module.ToolParameter.ToolParameterForm.LLM) + param.type = module.ToolParameter.ToolParameterType.FILE + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"] == {} + + def test_duplicate_required_not_duplicated(self, runner, mocker): + tool = mocker.MagicMock() + + param = self.build_param( + mocker, + form=module.ToolParameter.ToolParameterForm.LLM, + required=True, + ) + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": ["p1"]} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["required"].count("p1") == 1 + + +# ========================================================== +# create_agent_thought +# ========================================================== + + +class TestCreateAgentThought: + def test_with_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=10) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", ["f1"]) + assert result == "10" + assert runner.agent_thought_count == 1 + + def test_without_files(self, runner, mock_db_session, mocker): + mock_thought = mocker.MagicMock(id=11) + mocker.patch.object(module, "MessageAgentThought", return_value=mock_thought) + + result = runner.create_agent_thought("m", "msg", "tool", "input", []) + assert result == "11" + + +# ========================================================== +# save_agent_thought +# ========================================================== + + +class TestSaveAgentThought: + def setup_agent(self, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;tool2" + agent.tool_labels = {} + agent.thought = "" + return agent + + def test_not_found(self, runner, mock_db_session): + mock_db_session.scalar.return_value = None + with pytest.raises(ValueError): + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + def test_full_update(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + mock_label = mocker.MagicMock() + mock_label.to_dict.return_value = {"en_US": "label"} + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=mock_label) + + usage = mocker.MagicMock( + prompt_tokens=1, + prompt_price_unit=Decimal("0.1"), + prompt_unit_price=Decimal("0.1"), + completion_tokens=2, + completion_price_unit=Decimal("0.2"), + completion_unit_price=Decimal("0.2"), + total_tokens=3, + total_price=Decimal("0.3"), + ) + + runner.save_agent_thought( + "id", + "tool1;tool2", + {"a": 1}, + "thought", + {"b": 2}, + {"meta": 1}, + "answer", + ["f1"], + usage, + ) + + assert agent.answer == "answer" + assert agent.tokens == 3 + assert "tool1" in json.loads(agent.tool_labels_str) + + def test_label_fallback_when_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + agent.tool = "unknown_tool" + mock_db_session.scalar.return_value = agent + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert "unknown_tool" in labels + + def test_json_failure_paths(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + bad_obj = MagicMock() + bad_obj.__str__.return_value = "bad" + + runner.save_agent_thought( + "id", + None, + bad_obj, + None, + bad_obj, + bad_obj, + None, + [], + None, + ) + + assert mock_db_session.commit.called + + def test_messages_ids_none(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + runner.save_agent_thought("id", None, None, None, None, None, None, None, None) + assert mock_db_session.commit.called + + def test_success_dict_serialization(self, runner, mock_db_session, mocker): + agent = self.setup_agent(mocker) + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought( + "id", + None, + {"a": 1}, + None, + {"b": 2}, + None, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + + +# ========================================================== +# organize_agent_user_prompt +# ========================================================== + + +class TestOrganizeUserPrompt: + def test_no_files(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_with_files_no_config(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + msg = mocker.MagicMock(id="1", query="hello", app_model_config=None) + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + def test_image_detail_low_fallback(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=[]) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result.content == "hello" + + +# ========================================================== +# organize_agent_history +# ========================================================== + + +class TestOrganizeHistory: + def test_empty(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_answer_only(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="m1", answer="ans", agent_thoughts=[], app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert any(isinstance(x, module.AssistantPromptMessage) for x in result) + + def test_skip_current_message(self, runner, mock_db_session, mocker): + msg = mocker.MagicMock(id="msg_current", agent_thoughts=[], answer="ans", app_model_config=None) + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert result == [] + + def test_with_tool_calls_invalid_json(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input="invalid", + observation="invalid", + thought="thinking", + ) + msg = mocker.MagicMock(id="m2", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_empty_tool_name_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=";", thought="thinking") + msg = mocker.MagicMock(id="m5", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_valid_json_tool_flow(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=json.dumps({"tool1": {"x": 1}}), + observation=json.dumps({"tool1": "obs"}), + thought="thinking", + ) + + msg = mocker.MagicMock( + id="m100", + agent_thoughts=[thought], + answer=None, + app_model_config=None, + ) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + +# ========================================================== +# _convert_tool_to_prompt_message_tool (new coverage) +# ========================================================== + + +class TestConvertToolToPromptMessageTool: + def test_basic_conversion(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + runtime_param = mocker.MagicMock() + runtime_param.form = module.ToolParameter.ToolParameterForm.LLM + runtime_param.name = "param1" + runtime_param.llm_description = "desc" + runtime_param.required = True + runtime_param.input_schema = None + runtime_param.options = None + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + runtime_param.type = mock_type + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [runtime_param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + assert entity == tool_entity + + def test_full_conversion_multiple_params(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + # LLM param with input_schema override + param1 = mocker.MagicMock() + param1.form = module.ToolParameter.ToolParameterForm.LLM + param1.name = "p1" + param1.llm_description = "desc" + param1.required = True + param1.input_schema = {"type": "integer"} + param1.options = None + param1.type = mocker.MagicMock() + + # SYSTEM_FILES param should be skipped + param2 = mocker.MagicMock() + param2.form = module.ToolParameter.ToolParameterForm.LLM + param2.name = "file_param" + param2.type = module.ToolParameter.ToolParameterType.SYSTEM_FILES + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param1, param2] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, entity = runner._convert_tool_to_prompt_message_tool(tool) + + assert entity == tool_entity + + +# ========================================================== +# _init_prompt_tools additional branches +# ========================================================== + + +class TestInitPromptToolsExtended: + def test_agent_tool_branch(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="agent_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", return_value=(MagicMock(), "entity")) + + tools, prompts = runner._init_prompt_tools() + assert "agent_tool" in tools + + def test_exception_in_conversion(self, runner, mocker): + agent_tool = mocker.MagicMock(tool_name="bad_tool") + runner.app_config.agent = mocker.MagicMock(tools=[agent_tool]) + mocker.patch.object(runner, "_convert_tool_to_prompt_message_tool", side_effect=Exception) + + tools, prompts = runner._init_prompt_tools() + assert tools == {} + + +# ========================================================== +# Additional Coverage Tests (DO NOT MODIFY EXISTING TESTS) +# ========================================================== + + +class TestAdditionalCoverage: + def test_update_prompt_with_input_schema(self, runner, mocker): + tool = mocker.MagicMock() + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "p1" + param.required = False + param.llm_description = "desc" + param.options = None + param.input_schema = {"type": "number"} + + mock_type = mocker.MagicMock() + mock_type.as_normal_type.return_value = "string" + param.type = mock_type + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + assert result.parameters["properties"]["p1"]["type"] == "number" + + def test_save_agent_thought_existing_labels(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {"tool1": {"en_US": "existing"}} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + labels = json.loads(agent.tool_labels_str) + assert labels["tool1"]["en_US"] == "existing" + + def test_save_agent_thought_tool_meta_string(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + runner.save_agent_thought("id", None, None, None, None, "meta_string", None, [], None) + assert agent.tool_meta_str == "meta_string" + + def test_convert_dataset_retriever_tool(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + assert prompt is not None + + def test_organize_user_prompt_with_file_objects(self, runner, mock_db_session, mocker): + mock_db_session.scalars.return_value.all.return_value = [mocker.MagicMock()] + + file_config = mocker.MagicMock() + file_config.image_config = mocker.MagicMock(detail=None) + + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_config) + mocker.patch.object(module.file_factory, "build_from_message_files", return_value=["file1"]) + mocker.patch.object(module.file_manager, "to_prompt_message_content", return_value=mocker.MagicMock()) + + mocker.patch.object(module, "UserPromptMessage", side_effect=lambda **kw: MagicMock(**kw)) + mocker.patch.object(module, "TextPromptMessageContent", side_effect=lambda **kw: MagicMock(**kw)) + + msg = mocker.MagicMock(id="1", query="hello") + msg.app_model_config.to_dict.return_value = {} + + result = runner.organize_agent_user_prompt(msg) + assert result is not None + + def test_organize_history_without_tool_names(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock(tool=None, thought="thinking") + msg = mocker.MagicMock(id="m3", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + def test_organize_history_multiple_tools_split(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1;tool2", + tool_input=json.dumps({"tool1": {}, "tool2": {}}), + observation=json.dumps({"tool1": "o1", "tool2": "o2"}), + thought="thinking", + ) + msg = mocker.MagicMock(id="m4", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + result = runner.organize_agent_history([]) + assert isinstance(result, list) + + # ================= Additional Surgical Coverage ================= + + def test_convert_tool_select_enum_branch(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = True + param.llm_description = "desc" + param.input_schema = None + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + assert prompt_tool is not None + + +class TestConvertDatasetRetrieverTool: + def test_required_param_added(self, runner, mocker): + ds_tool = mocker.MagicMock() + ds_tool.entity.identity.name = "ds" + ds_tool.entity.description.llm = "desc" + + param = mocker.MagicMock() + param.name = "query" + param.llm_description = "desc" + param.required = True + + ds_tool.get_runtime_parameters.return_value = [param] + + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt = runner._convert_dataset_retriever_tool_to_prompt_message_tool(ds_tool) + + assert prompt is not None + + +class TestBaseAgentRunnerInit: + def test_init_sets_stream_tool_call_and_files(self, mocker): + session = mocker.MagicMock() + session.query.return_value.where.return_value.count.return_value = 2 + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) + mocker.patch.object(module.DatasetRetrieverTool, "get_dataset_tools", return_value=["ds_tool"]) + + llm = mocker.MagicMock() + llm.get_model_schema.return_value = mocker.MagicMock( + features=[module.ModelFeature.STREAM_TOOL_CALL, module.ModelFeature.VISION] + ) + model_instance = mocker.MagicMock(model_type_instance=llm, model="m", credentials="c") + + app_config = mocker.MagicMock() + app_config.app_id = "app1" + app_config.agent = None + app_config.dataset = mocker.MagicMock(dataset_ids=["d1"], retrieve_config={"k": "v"}) + app_config.additional_features = mocker.MagicMock(show_retrieve_source=True) + + app_generate = mocker.MagicMock(invoke_from="test", inputs={}, files=["file1"]) + message = mocker.MagicMock(id="msg1", conversation_id="conv1") + + runner = BaseAgentRunner( + tenant_id="tenant", + application_generate_entity=app_generate, + conversation=mocker.MagicMock(), + app_config=app_config, + model_config=mocker.MagicMock(), + config=mocker.MagicMock(), + queue_manager=mocker.MagicMock(), + message=message, + user_id="user", + model_instance=model_instance, + ) + + assert runner.stream_tool_call is True + assert runner.files == ["file1"] + assert runner.dataset_tools == ["ds_tool"] + assert runner.agent_thought_count == 2 + + +class TestBaseAgentRunnerCoverage: + def test_convert_tool_skips_non_llm_param(self, runner, mocker): + tool = mocker.MagicMock(tool_name="tool1") + + param = mocker.MagicMock() + param.form = "NOT_LLM" + param.type = mocker.MagicMock() + + tool_entity = mocker.MagicMock() + tool_entity.entity.description.llm = "desc" + tool_entity.get_merged_runtime_parameters.return_value = [param] + + mocker.patch.object(module.ToolManager, "get_agent_tool_runtime", return_value=tool_entity) + mocker.patch.object(module, "PromptMessageTool", side_effect=lambda **kw: MagicMock(**kw)) + + prompt_tool, _ = runner._convert_tool_to_prompt_message_tool(tool) + + assert prompt_tool.parameters["properties"] == {} + + def test_init_prompt_tools_adds_dataset_tools(self, runner, mocker): + dataset_tool = mocker.MagicMock() + dataset_tool.entity.identity.name = "ds" + runner.dataset_tools = [dataset_tool] + + mocker.patch.object(runner, "_convert_dataset_retriever_tool_to_prompt_message_tool", return_value=MagicMock()) + + tools, prompt_tools = runner._init_prompt_tools() + + assert tools["ds"] == dataset_tool + assert len(prompt_tools) == 1 + + def test_update_prompt_message_tool_select_enum(self, runner, mocker): + tool = mocker.MagicMock() + + option1 = mocker.MagicMock(value="A") + option2 = mocker.MagicMock(value="B") + + param = mocker.MagicMock() + param.form = module.ToolParameter.ToolParameterForm.LLM + param.name = "select_param" + param.required = False + param.llm_description = "desc" + param.input_schema = None + param.options = [option1, option2] + param.type = module.ToolParameter.ToolParameterType.SELECT + + tool.get_runtime_parameters.return_value = [param] + + prompt_tool = mocker.MagicMock() + prompt_tool.parameters = {"properties": {}, "required": []} + + result = runner.update_prompt_message_tool(tool, prompt_tool) + + assert result.parameters["properties"]["select_param"]["enum"] == ["A", "B"] + + def test_save_agent_thought_json_dumps_fallbacks(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + tool_input = {"a": 1} + observation = {"b": 2} + tool_meta = {"c": 3} + + real_dumps = json.dumps + + def dumps_side_effect(value, *args, **kwargs): + if value in (tool_input, observation, tool_meta) and kwargs.get("ensure_ascii") is False: + raise TypeError("fail") + return real_dumps(value, *args, **kwargs) + + mocker.patch.object(module.json, "dumps", side_effect=dumps_side_effect) + + runner.save_agent_thought( + "id", + "tool1", + tool_input, + None, + observation, + tool_meta, + None, + [], + None, + ) + + assert isinstance(agent.tool_input, str) + assert isinstance(agent.observation, str) + assert isinstance(agent.tool_meta_str, str) + + def test_save_agent_thought_skips_empty_tool_name(self, runner, mock_db_session, mocker): + agent = mocker.MagicMock() + agent.tool = "tool1;;" + agent.tool_labels = {} + agent.thought = "" + mock_db_session.scalar.return_value = agent + + mocker.patch.object(module.ToolManager, "get_tool_label", return_value=None) + + runner.save_agent_thought("id", None, None, None, None, None, None, [], None) + + labels = json.loads(agent.tool_labels_str) + assert "" not in labels + + def test_organize_history_includes_system_prompt(self, runner, mock_db_session, mocker): + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [] + mocker.patch.object(module, "extract_thread_messages", return_value=[]) + + system_message = module.SystemPromptMessage(content="sys") + + result = runner.organize_agent_history([system_message]) + + assert system_message in result + + def test_organize_history_tool_inputs_and_observation_none(self, runner, mock_db_session, mocker): + thought = mocker.MagicMock( + tool="tool1", + tool_input=None, + observation=None, + thought="thinking", + ) + msg = mocker.MagicMock(id="m6", agent_thoughts=[thought], answer=None, app_model_config=None) + + mock_db_session.execute.return_value.scalars.return_value.all.return_value = [msg] + mocker.patch.object(module, "extract_thread_messages", return_value=[msg]) + mocker.patch("uuid.uuid4", return_value="uuid") + + mocker.patch.object( + runner, + "organize_agent_user_prompt", + return_value=module.UserPromptMessage(content="user"), + ) + + result = runner.organize_agent_history([]) + + assert any(isinstance(item, module.ToolPromptMessage) for item in result) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py new file mode 100644 index 0000000000..f6d1edbaf0 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -0,0 +1,551 @@ +import json +from unittest.mock import MagicMock + +import pytest + +from core.agent.cot_agent_runner import CotAgentRunner +from core.agent.entities import AgentScratchpadUnit +from core.agent.errors import AgentMaxIterationError +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class DummyRunner(CotAgentRunner): + """Concrete implementation for testing abstract methods.""" + + def __init__(self, **kwargs): + # Completely bypass BaseAgentRunner __init__ to avoid DB/session usage + for k, v in kwargs.items(): + setattr(self, k, v) + # Minimal required defaults + self.history_prompt_messages = [] + self.memory = None + + def _organize_prompt_messages(self): + return [] + + +@pytest.fixture +def runner(mocker): + # Prevent BaseAgentRunner __init__ from hitting database + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history", + return_value=[], + ) + # Prepare required constructor dependencies for BaseAgentRunner + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock() + application_generate_entity.model_conf.stop = [] + application_generate_entity.model_conf.provider = "openai" + application_generate_entity.model_conf.parameters = {} + application_generate_entity.trace_manager = None + application_generate_entity.invoke_from = "test" + + app_config = MagicMock() + app_config.agent = MagicMock() + app_config.agent.max_iteration = 1 + app_config.prompt_template.simple_prompt_template = "Hello {{name}}" + + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + model_instance.invoke_llm.return_value = [] + + model_config = MagicMock() + model_config.model = "test-model" + + queue_manager = MagicMock() + message = MagicMock() + + runner = DummyRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=MagicMock(), + app_config=app_config, + model_config=model_config, + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Patch internal methods to isolate behavior + runner._repack_app_generate_entity = MagicMock() + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.recalc_llm_max_tokens = MagicMock() + runner.create_agent_thought = MagicMock(return_value="thought-id") + runner.save_agent_thought = MagicMock() + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + runner.memory = None + runner.history_prompt_messages = [] + + return runner + + +class TestFillInputs: + @pytest.mark.parametrize( + ("instruction", "inputs", "expected"), + [ + ("Hello {{name}}", {"name": "John"}, "Hello John"), + ("No placeholders", {"name": "John"}, "No placeholders"), + ("{{a}}{{b}}", {"a": 1, "b": 2}, "12"), + ("{{x}}", {"x": None}, "None"), + ("", {"x": "y"}, ""), + ], + ) + def test_fill_in_inputs(self, runner, instruction, inputs, expected): + result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs) + assert result == expected + + +class TestConvertDictToAction: + def test_convert_valid_dict(self, runner): + action_dict = {"action": "test", "action_input": {"a": 1}} + action = runner._convert_dict_to_action(action_dict) + assert action.action_name == "test" + assert action.action_input == {"a": 1} + + def test_convert_missing_keys(self, runner): + with pytest.raises(KeyError): + runner._convert_dict_to_action({"invalid": 1}) + + +class TestFormatAssistantMessage: + def test_format_assistant_message_multiple_scratchpads(self, runner): + sp1 = AgentScratchpadUnit( + agent_response="resp1", + thought="thought1", + action_str="action1", + action=AgentScratchpadUnit.Action(action_name="tool", action_input={}), + observation="obs1", + ) + sp2 = AgentScratchpadUnit( + agent_response="final", + thought="", + action_str="", + action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"), + observation=None, + ) + result = runner._format_assistant_message([sp1, sp2]) + assert "Final Answer:" in result + + def test_format_with_final(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="Done", + thought="", + action_str="", + action=None, + observation=None, + ) + # Simulate final state via action name + scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done") + result = runner._format_assistant_message([scratchpad]) + assert "Final Answer" in result + + def test_format_with_action_and_observation(self, runner): + scratchpad = AgentScratchpadUnit( + agent_response="resp", + thought="thinking", + action_str="action", + action=None, + observation="obs", + ) + # Non-final state: provide a non-final action + scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + result = runner._format_assistant_message([scratchpad]) + assert "Thought:" in result + assert "Action:" in result + assert "Observation:" in result + + +class TestHandleInvokeAction: + def test_handle_invoke_action_tool_not_present(self, runner): + action = AgentScratchpadUnit.Action(action_name="missing", action_input={}) + response, meta = runner._handle_invoke_action(action, {}, []) + assert "there is not a tool named" in response + + def test_tool_with_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1})) + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("result", [], MagicMock(to_dict=lambda: {})), + ) + + response, meta = runner._handle_invoke_action(action, tool_instances, []) + assert response == "result" + + +class TestOrganizeHistoricPromptMessages: + def test_empty_history(self, runner, mocker): + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt", + return_value=[], + ) + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRun: + def test_run_handles_empty_parser_output(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert isinstance(results, list) + + def test_run_with_action_and_tool_invocation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_respects_max_iteration_boundary(self, runner, mocker): + runner.app_config.agent.max_iteration = 1 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.agent_callback = None + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {"tool": MagicMock()})) + + def test_run_basic_flow(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {"name": "John"})) + assert results + + def test_run_max_iteration_error(self, runner, mocker): + runner.app_config.agent.max_iteration = 0 + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query", {})) + + def test_run_increase_usage_aggregation(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + runner.app_config.agent.max_iteration = 2 + + usage_1 = LLMUsage.empty_usage() + usage_1.prompt_tokens = 1 + usage_1.completion_tokens = 1 + usage_1.total_tokens = 2 + usage_1.prompt_price = 1 + usage_1.completion_price = 1 + usage_1.total_price = 2 + + usage_2 = LLMUsage.empty_usage() + usage_2.prompt_tokens = 1 + usage_2.completion_tokens = 1 + usage_2.total_tokens = 2 + usage_2.prompt_price = 1 + usage_2.completion_price = 1 + usage_2.total_price = 2 + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + handle_output = mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[ + [action], + [], + ], + ) + + def _handle_side_effect(chunks, usage_dict): + call_index = handle_output.call_count + usage_dict["usage"] = usage_1 if call_index == 1 else usage_2 + return [action] if call_index == 1 else [] + + handle_output.side_effect = _handle_side_effect + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + results = list(runner.run(message, "query", {})) + final_usage = results[-1].delta.usage + assert final_usage is not None + assert final_usage.prompt_tokens == 2 + assert final_usage.completion_tokens == 2 + assert final_usage.total_tokens == 4 + assert final_usage.prompt_price == 2 + assert final_usage.completion_price == 2 + assert final_usage.total_price == 4 + + def test_run_when_no_action_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "" + + def test_run_usage_missing_key_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[], + ) + + runner.model_instance.invoke_llm = MagicMock(return_value=[]) + + list(runner.run(message, "query", {})) + + def test_run_prompt_tool_update_branch(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="tool", action_input={}) + + # First iteration → action + # Second iteration → no action (empty list) + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + side_effect=[[action], []], + ) + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", [], MagicMock(to_dict=lambda: {})), + ) + + runner.app_config.agent.max_iteration = 5 + + fake_prompt_tool = MagicMock() + fake_prompt_tool.name = "tool" + + runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool])) + + runner.update_prompt_message_tool = MagicMock() + runner.agent_callback = None + + list(runner.run(message, "query", {})) + + runner.update_prompt_message_tool.assert_called_once() + + def test_historic_with_assistant_and_tool_calls(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + + assistant = AssistantPromptMessage(content="thinking") + assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] + + tool_msg = ToolPromptMessage(content="obs", tool_call_id="1") + + runner.history_prompt_messages = [assistant, tool_msg] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + def test_historic_final_flush_branch(self, runner): + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + assistant = AssistantPromptMessage(content="final") + runner.history_prompt_messages = [assistant] + + result = runner._organize_historic_prompt_messages([]) + assert isinstance(result, list) + + +class TestInitReactState: + def test_init_react_state_resets_state(self, runner, mocker): + mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"]) + runner._agent_scratchpad = ["old"] + runner._query = "old" + + runner._init_react_state("new-query") + + assert runner._query == "new-query" + assert runner._agent_scratchpad == [] + assert runner._historic_prompt_messages == ["historic"] + + +class TestHandleInvokeActionExtended: + def test_tool_with_invalid_json_string_args(self, runner, mocker): + action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json") + tool_instance = MagicMock() + tool_instances = {"tool": tool_instance} + + mocker.patch( + "core.agent.cot_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})), + ) + + message_file_ids = [] + response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids) + + assert response == "ok" + assert message_file_ids == ["file1"] + runner.queue_manager.publish.assert_called() + + +class TestFillInputsEdgeCases: + def test_fill_inputs_with_empty_inputs(self, runner): + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {}) + assert result == "Hello {{x}}" + + def test_fill_inputs_with_exception_in_replace(self, runner): + class BadValue: + def __str__(self): + raise Exception("fail") + + # Should silently continue on exception + result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()}) + assert result == "Hello {{x}}" + + +class TestOrganizeHistoricPromptMessagesExtended: + def test_user_message_flushes_scratchpad(self, runner, mocker): + from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + user_message = UserPromptMessage(content="Hi") + + runner.history_prompt_messages = [user_message] + + mock_transform = mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + ) + mock_transform.return_value.get_prompt.return_value = ["final"] + + result = runner._organize_historic_prompt_messages([]) + assert result == ["final"] + + def test_tool_message_without_scratchpad_raises(self, runner): + from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + + runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] + + with pytest.raises(NotImplementedError): + runner._organize_historic_prompt_messages([]) + + def test_agent_history_transform_invocation(self, runner, mocker): + mock_transform = MagicMock() + mock_transform.get_prompt.return_value = [] + + mocker.patch( + "core.agent.cot_agent_runner.AgentHistoryPromptTransform", + return_value=mock_transform, + ) + + runner.history_prompt_messages = [] + result = runner._organize_historic_prompt_messages([]) + assert result == [] + + +class TestRunAdditionalBranches: + def test_run_with_no_action_final_answer_empty(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=["thinking"], + ) + + results = list(runner.run(message, "query", {})) + assert any(hasattr(r, "delta") for r in results) + + def test_run_with_final_answer_action_string(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "done" + + def test_run_with_final_answer_action_dict(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1}) + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert json.loads(results[-1].delta.message.content) == {"a": 1} + + def test_run_with_string_final_answer(self, runner, mocker): + message = MagicMock() + message.id = "msg-id" + + # Remove invalid branch: Pydantic enforces str|dict for action_input + action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345") + + mocker.patch( + "core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output", + return_value=[action], + ) + + results = list(runner.run(message, "query", {})) + assert results[-1].delta.message.content == "12345" diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py new file mode 100644 index 0000000000..f9d69d1196 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -0,0 +1,215 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.agent.cot_chat_agent_runner import CotChatAgentRunner +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent +from tests.unit_tests.core.agent.conftest import ( + DummyAgentConfig, + DummyAppConfig, + DummyTool, +) +from tests.unit_tests.core.agent.conftest import ( + DummyPromptEntity as DummyPrompt, +) + + +class DummyFileUploadConfig: + def __init__(self, image_config=None): + self.image_config = image_config + + +class DummyImageConfig: + def __init__(self, detail=None): + self.detail = detail + + +class DummyGenerateEntity: + def __init__(self, file_upload_config=None): + self.file_upload_config = file_upload_config + + +class DummyUnit: + def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None): + self._final = final + self.thought = thought + self.action_str = action_str + self.observation = observation + self.agent_response = agent_response + + def is_final(self): + return self._final + + +@pytest.fixture +def runner(): + runner = CotChatAgentRunner.__new__(CotChatAgentRunner) + runner._instruction = "test_instruction" + runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")] + runner._query = "user query" + runner._agent_scratchpad = [] + runner.files = [] + runner.application_generate_entity = DummyGenerateEntity() + runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"]) + return runner + + +class TestOrganizeSystemPrompt: + def test_organize_system_prompt_success(self, runner, mocker): + first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}" + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt))) + + mocker.patch( + "core.agent.cot_chat_agent_runner.jsonable_encoder", + return_value=[{"name": "tool1"}, {"name": "tool2"}], + ) + + result = runner._organize_system_prompt() + + assert "test_instruction" in result.content + assert "tool1" in result.content + assert "tool2" in result.content + assert "tool1, tool2" in result.content + + def test_organize_system_prompt_missing_agent(self, runner): + runner.app_config = DummyAppConfig(agent=None) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + def test_organize_system_prompt_missing_prompt(self, runner): + runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None)) + with pytest.raises(AssertionError): + runner._organize_system_prompt() + + +class TestOrganizeUserQuery: + @pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")]) + def test_organize_user_query_no_files(self, runner, files): + runner.files = files + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert result[0].content == "query" + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.LOW, + ) + + @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): + from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + mock_content = ImagePromptMessageContent( + url="http://test", + format="png", + mime_type="image/png", + ) + mock_to_prompt.return_value = mock_content + mock_user_prompt.side_effect = lambda content: MagicMock(content=content) + + runner.files = ["file1"] + + image_config = DummyImageConfig(detail="high") + runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config)) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert mock_content in result[0].content + mock_to_prompt.assert_called_once_with( + "file1", + image_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") + def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner): + mock_to_prompt.return_value = TextPromptMessageContent(data="file_content") + runner.files = ["file1"] + runner.application_generate_entity = DummyGenerateEntity(None) + + result = runner._organize_user_query("query", []) + assert len(result) == 1 + assert isinstance(result[0].content, list) + + +class TestOrganizePromptMessages: + def test_no_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + result = runner._organize_prompt_messages() + assert "system" in result + assert "query" in result + runner._organize_historic_prompt_messages.assert_called_once() + + def test_with_final_scratchpad(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit(final=True, agent_response="done") + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Final Answer: done" in combined + + def test_with_thought_action_observation(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + unit = DummyUnit( + final=False, + thought="thinking", + action_str="action", + observation="observe", + ) + runner._agent_scratchpad = [unit] + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: thinking" in combined + assert "Action: action" in combined + assert "Observation: observe" in combined + + def test_multiple_units_mixed(self, runner, mocker): + runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}"))) + runner._organize_system_prompt = MagicMock(return_value="system") + runner._organize_user_query = MagicMock(return_value=["query"]) + + units = [ + DummyUnit(final=False, thought="t1"), + DummyUnit(final=True, agent_response="done"), + ] + runner._agent_scratchpad = units + + result = runner._organize_prompt_messages() + assistant_msgs = [m for m in result if hasattr(m, "content")] + combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)]) + assert "Thought: t1" in combined + assert "Final Answer: done" in combined diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py new file mode 100644 index 0000000000..ab822bb57d --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -0,0 +1,234 @@ +import json + +import pytest + +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def runner(mocker, dummy_tool_factory): + runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner) + + runner._instruction = "Test instruction" + runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")] + runner._query = "What is Python?" + runner._agent_scratchpad = [] + + mocker.patch( + "core.agent.cot_completion_agent_runner.jsonable_encoder", + side_effect=lambda tools: [{"name": t.name} for t in tools], + ) + + return runner + + +# ====================================================== +# _organize_instruction_prompt Tests +# ====================================================== + + +class TestOrganizeInstructionPrompt: + def test_success_all_placeholders( + self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = ( + "{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}" + ) + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + result = runner._organize_instruction_prompt() + + assert "Test instruction" in result + assert "toolA" in result + assert "toolB" in result + tools_payload = json.loads(result.split(" | ")[1]) + assert {item["name"] for item in tools_payload} == {"toolA", "toolB"} + + def test_agent_none_raises(self, runner, dummy_app_config_factory): + runner.app_config = dummy_app_config_factory(agent=None) + with pytest.raises(ValueError, match="Agent configuration is not set"): + runner._organize_instruction_prompt() + + def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory): + runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None)) + with pytest.raises(ValueError, match="prompt entity is not set"): + runner._organize_instruction_prompt() + + +# ====================================================== +# _organize_historic_prompt Tests +# ====================================================== + + +class TestOrganizeHistoricPrompt: + def test_with_user_and_assistant_string(self, runner, mocker): + user_msg = UserPromptMessage(content="Hello") + assistant_msg = AssistantPromptMessage(content="Hi there") + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[user_msg, assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Question: Hello" in result + assert "Hi there" in result + + def test_assistant_list_with_text_content(self, runner, mocker): + text_content = TextPromptMessageContent(data="Partial answer") + assistant_msg = AssistantPromptMessage(content=[text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + + assert "Partial answer" in result + + def test_assistant_list_with_non_text_content_ignored(self, runner, mocker): + non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png") + assistant_msg = AssistantPromptMessage(content=[non_text_content]) + + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[assistant_msg], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + def test_empty_history(self, runner, mocker): + mocker.patch.object( + runner, + "_organize_historic_prompt_messages", + return_value=[], + ) + + result = runner._organize_historic_prompt() + assert result == "" + + +# ====================================================== +# _organize_prompt_messages Tests +# ====================================================== + + +class TestOrganizePromptMessages: + def test_full_flow_with_scratchpad( + self, + runner, + mocker, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"), + dummy_scratchpad_unit_factory(final=True, agent_response="Done"), + ] + + result = runner._organize_prompt_messages() + + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], UserPromptMessage) + + content = result[0].content + + assert "History" in content + assert "Thought: Thinking" in content + assert "Action: Act" in content + assert "Observation: Obs" in content + assert "Final Answer: Done" in content + assert "Question: What is Python?" in content + + def test_no_scratchpad( + self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = None + + result = runner._organize_prompt_messages() + + assert "Question: What is Python?" in result[0].content + + @pytest.mark.parametrize( + ("thought", "action", "observation"), + [ + ("T", None, None), + ("T", "A", None), + ("T", None, "O"), + ], + ) + def test_partial_scratchpad_units( + self, + runner, + mocker, + thought, + action, + observation, + dummy_app_config_factory, + dummy_agent_config_factory, + dummy_prompt_entity_factory, + dummy_scratchpad_unit_factory, + ): + template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}" + + runner.app_config = dummy_app_config_factory( + agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template)) + ) + + mocker.patch.object(runner, "_organize_historic_prompt", return_value="") + + runner._agent_scratchpad = [ + dummy_scratchpad_unit_factory( + final=False, + thought=thought, + action_str=action, + observation=observation, + ) + ] + + result = runner._organize_prompt_messages() + content = result[0].content + + assert "Thought:" in content + if action: + assert "Action:" in content + if observation: + assert "Observation:" in content diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py new file mode 100644 index 0000000000..299c9b31d2 --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -0,0 +1,452 @@ +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + DocumentPromptMessageContent, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) + +# ============================== +# Dummy Helper Classes +# ============================== + + +def build_usage(pt=1, ct=1, tt=2) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.prompt_tokens = pt + usage.completion_tokens = ct + usage.total_tokens = tt + usage.prompt_price = 0 + usage.completion_price = 0 + usage.total_price = 0 + return usage + + +class DummyMessage: + def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None): + self.content: str | None = content + self.tool_calls: list[Any] = tool_calls or [] + + +class DummyDelta: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + + +class DummyChunk: + def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None): + self.delta: DummyDelta = DummyDelta(message=message, usage=usage) + + +class DummyResult: + def __init__( + self, + message: DummyMessage | None = None, + usage: LLMUsage | None = None, + prompt_messages: list[DummyMessage] | None = None, + ): + self.message: DummyMessage | None = message + self.usage: LLMUsage | None = usage + self.prompt_messages: list[DummyMessage] = prompt_messages or [] + self.system_fingerprint: str = "" + + +# ============================== +# Fixtures +# ============================== + + +@pytest.fixture +def runner(mocker): + # Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context + mocker.patch( + "core.agent.base_agent_runner.BaseAgentRunner.__init__", + return_value=None, + ) + + # Patch streaming chunk models to avoid validation on dummy message objects + mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock) + mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock) + + app_config = MagicMock() + app_config.agent = MagicMock(max_iteration=2) + app_config.prompt_template = MagicMock(simple_prompt_template="system") + + application_generate_entity = MagicMock() + application_generate_entity.model_conf = MagicMock(parameters={}, stop=None) + application_generate_entity.trace_manager = MagicMock() + application_generate_entity.invoke_from = "test" + application_generate_entity.app_config = MagicMock(app_id="app") + application_generate_entity.file_upload_config = None + + queue_manager = MagicMock() + model_instance = MagicMock() + model_instance.model = "test-model" + model_instance.model_name = "test-model" + + message = MagicMock(id="msg1") + conversation = MagicMock(id="conv1") + + runner = FunctionCallAgentRunner( + tenant_id="tenant", + application_generate_entity=application_generate_entity, + conversation=conversation, + app_config=app_config, + model_config=MagicMock(), + config=MagicMock(), + queue_manager=queue_manager, + message=message, + user_id="user", + model_instance=model_instance, + ) + + # Manually inject required attributes normally set by BaseAgentRunner + runner.tenant_id = "tenant" + runner.application_generate_entity = application_generate_entity + runner.conversation = conversation + runner.app_config = app_config + runner.model_config = MagicMock() + runner.config = MagicMock() + runner.queue_manager = queue_manager + runner.message = message + runner.user_id = "user" + runner.model_instance = model_instance + + runner.stream_tool_call = False + runner.memory = None + runner.history_prompt_messages = [] + runner._current_thoughts = [] + runner.files = [] + runner.agent_callback = MagicMock() + + runner._init_prompt_tools = MagicMock(return_value=({}, [])) + runner.create_agent_thought = MagicMock(return_value="thought1") + runner.save_agent_thought = MagicMock() + runner.recalc_llm_max_tokens = MagicMock() + runner.update_prompt_message_tool = MagicMock() + + return runner + + +# ============================== +# Tool Call Checks +# ============================== + + +class TestToolCallChecks: + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_tool_calls(self, runner, tool_calls, expected): + chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_tool_calls(chunk) is expected + + @pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)]) + def test_check_blocking_tool_calls(self, runner, tool_calls, expected): + result = DummyResult(message=DummyMessage(tool_calls=tool_calls)) + assert runner.check_blocking_tool_calls(result) is expected + + +# ============================== +# Extract Tool Calls +# ============================== + + +class TestExtractToolCalls: + def test_extract_tool_calls_with_valid_json(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {"a": 1})] + + def test_extract_tool_calls_empty_arguments(self, runner): + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "" + + chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_tool_calls(chunk) + + assert calls == [("1", "tool", {})] + + def test_extract_blocking_tool_calls(self, runner): + tool_call = MagicMock() + tool_call.id = "2" + tool_call.function.name = "block" + tool_call.function.arguments = json.dumps({"x": 2}) + + result = DummyResult(message=DummyMessage(tool_calls=[tool_call])) + calls = runner.extract_blocking_tool_calls(result) + + assert calls == [("2", "block", {"x": 2})] + + +# ============================== +# System Message Initialization +# ============================== + + +class TestInitSystemMessage: + def test_init_system_message_empty_prompt_messages(self, runner): + result = runner._init_system_message("system", []) + assert len(result) == 1 + + def test_init_system_message_insert_at_start(self, runner): + msgs = [MagicMock()] + result = runner._init_system_message("system", msgs) + assert result[0].content == "system" + + def test_init_system_message_no_template(self, runner): + result = runner._init_system_message("", []) + assert result == [] + + +# ============================== +# Organize User Query +# ============================== + + +class TestOrganizeUserQuery: + def test_without_files(self, runner): + result = runner._organize_user_query("query", []) + assert len(result) == 1 + + def test_with_none_query(self, runner): + result = runner._organize_user_query(None, []) + assert len(result) == 1 + + def test_with_files_uses_image_detail_config(self, runner, mocker): + file_content = TextPromptMessageContent(data="file-content") + mock_to_prompt = mocker.patch( + "core.agent.fc_agent_runner.file_manager.to_prompt_message_content", + return_value=file_content, + ) + + image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH) + runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config) + runner.files = ["file1"] + + result = runner._organize_user_query("query", []) + + assert len(result) == 1 + assert isinstance(result[0].content, list) + mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH) + + +# ============================== +# Clear User Prompt Images +# ============================== + + +class TestClearUserPromptImageMessages: + def test_clear_text_and_image_content(self, runner): + text = MagicMock() + text.type = "text" + text.data = "hello" + + image = MagicMock() + image.type = "image" + image.data = "img" + + user_msg = MagicMock() + user_msg.__class__.__name__ = "UserPromptMessage" + user_msg.content = [text, image] + + result = runner._clear_user_prompt_image_messages([user_msg]) + assert isinstance(result, list) + + def test_clear_includes_file_placeholder(self, runner): + text = TextPromptMessageContent(data="hello") + image = ImagePromptMessageContent(format="url", mime_type="image/png") + document = DocumentPromptMessageContent(format="url", mime_type="application/pdf") + + user_msg = UserPromptMessage(content=[text, image, document]) + + result = runner._clear_user_prompt_image_messages([user_msg]) + + assert result[0].content == "hello\n[image]\n[file]" + + +# ============================== +# Run Method Tests +# ============================== + + +class TestRunMethod: + def test_run_non_streaming_no_tool_calls(self, runner): + message = MagicMock(id="m1") + dummy_message = DummyMessage(content="hello") + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + runner.queue_manager.publish.assert_called() + + queue_calls = runner.queue_manager.publish.call_args_list + assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls) + + def test_run_streaming_branch(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_streaming_tool_calls_list_content(self, runner): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + content = [TextPromptMessageContent(data="hi")] + chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [generator(), final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_non_streaming_list_content(self, runner): + message = MagicMock(id="m1") + content = [TextPromptMessageContent(data="hi")] + dummy_message = DummyMessage(content=content) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi" + + def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker): + message = MagicMock(id="m1") + runner.stream_tool_call = True + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage()) + + def generator(): + yield chunk + + runner.model_instance.invoke_llm.return_value = generator() + + real_dumps = json.dumps + + def flaky_dumps(obj, *args, **kwargs): + if kwargs.get("ensure_ascii") is False: + return real_dumps(obj, *args, **kwargs) + raise TypeError("boom") + + mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps) + + outputs = list(runner.run(message, "query")) + assert len(outputs) == 1 + + def test_run_with_missing_tool_instance(self, runner): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "missing" + tool_call.function.arguments = json.dumps({}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_message = DummyMessage(content="done", tool_calls=[]) + final_result = DummyResult(message=final_message, usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + + def test_run_with_tool_instance_and_files(self, runner, mocker): + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = json.dumps({"a": 1}) + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage()) + + runner.model_instance.invoke_llm.side_effect = [result, final_result] + + tool_instance = MagicMock() + prompt_tool = MagicMock() + prompt_tool.name = "tool" + runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool]) + + tool_invoke_meta = MagicMock() + tool_invoke_meta.to_dict.return_value = {"ok": True} + mocker.patch( + "core.agent.fc_agent_runner.ToolEngine.agent_invoke", + return_value=("ok", ["file1"], tool_invoke_meta), + ) + + outputs = list(runner.run(message, "query")) + assert len(outputs) >= 1 + assert any( + isinstance(call.args[0], QueueMessageFileEvent) + and call.args[0].message_file_id == "file1" + and call.args[1] == PublishFrom.APPLICATION_MANAGER + for call in runner.queue_manager.publish.call_args_list + ) + + def test_run_max_iteration_error(self, runner): + runner.app_config.agent.max_iteration = 0 + + message = MagicMock(id="m1") + + tool_call = MagicMock() + tool_call.id = "1" + tool_call.function.name = "tool" + tool_call.function.arguments = "{}" + + dummy_message = DummyMessage(content="", tool_calls=[tool_call]) + result = DummyResult(message=dummy_message, usage=build_usage()) + + runner.model_instance.invoke_llm.return_value = result + + with pytest.raises(AgentMaxIterationError): + list(runner.run(message, "query")) diff --git a/api/tests/unit_tests/core/agent/test_plugin_entities.py b/api/tests/unit_tests/core/agent/test_plugin_entities.py new file mode 100644 index 0000000000..9955190aca --- /dev/null +++ b/api/tests/unit_tests/core/agent/test_plugin_entities.py @@ -0,0 +1,324 @@ +"""Unit tests for core.agent.plugin_entities. + +Covers entities such as AgentFeature, AgentProviderEntityWithPlugin, +AgentStrategyEntity, AgentStrategyIdentity, AgentStrategyParameter, +AgentStrategyProviderEntity, and AgentStrategyProviderIdentity. Tests rely on +Pydantic ValidationError behavior and pytest fixtures for validation and +mocking; ensure entity invariants and validation rules remain stable. +""" + +import pytest +from pydantic import ValidationError + +from core.agent.plugin_entities import ( + AgentFeature, + AgentProviderEntityWithPlugin, + AgentStrategyEntity, + AgentStrategyIdentity, + AgentStrategyParameter, + AgentStrategyProviderEntity, + AgentStrategyProviderIdentity, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolIdentity, ToolProviderIdentity + +# ========================================================= +# Fixtures +# ========================================================= + + +@pytest.fixture +def mock_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyIdentity) + + +@pytest.fixture +def mock_provider_identity(mocker): + return mocker.MagicMock(spec=AgentStrategyProviderIdentity) + + +# ========================================================= +# AgentStrategyParameterType Tests +# ========================================================= + + +class TestAgentStrategyParameterType: + @pytest.mark.parametrize( + "enum_member", + list(AgentStrategyParameter.AgentStrategyParameterType), + ) + def test_as_normal_type_calls_external_function(self, mocker, enum_member) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.as_normal_type", + return_value="normalized", + ) + + result = enum_member.as_normal_type() + + mock_func.assert_called_once_with(enum_member) + assert result == "normalized" + + def test_as_normal_type_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.as_normal_type", + side_effect=RuntimeError("boom"), + ) + + with pytest.raises(RuntimeError): + enum_member.as_normal_type() + + @pytest.mark.parametrize( + ("enum_member", "value"), + [ + (AgentStrategyParameter.AgentStrategyParameterType.STRING, "abc"), + (AgentStrategyParameter.AgentStrategyParameterType.NUMBER, 10), + (AgentStrategyParameter.AgentStrategyParameterType.BOOLEAN, True), + (AgentStrategyParameter.AgentStrategyParameterType.ANY, {"a": 1}), + (AgentStrategyParameter.AgentStrategyParameterType.STRING, None), + (AgentStrategyParameter.AgentStrategyParameterType.FILES, []), + ], + ) + def test_cast_value_calls_external_function(self, mocker, enum_member, value) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + return_value="casted", + ) + + result = enum_member.cast_value(value) + + mock_func.assert_called_once_with(enum_member, value) + assert result == "casted" + + def test_cast_value_propagates_exception(self, mocker) -> None: + enum_member = AgentStrategyParameter.AgentStrategyParameterType.STRING + mocker.patch( + "core.agent.plugin_entities.cast_parameter_value", + side_effect=ValueError("invalid"), + ) + + with pytest.raises(ValueError): + enum_member.cast_value("bad") + + +# ========================================================= +# AgentStrategyParameter Tests +# ========================================================= + + +class TestAgentStrategyParameter: + def test_valid_creation_minimal(self) -> None: + # bypass base PluginParameter required fields using model_construct + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=None, + ) + assert param.type == AgentStrategyParameter.AgentStrategyParameterType.STRING + assert param.help is None + + def test_valid_creation_with_help(self) -> None: + help_obj = I18nObject(en_US="test") + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + help=help_obj, + ) + assert param.help == help_obj + + @pytest.mark.parametrize("invalid_type", [None, "invalid_type", 999, [], {}, ["bad"], {"bad": 1}]) + def test_invalid_type_raises_validation_error(self, invalid_type) -> None: + with pytest.raises(ValidationError) as exc_info: + AgentStrategyParameter(type=invalid_type, name="x", label=I18nObject(en_US="y", zh_Hans="y")) + + assert any(error["loc"] == ("type",) for error in exc_info.value.errors()) + + def test_init_frontend_parameter_calls_external(self, mocker) -> None: + mock_func = mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + return_value="frontend", + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + result = param.init_frontend_parameter("value") + + mock_func.assert_called_once_with(param, param.type, "value") + assert result == "frontend" + + def test_init_frontend_parameter_propagates_exception(self, mocker) -> None: + mocker.patch( + "core.agent.plugin_entities.init_frontend_parameter", + side_effect=RuntimeError("error"), + ) + + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + with pytest.raises(RuntimeError): + param.init_frontend_parameter("value") + + +# ========================================================= +# AgentStrategyProviderEntity Tests +# ========================================================= + + +class TestAgentStrategyProviderEntity: + def test_creation_with_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="plugin-123", + ) + assert entity.plugin_id == "plugin-123" + + def test_creation_with_empty_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity( + identity=mock_provider_identity, + plugin_id="", + ) + assert entity.plugin_id == "" + + def test_creation_without_plugin_id(self, mock_provider_identity) -> None: + entity = AgentStrategyProviderEntity(identity=mock_provider_identity) + assert entity.plugin_id is None + + def test_invalid_identity_raises(self) -> None: + with pytest.raises(ValidationError): + AgentStrategyProviderEntity(identity="invalid") + + +# ========================================================= +# AgentStrategyEntity Tests +# ========================================================= + + +class TestAgentStrategyEntity: + def test_parameters_default_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + ) + assert entity.parameters == [] + + def test_parameters_none_converted_to_empty(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=None, + ) + assert entity.parameters == [] + + def test_parameters_preserved(self, mock_identity) -> None: + param = AgentStrategyParameter.model_construct( + type=AgentStrategyParameter.AgentStrategyParameterType.STRING, + name="test", + label="label", + ) + + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[param], + ) + assert entity.parameters == [param] + + def test_invalid_parameters_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters="invalid", + ) + + @pytest.mark.parametrize( + "features", + [ + None, + [], + [AgentFeature.HISTORY_MESSAGES], + ], + ) + def test_features_valid(self, mock_identity, features) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features=features, + ) + assert entity.features == features + + def test_invalid_features_type_raises(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + features="invalid", + ) + + def test_output_schema_and_meta_version(self, mock_identity) -> None: + entity = AgentStrategyEntity( + identity=mock_identity, + description=I18nObject(en_US="test"), + output_schema={"type": "object"}, + meta_version="v1", + ) + assert entity.output_schema == {"type": "object"} + assert entity.meta_version == "v1" + + def test_missing_required_fields_raise(self, mock_identity) -> None: + with pytest.raises(ValidationError): + AgentStrategyEntity(identity=mock_identity) + + +# ========================================================= +# AgentProviderEntityWithPlugin Tests +# ========================================================= + + +class TestAgentProviderEntityWithPlugin: + def test_default_strategies_empty(self, mock_provider_identity) -> None: + entity = AgentProviderEntityWithPlugin(identity=mock_provider_identity) + assert entity.strategies == [] + + def test_strategies_assignment(self, mock_provider_identity, mock_identity) -> None: + strategy = AgentStrategyEntity.model_construct( + identity=mock_identity, + description=I18nObject(en_US="test"), + parameters=[], + ) + + entity = AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies=[strategy], + ) + assert entity.strategies == [strategy] + + def test_invalid_strategies_type_raises(self, mock_provider_identity) -> None: + with pytest.raises(ValidationError): + AgentProviderEntityWithPlugin( + identity=mock_provider_identity, + strategies="invalid", + ) + + +# ========================================================= +# Inheritance Smoke Tests +# ========================================================= + + +class TestInheritanceBehavior: + def test_agent_strategy_identity_inherits(self) -> None: + assert issubclass(AgentStrategyIdentity, ToolIdentity) + + def test_agent_strategy_provider_identity_inherits(self) -> None: + assert issubclass(AgentStrategyProviderIdentity, ToolProviderIdentity) diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index 9dddb18595..de99833aac 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,6 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.workflow.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/apps/__init__.py b/api/tests/unit_tests/core/app/apps/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py b/api/tests/unit_tests/core/app/apps/advanced_chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py new file mode 100644 index 0000000000..6ca4f60459 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_config_manager.py @@ -0,0 +1,75 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager +from models.model import AppMode + + +class TestAdvancedChatAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.ADVANCED_CHAT.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = AdvancedChatAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.ADVANCED_CHAT + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + config = kwargs.get("config") if kwargs else args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.advanced_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 2), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 3), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 4), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 5), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 6), + ), + patch( + "core.app.apps.advanced_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 7), + ), + ): + filtered = AdvancedChatAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["opening_statement"] == 2 + assert filtered["suggested_questions_after_answer"] == 3 + assert filtered["speech_to_text"] == 4 + assert filtered["text_to_speech"] == 5 + assert filtered["retriever_resource"] == 6 + assert filtered["sensitive_word_avoidance"] == 7 diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py new file mode 100644 index 0000000000..441d2fcd17 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -0,0 +1,1260 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import BaseModel, ValidationError + +from constants import UUID_NIL +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestAdvancedChatAppGeneratorValidation: + def test_generate_requires_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query is required"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_generate_requires_string_query(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="query must be a string"): + generator.generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}, "query": 123}, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + def test_single_iteration_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = AdvancedChatAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestAdvancedChatAppGeneratorInternals: + @staticmethod + def _build_app_config() -> WorkflowUIBasedAppConfig: + return WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + def test_generate_loads_conversation_and_files(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + + conversation = SimpleNamespace(id="conversation-id") + built_files: list[object] = [] + build_files_called = {"called": False} + captured: dict[str, object] = {} + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.ConversationService.get_conversation", + lambda **kwargs: conversation, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda *args, **kwargs: {"enabled": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.file_factory.build_from_mappings", + lambda **kwargs: build_files_called.update({"called": True}) or built_files, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: kwargs["user_inputs"]) + + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.TraceQueueManager", DummyTraceQueueManager) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user-id" + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=user, + args={ + "query": "hello", + "inputs": {"k": "v"}, + "conversation_id": "conversation-id", + "files": [{"id": "f"}], + }, + invoke_from=InvokeFrom.WEB_APP, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert captured["conversation"] is conversation + assert captured["application_generate_entity"].files == built_files + assert build_files_called["called"] is True + + def test_resume_delegates_to_generate(self, monkeypatch): + generator = AdvancedChatAppGenerator() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=self._build_app_config(), + inputs={}, + query="hello", + files=[], + user_id="user", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + captured: dict[str, object] = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"resumed": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.resume( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + user=SimpleNamespace(), + conversation=SimpleNamespace(id="conversation-id"), + message=SimpleNamespace(id="message-id"), + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_runtime_state=SimpleNamespace(), + pause_state_config=None, + ) + + assert result == {"resumed": True} + assert captured["graph_runtime_state"] is not None + + def test_single_iteration_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + workflow = SimpleNamespace(id="workflow-id") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow, user_id): + prefill_calls.append((workflow, user_id)) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_iteration_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=workflow, + node_id="node-1", + user=SimpleNamespace(id="user-id"), + args={"inputs": {"foo": "bar"}}, + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls == [(workflow, "user-id")] + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_iteration_run.node_id == "node-1" + + def test_single_loop_generate_builds_debug_task(self, monkeypatch): + generator = AdvancedChatAppGenerator() + app_config = self._build_app_config() + captured: dict[str, object] = {} + prefill_calls: list[object] = [] + var_loader = SimpleNamespace(loader="draft") + workflow = SimpleNamespace(id="workflow-id") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda **kwargs: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(repo="execution"), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(repo="node"), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.DraftVarLoader", lambda **kwargs: var_loader) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace() + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=lambda: SimpleNamespace()), + ) + + class _DraftVarService: + def __init__(self, session): + _ = session + + def prefill_conversation_variable_default_values(self, workflow, user_id): + prefill_calls.append((workflow, user_id)) + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.WorkflowDraftVariableService", _DraftVarService) + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + result = generator.single_loop_generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=workflow, + node_id="node-2", + user=SimpleNamespace(id="user-id"), + args=SimpleNamespace(inputs={"foo": "bar"}), + streaming=False, + ) + + assert result == {"ok": True} + assert prefill_calls == [(workflow, "user-id")] + assert captured["variable_loader"] is var_loader + assert captured["application_generate_entity"].single_loop_run.node_id == "node-2" + + def test_generate_internal_flow_initial_conversation_with_pause_layer(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-1") + db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) + captured: dict[str, object] = {} + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", lambda *args: (conversation, message)) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 2) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.PauseStatePersistenceLayer", + lambda **kwargs: "pause-layer", + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: captured.update(kwargs) or {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: {"response": response, "invoke_from": invoke_from}, + ) + + pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") + + response = generator._generate( + workflow=SimpleNamespace(features={"feature": True}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=None, + message=None, + stream=False, + pause_state_config=pause_state_config, + ) + + assert response["response"] == {"raw": True} + assert thread_data["started"] is True + assert "pause-layer" in thread_data["kwargs"]["graph_engine_layers"] + assert generator._dialogue_count == 3 + db_session.commit.assert_called_once() + db_session.refresh.assert_called_once_with(conversation) + db_session.close.assert_called_once() + assert captured["draft_var_saver_factory"] == "draft-factory" + + def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 0 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) + message = SimpleNamespace(id="msg-2") + db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + init_records = MagicMock() + thread_data: dict[str, object] = {} + + monkeypatch.setattr(generator, "_init_generate_records", init_records) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.get_thread_messages_length", lambda _: 0) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.MessageBasedAppQueueManager", + lambda **kwargs: SimpleNamespace(**kwargs), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.current_app", + SimpleNamespace(_get_current_object=lambda: SimpleNamespace(name="flask")), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.contextvars.copy_context", lambda: "ctx") + + class _Thread: + def __init__(self, *, target, kwargs): + thread_data["target"] = target + thread_data["kwargs"] = kwargs + + def start(self): + thread_data["started"] = True + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) + ) + monkeypatch.setattr(generator, "_get_draft_var_saver_factory", lambda *args, **kwargs: "draft-factory") + monkeypatch.setattr( + generator, + "_handle_advanced_chat_response", + lambda **kwargs: {"raw": True}, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateResponseConverter.convert", + lambda response, invoke_from: response, + ) + + response = generator._generate( + workflow=SimpleNamespace(features={}), + user=SimpleNamespace(id="user"), + invoke_from=InvokeFrom.WEB_APP, + application_generate_entity=application_generate_entity, + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + conversation=conversation, + message=message, + stream=False, + ) + + assert response == {"raw": True} + init_records.assert_not_called() + assert thread_data["started"] is True + db_session.commit.assert_not_called() + db_session.refresh.assert_not_called() + db_session.close.assert_called_once() + + def test_generate_worker_raises_when_workflow_not_found(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock(return_value=None) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="Workflow not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_raises_when_app_not_found_for_internal_call(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + None, + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + with pytest.raises(ValueError, match="App not found"): + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + def test_generate_worker_handles_stopped_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise GenerateTaskStoppedError() + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_not_called() + + def test_generate_worker_handles_validation_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _ValidationModel(BaseModel): + value: int + + try: + _ValidationModel(value="invalid") + except ValidationError as error: + validation_error = error + else: + raise AssertionError("validation error should be created") + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise validation_error + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_generate_worker_handles_value_and_unknown_errors(self, monkeypatch): + app_config = self._build_app_config() + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + def _make_runner(error: Exception): + class _Runner: + def __init__(self, **kwargs): + _ = kwargs + + def run(self): + raise error + + return _Runner + + for raised_error in [ValueError("bad input"), RuntimeError("unexpected")]: + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="internal-user", + stream=False, + invoke_from=InvokeFrom.DEBUGGER, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv")) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", + _make_runner(raised_error), + ) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.dify_config", SimpleNamespace(DEBUG=True)) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + queue_manager.publish_error.assert_called_once() + + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + def test_handle_response_re_raises_value_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + app_config = self._build_app_config() + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + class _Pipeline: + def __init__(self, **kwargs): + _ = kwargs + + def process(self): + raise ValueError("other error") + + logger_exception = MagicMock() + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.logger.exception", logger_exception) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppGenerateTaskPipeline", _Pipeline) + + with pytest.raises(ValueError, match="other error"): + generator._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + logger_exception.assert_called_once() + + def test_refresh_model_returns_detached_model(self, monkeypatch): + source_model = SimpleNamespace(id="source-id") + detached_model = SimpleNamespace(id="source-id", detached=True) + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def get(self, model_type, model_id): + _ = model_type + return detached_model if model_id == "source-id" else None + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) + + refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) + + assert refreshed is detached_model + + def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): + generator = AdvancedChatAppGenerator() + generator._dialogue_count = 1 + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.SERVICE_API, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + queue_manager = MagicMock() + + generator._get_conversation = MagicMock(return_value=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT)) + generator._get_message = MagicMock(return_value=SimpleNamespace(id="msg")) + + class _Runner: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def run(self): + from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + raise InvokeAuthorizationError("bad key") + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner) + + @contextmanager + def _fake_context(*args, **kwargs): + yield + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.preserve_flask_contexts", _fake_context) + + class _Session: + def __init__(self, *args, **kwargs): + self.scalar = MagicMock( + side_effect=[ + SimpleNamespace(id="workflow-id", tenant_id="tenant", app_id="app"), + SimpleNamespace(id="end-user-id", session_id="session-id"), + SimpleNamespace(id="app"), + ] + ) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + + generator._generate_worker( + flask_app=SimpleNamespace(), + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + context=SimpleNamespace(), + variable_loader=SimpleNamespace(), + workflow_execution_repository=SimpleNamespace(), + workflow_node_execution_repository=SimpleNamespace(), + graph_engine_layers=(), + graph_runtime_state=None, + ) + + assert queue_manager.publish_error.called + + def test_generate_debugger_enables_retrieve_source(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models import Account + + user = Account(name="Tester", email="tester@example.com") + user.id = "user" + + result = generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello\x00", "inputs": {}}, + invoke_from=InvokeFrom.DEBUGGER, + workflow_run_id="run-id", + streaming=False, + ) + + assert result == {"ok": True} + assert app_config.additional_features.show_retrieve_source is True + assert captured["application_generate_entity"].query == "hello" + + def test_generate_service_api_sets_parent_message_id(self, monkeypatch): + generator = AdvancedChatAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.AdvancedChatAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.advanced_chat.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + captured = {} + + def _fake_generate(**kwargs): + captured.update(kwargs) + return {"ok": True} + + monkeypatch.setattr(generator, "_generate", _fake_generate) + + app_model = SimpleNamespace(id="app", tenant_id="tenant") + workflow = SimpleNamespace(features_dict={}) + from models.model import EndUser + + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + user.id = "end-user" + + generator.generate( + app_model=app_model, + workflow=workflow, + user=user, + args={"query": "hello", "inputs": {}, "parent_message_id": "p1"}, + invoke_from=InvokeFrom.SERVICE_API, + workflow_run_id="run-id", + streaming=False, + ) + + assert captured["application_generate_entity"].parent_message_id == UUID_NIL diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 0ca54a2f4a..15aceef2c7 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow @@ -125,7 +125,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -265,7 +269,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, @@ -412,7 +420,11 @@ class TestAdvancedChatAppRunnerConversationVariables: patch("core.app.apps.advanced_chat.app_runner.select") as mock_select, patch("core.app.apps.advanced_chat.app_runner.db") as mock_db, patch.object(runner, "_init_graph") as mock_init_graph, - patch.object(runner, "handle_input_moderation", return_value=False), + patch.object( + runner, + "handle_input_moderation", + return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query), + ), patch.object(runner, "handle_annotation_reply", return_value=False), patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class, patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class, diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py new file mode 100644 index 0000000000..5792a2f1e2 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -0,0 +1,170 @@ +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import QueueStopEvent +from core.moderation.base import ModerationError + + +@pytest.fixture +def build_runner(): + """Construct a minimal AdvancedChatAppRunner with heavy dependencies mocked.""" + app_id = str(uuid4()) + workflow_id = str(uuid4()) + + # Mocks for constructor args + mock_queue_manager = MagicMock() + + mock_conversation = MagicMock() + mock_conversation.id = str(uuid4()) + mock_conversation.app_id = app_id + + mock_message = MagicMock() + mock_message.id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.id = workflow_id + mock_workflow.tenant_id = str(uuid4()) + mock_workflow.app_id = app_id + mock_workflow.type = "chat" + mock_workflow.graph_dict = {} + mock_workflow.environment_variables = [] + + mock_app_config = MagicMock() + mock_app_config.app_id = app_id + mock_app_config.workflow_id = workflow_id + mock_app_config.tenant_id = str(uuid4()) + + gen = MagicMock(spec=AdvancedChatAppGenerateEntity) + gen.app_config = mock_app_config + gen.inputs = {"q": "raw"} + gen.query = "raw-query" + gen.files = [] + gen.user_id = str(uuid4()) + gen.invoke_from = InvokeFrom.SERVICE_API + gen.workflow_run_id = str(uuid4()) + gen.task_id = str(uuid4()) + gen.call_depth = 0 + gen.single_iteration_run = None + gen.single_loop_run = None + gen.trace_manager = None + + runner = AdvancedChatAppRunner( + application_generate_entity=gen, + queue_manager=mock_queue_manager, + conversation=mock_conversation, + message=mock_message, + dialogue_count=1, + variable_loader=MagicMock(), + workflow=mock_workflow, + system_user_id=str(uuid4()), + app=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + return runner + + +def _patch_common_run_deps(runner: AdvancedChatAppRunner): + """Context manager that patches common heavy deps used by run().""" + return patch.multiple( + "core.app.apps.advanced_chat.app_runner", + Session=MagicMock( + return_value=MagicMock( + __enter__=lambda s: s, + __exit__=lambda *a, **k: False, + scalar=lambda *a, **k: MagicMock(), + ), + ), + select=MagicMock(), + db=MagicMock(engine=MagicMock()), + RedisChannel=MagicMock(), + redis_client=MagicMock(), + WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}), + GraphRuntimeState=MagicMock(), + ) + + +def test_handle_input_moderation_stops_on_moderation_error(build_runner): + runner = build_runner + + # moderation_for_inputs raises ModerationError -> should stop and emit stop event + with ( + patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(runner, "_complete_with_stream_output") as mock_complete, + ): + stop, new_inputs, new_query = runner.handle_input_moderation( + app_record=MagicMock(), + app_generate_entity=runner.application_generate_entity, + inputs={"k": "v"}, + query="hello", + message_id="mid", + ) + + assert stop is True + # inputs/query should be unchanged on error path + assert new_inputs == {"k": "v"} + assert new_query == "hello" + # ensure stopped_by reason is INPUT_MODERATION + assert mock_complete.called + args, kwargs = mock_complete.call_args + assert kwargs.get("stopped_by") == QueueStopEvent.StopBy.INPUT_MODERATION + + +def test_run_applies_overridden_inputs_and_query_from_moderation(build_runner): + runner = build_runner + + overridden_inputs = {"q": "sanitized"} + overridden_query = "sanitized-query" + + with ( + _patch_common_run_deps(runner), + patch.object( + runner, + "moderation_for_inputs", + return_value=(True, overridden_inputs, overridden_query), + ) as mock_moderate, + patch.object(runner, "handle_annotation_reply", return_value=False) as mock_anno, + patch.object(runner, "_init_graph", return_value=MagicMock()) as mock_init_graph, + ): + runner.run() + + # moderation called with original values + mock_moderate.assert_called_once() + + # application_generate_entity should be updated to overridden values + assert runner.application_generate_entity.inputs == overridden_inputs + assert runner.application_generate_entity.query == overridden_query + + # annotation reply should use the new query + mock_anno.assert_called() + assert mock_anno.call_args.kwargs.get("query") == overridden_query + + # since not stopped, graph initialization should proceed + assert mock_init_graph.called + + +def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_runner): + runner = build_runner + + with ( + _patch_common_run_deps(runner), + # Simulate handle_input_moderation signalling to stop + patch.object( + runner, + "handle_input_moderation", + return_value=(True, runner.application_generate_entity.inputs, runner.application_generate_entity.query), + ) as mock_handle, + patch.object(runner, "_init_graph") as mock_init_graph, + patch.object(runner, "handle_annotation_reply") as mock_anno, + ): + runner.run() + + mock_handle.assert_called_once() + # Ensure no further steps executed + mock_anno.assert_not_called() + mock_init_graph.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py new file mode 100644 index 0000000000..5b199e0c52 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -0,0 +1,96 @@ +from collections.abc import Generator + +from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, +) +from dify_graph.enums import WorkflowNodeExecutionStatus + + +class TestAdvancedChatGenerateResponseConverter: + def test_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + response = AdvancedChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + assert "usage" not in response["metadata"] + + def test_stream_simple_response_includes_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_start, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=node_finish, + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + converted = list(AdvancedChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py similarity index 56% rename from api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py rename to api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index a94b5445f7..83a6e0f231 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_extra_contents.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -9,8 +9,16 @@ import pytest from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.entities.queue_entities import QueueTextChunkEvent, QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired +from core.app.entities.queue_entities import ( + QueuePingEvent, + QueueTextChunkEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import StreamEvent +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent from models.model import EndUser @@ -185,3 +193,97 @@ def test_resume_appends_chunks_to_paused_answer() -> None: assert message.answer == "beforeafter" assert message.status == MessageStatus.NORMAL + + +def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.SUCCEEDED), + ) + + event = QueueWorkflowSucceededEvent(outputs={}) + responses = list(pipeline._handle_workflow_succeeded_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_workflow_partial_success_emits_message_end_before_workflow_finished() -> None: + pipeline = _build_pipeline() + pipeline._application_generate_entity = SimpleNamespace(task_id="task-1") + pipeline._workflow_id = "workflow-1" + pipeline._ensure_workflow_initialized = mock.Mock() + runtime_state = SimpleNamespace() + pipeline._ensure_graph_runtime_initialized = mock.Mock(return_value=runtime_state) + pipeline._handle_advanced_chat_message_end_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.MESSAGE_END)]) + ) + pipeline._workflow_response_converter = mock.Mock() + pipeline._workflow_response_converter.workflow_finish_to_stream_response.return_value = SimpleNamespace( + event=StreamEvent.WORKFLOW_FINISHED, + data=SimpleNamespace(status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED), + ) + + event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + responses = list(pipeline._handle_workflow_partial_success_event(event)) + + assert [resp.event for resp in responses] == [StreamEvent.MESSAGE_END, StreamEvent.WORKFLOW_FINISHED] + + +def test_process_stream_response_breaks_after_workflow_succeeded() -> None: + pipeline = _build_pipeline() + succeeded_event = QueueWorkflowSucceededEvent(outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=succeeded_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_succeeded_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_succeeded_event.assert_called_once_with(succeeded_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() + + +def test_process_stream_response_breaks_after_workflow_partial_success() -> None: + pipeline = _build_pipeline() + partial_event = QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ping_event = QueuePingEvent() + queue_messages = [ + SimpleNamespace(event=partial_event), + SimpleNamespace(event=ping_event), + ] + + pipeline._conversation_name_generate_thread = None + pipeline._base_task_pipeline = mock.Mock() + pipeline._base_task_pipeline.queue_manager = mock.Mock() + pipeline._base_task_pipeline.queue_manager.listen.return_value = iter(queue_messages) + pipeline._base_task_pipeline.ping_stream_response = mock.Mock(return_value=SimpleNamespace(event=StreamEvent.PING)) + pipeline._handle_workflow_partial_success_event = mock.Mock( + return_value=iter([SimpleNamespace(event=StreamEvent.WORKFLOW_FINISHED)]) + ) + + responses = list(pipeline._process_stream_response()) + + assert [resp.event for resp in responses] == [StreamEvent.WORKFLOW_FINISHED] + pipeline._handle_workflow_partial_success_event.assert_called_once_with(partial_event, trace_manager=None) + pipeline._base_task_pipeline.ping_stream_response.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..0a244b3fea --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -0,0 +1,600 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from core.app.entities.queue_entities import ( + QueueAdvancedChatMessageEndEvent, + QueueAnnotationReplyEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueMessageReplaceEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueuePingEvent, + QueueRetrieverResourcesEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, + MessageAudioStreamResponse, + MessageEndStreamResponse, + PingStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import MessageStatus +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.ADVANCED_CHAT, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = AdvancedChatAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + query="hello", + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_run_id="run-id", + ) + + message = SimpleNamespace( + id="message-id", + query="hello", + created_at=datetime.utcnow(), + status=MessageStatus.NORMAL, + answer="", + ) + conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") + + pipeline = AdvancedChatAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + conversation=conversation, + message=message, + user=user, + stream=False, + dialogue_count=1, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestAdvancedChatGenerateTaskPipeline: + def test_ensure_workflow_initialized_raises(self): + pipeline = _make_pipeline() + + with pytest.raises(ValueError, match="workflow run not initialized"): + pipeline._ensure_workflow_initialized() + + def test_to_blocking_response_returns_message_end(self): + pipeline = _make_pipeline() + pipeline._task_state.answer = "done" + + def _gen(): + yield MessageEndStreamResponse(task_id="task", id="message-id", metadata={"k": "v"}) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.answer == "done" + assert response.data.metadata == {"k": "v"} + + def test_handle_text_chunk_event_updates_state(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager = SimpleNamespace( + message_to_stream_response=lambda **kwargs: MessageEndStreamResponse( + task_id="task", id="message-id", metadata={} + ) + ) + + event = SimpleNamespace(text="hi", from_variable_selector=None) + + responses = list(pipeline._handle_text_chunk_event(event)) + + assert pipeline._task_state.answer == "hi" + assert responses + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + pipeline._database_session = _fake_session + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_get_message", lambda **kwargs: SimpleNamespace()) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_run_id == "run-id" + assert responses == ["started"] + + def test_message_end_to_stream_response_strips_annotation_reply(self): + pipeline = _make_pipeline() + pipeline._task_state.metadata.annotation_reply = AnnotationReply( + id="ann", + account=AnnotationReplyAccount(id="acc", name="acc"), + ) + + response = pipeline._message_end_to_stream_response() + + assert "annotation_reply" not in response.metadata + + def test_handle_output_moderation_chunk_publishes_stop(self): + pipeline = _make_pipeline() + events: list[object] = [] + + class _Moderation: + def should_direct_output(self): + return True + + def get_final_output(self): + return "final" + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + pipeline._base_task_pipeline.queue_manager = SimpleNamespace( + publish=lambda event, pub_from: events.append(event) + ) + + result = pipeline._handle_output_moderation_chunk("ignored") + + assert result is True + assert pipeline._task_state.answer == "final" + assert any(isinstance(event, QueueTextChunkEvent) for event in events) + assert any(isinstance(event, QueueStopEvent) for event in events) + + def test_handle_node_succeeded_event_records_files(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.fetch_files_from_node_outputs = lambda outputs: [ + {"type": "file", "transfer_method": "local"} + ] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + event = SimpleNamespace( + node_type=BuiltinNodeTypes.ANSWER, + outputs={"k": "v"}, + node_execution_id="exec", + node_id="node", + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + assert pipeline._recorded_files + + def test_iteration_and_loop_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: ( + "iter_start" + ) + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "iter_next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: ( + "iter_done" + ) + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop_start" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter_start"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["iter_next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["iter_done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop_start"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + + def test_workflow_finish_handlers(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_run_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: ["pause"] + pipeline._persist_human_input_extra_content = lambda **kwargs: None + pipeline._save_message = lambda **kwargs: None + pipeline._base_task_pipeline.queue_manager.publish = lambda *args, **kwargs: None + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + pipeline._get_message = lambda **kwargs: SimpleNamespace(id="message-id") + + @contextmanager + def _fake_session(): + yield SimpleNamespace(scalar=lambda *args, **kwargs: None) + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + succeeded_responses = list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) + assert len(succeeded_responses) == 2 + assert isinstance(succeeded_responses[0], MessageEndStreamResponse) + assert succeeded_responses[1] == "finish" + + partial_success_responses = list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=1, outputs={}) + ) + ) + assert len(partial_success_responses) == 2 + assert isinstance(partial_success_responses[0], MessageEndStreamResponse) + assert partial_success_responses[1] == "finish" + assert ( + list(pipeline._handle_workflow_failed_event(QueueWorkflowFailedEvent(error="err", exceptions_count=1)))[0] + == "finish" + ) + assert list(pipeline._handle_workflow_paused_event(QueueWorkflowPausedEvent(reasons=[], outputs={}))) == [ + "pause" + ] + + def test_node_failure_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "node_finish" + pipeline._save_output_for_event = lambda event, node_execution_id: None + + failed_event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + exc_event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._handle_node_failed_events(failed_event)) == ["node_finish"] + assert list(pipeline._handle_node_failed_events(exc_event)) == ["node_finish"] + + def test_handle_text_chunk_event_tracks_streaming_metrics(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk") + + event = SimpleNamespace(text="hi", from_variable_selector=["a"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses == ["chunk"] + assert pipeline._task_state.is_streaming_response is True + assert pipeline._task_state.first_token_time is not None + assert pipeline._task_state.last_token_time is not None + assert pipeline._task_state.answer == "hi" + assert published == [queue_message] + + def test_handle_output_moderation_chunk_appends_token(self): + pipeline = _make_pipeline() + seen: list[str] = [] + + class _Moderation: + def should_direct_output(self): + return False + + def append_new_token(self, text): + seen.append(text) + + pipeline._base_task_pipeline.output_moderation_handler = _Moderation() + + result = pipeline._handle_output_moderation_chunk("token") + + assert result is False + assert seen == ["token"] + + def test_handle_retriever_and_annotation_events(self): + pipeline = _make_pipeline() + calls = {"retriever": 0, "annotation": 0} + + def _hit_retriever(event): + calls["retriever"] += 1 + + def _hit_annotation(event): + calls["annotation"] += 1 + + pipeline._message_cycle_manager.handle_retriever_resources = _hit_retriever + pipeline._message_cycle_manager.handle_annotation_reply = _hit_annotation + + retriever_event = QueueRetrieverResourcesEvent(retriever_resources=[]) + annotation_event = QueueAnnotationReplyEvent(message_annotation_id="ann") + + assert list(pipeline._handle_retriever_resources_event(retriever_event)) == [] + assert list(pipeline._handle_annotation_reply_event(annotation_event)) == [] + assert calls == {"retriever": 1, "annotation": 1} + + def test_handle_message_replace_event(self): + pipeline = _make_pipeline() + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + + event = QueueMessageReplaceEvent( + text="new", + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + ) + + assert list(pipeline._handle_message_replace_event(event)) == ["replace"] + + def test_handle_human_input_events(self): + pipeline = _make_pipeline() + persisted: list[str] = [] + pipeline._persist_human_input_extra_content = lambda **kwargs: persisted.append("saved") + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert persisted == ["saved"] + + def test_save_message_strips_markdown_and_sets_usage(self): + pipeline = _make_pipeline() + pipeline._recorded_files = [ + { + "type": "image", + "transfer_method": "remote", + "remote_url": "http://example.com/file.png", + "related_id": "file-id", + } + ] + pipeline._task_state.answer = "![img](url) hello" + pipeline._task_state.is_streaming_response = True + pipeline._task_state.first_token_time = pipeline._base_task_pipeline.start_at + 0.1 + pipeline._task_state.last_token_time = pipeline._base_task_pipeline.start_at + 0.2 + + message = SimpleNamespace( + id="message-id", + status=MessageStatus.PAUSED, + answer="", + updated_at=None, + provider_response_latency=None, + message_tokens=None, + message_unit_price=None, + message_price_unit=None, + answer_tokens=None, + answer_unit_price=None, + answer_price_unit=None, + total_price=None, + currency=None, + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="end-user", + ) + + class _Session: + def scalar(self, *args, **kwargs): + return message + + def add_all(self, items): + self.items = items + + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._save_message(session=_Session(), graph_runtime_state=graph_runtime_state) + + assert message.status == MessageStatus.NORMAL + assert message.answer == "hello" + assert message.message_metadata + + def test_handle_stop_event_saves_message_for_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._message_end_to_stream_response = lambda: "end" + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_stop_event(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION))) + + assert responses == ["end"] + assert saved == ["saved"] + + def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" + pipeline._message_cycle_manager.message_replace_to_stream_response = lambda **kwargs: "replace" + pipeline._message_end_to_stream_response = lambda: "end" + + saved: list[str] = [] + + def _save_message(**kwargs): + saved.append("saved") + + pipeline._save_message = _save_message + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + + responses = list(pipeline._handle_advanced_chat_message_end_event(QueueAdvancedChatMessageEndEvent())) + + assert responses == ["replace", "end"] + assert saved == ["saved"] + + def test_dispatch_event_handles_node_exception(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda *args, **kwargs: None + + event = QueueNodeExceptionEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["failed"] diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py new file mode 100644 index 0000000000..a871e8d93b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_config_manager.py @@ -0,0 +1,302 @@ +import uuid +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.agent_chat.app_config_manager import ( + AgentChatAppConfigManager, +) +from core.entities.agent_entities import PlanningStrategy + + +class TestAgentChatAppConfigManagerGetAppConfig: + def test_get_app_config_override_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"ignored": True} + + override_config = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override_config, + ) + + assert result.app_model_config_dict == override_config + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.variables == "variables" + assert result.external_data_variables == "external" + + def test_get_app_config_conversation_specific(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + conversation = mocker.MagicMock() + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=None, + ) + + assert result.app_model_config_dict == app_model_config.to_dict.return_value + assert result.app_model_config_from.value == "conversation-specific-config" + + def test_get_app_config_latest_config(self, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + mocker.patch("core.app.apps.agent_chat.app_config_manager.ModelConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.convert") + mocker.patch("core.app.apps.agent_chat.app_config_manager.AgentConfigManager.convert") + mocker.patch.object(AgentChatAppConfigManager, "convert_features") + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.convert", + return_value=("variables", "external"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.AgentChatAppConfig", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) + + result = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=None, + ) + + assert result.app_model_config_from.value == "app-latest-config" + + +class TestAgentChatAppConfigManagerConfigValidate: + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {}, + "user_input_form": {}, + "file_upload": {}, + "prompt_template": {}, + "agent_mode": {}, + "opening_statement": {}, + "suggested_questions_after_answer": {}, + "speech_to_text": {}, + "text_to_speech": {}, + "retriever_resource": {}, + "dataset": {}, + "moderation": {}, + "extra": "value", + } + + def return_with_key(key): + return config, [key] + + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("model"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("user_input_form"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("file_upload"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=lambda app_mode, cfg: return_with_key("prompt_template"), + ) + mocker.patch.object( + AgentChatAppConfigManager, + "validate_agent_mode_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("agent_mode"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("opening_statement"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("suggested_questions_after_answer"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("speech_to_text"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("text_to_speech"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=lambda cfg: return_with_key("retriever_resource"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, app_mode, cfg: return_with_key("dataset"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=lambda tenant_id, cfg: return_with_key("moderation"), + ) + + filtered = AgentChatAppConfigManager.config_validate("tenant", config) + assert set(filtered.keys()) == { + "model", + "user_input_form", + "file_upload", + "prompt_template", + "agent_mode", + "opening_statement", + "suggested_questions_after_answer", + "speech_to_text", + "text_to_speech", + "retriever_resource", + "dataset", + "moderation", + } + assert "extra" not in filtered + + +class TestValidateAgentModeAndSetDefaults: + def test_defaults_when_missing(self): + config = {} + updated, keys = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert "agent_mode" in updated + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + assert keys == ["agent_mode"] + + @pytest.mark.parametrize( + "agent_mode", + ["invalid", 123], + ) + def test_agent_mode_type_validation(self, agent_mode): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": agent_mode}) + + def test_agent_mode_empty_list_defaults(self): + config = {"agent_mode": []} + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["enabled"] is False + assert updated["agent_mode"]["tools"] == [] + + def test_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", {"agent_mode": {"enabled": "yes"}}) + + def test_strategy_must_be_valid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "strategy": "invalid"}} + ) + + def test_tools_must_be_list(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": "not-list"}} + ) + + def test_old_tool_dataset_requires_id(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True}}]}} + ) + + def test_old_tool_dataset_id_must_be_uuid(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": "bad"}}]}}, + ) + + def test_old_tool_dataset_id_not_exists(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=False, + ) + dataset_id = str(uuid.uuid4()) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": True, "id": dataset_id}}]}}, + ) + + def test_old_tool_enabled_must_be_bool(self): + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", + {"agent_mode": {"enabled": True, "tools": [{"dataset": {"enabled": "yes", "id": str(uuid.uuid4())}}]}}, + ) + + @pytest.mark.parametrize("missing_key", ["provider_type", "provider_id", "tool_name", "tool_parameters"]) + def test_new_style_tool_requires_fields(self, missing_key): + tool = {"enabled": True, "provider_type": "type", "provider_id": "id", "tool_name": "tool"} + tool.pop(missing_key, None) + with pytest.raises(ValueError): + AgentChatAppConfigManager.validate_agent_mode_and_set_defaults( + "tenant", {"agent_mode": {"enabled": True, "tools": [tool]}} + ) + + def test_valid_old_and_new_style_tools(self, mocker): + mocker.patch( + "core.app.apps.agent_chat.app_config_manager.DatasetConfigManager.is_dataset_exists", + return_value=True, + ) + dataset_id = str(uuid.uuid4()) + config = { + "agent_mode": { + "enabled": True, + "strategy": PlanningStrategy.ROUTER.value, + "tools": [ + {"dataset": {"id": dataset_id}}, + { + "provider_type": "builtin", + "provider_id": "p1", + "tool_name": "tool", + "tool_parameters": {}, + }, + ], + } + } + + updated, _ = AgentChatAppConfigManager.validate_agent_mode_and_set_defaults("tenant", config) + assert updated["agent_mode"]["tools"][0]["dataset"]["enabled"] is False + assert updated["agent_mode"]["tools"][1]["enabled"] is False diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py new file mode 100644 index 0000000000..53f26d1592 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -0,0 +1,296 @@ +import contextlib + +import pytest +from pydantic import ValidationError + +from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + + +class DummyAccount: + def __init__(self, user_id): + self.id = user_id + self.session_id = f"session-{user_id}" + + +@pytest.fixture +def generator(mocker): + gen = AgentChatAppGenerator() + mocker.patch( + "core.app.apps.agent_chat.app_generator.current_app", + new=mocker.MagicMock(_get_current_object=mocker.MagicMock()), + ) + mocker.patch("core.app.apps.agent_chat.app_generator.contextvars.copy_context", return_value="ctx") + return gen + + +class TestAgentChatAppGeneratorGenerate: + def test_generate_rejects_blocking_mode(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={}, invoke_from=mocker.MagicMock(), streaming=False) + + def test_generate_requires_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate(app_model=app_model, user=user, args={"inputs": {}}, invoke_from=mocker.MagicMock()) + + def test_generate_rejects_non_string_query(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": 123, "inputs": {}}, + invoke_from=mocker.MagicMock(), + ) + + def test_generate_override_requires_debugger(self, generator, mocker): + app_model = mocker.MagicMock() + user = DummyAccount("user") + + with pytest.raises(ValueError): + generator.generate( + app_model=app_model, + user=user, + args={"query": "hi", "inputs": {}, "model_config": {"model": {"provider": "p"}}}, + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_success_with_debugger_override(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + invoke_from = InvokeFrom.DEBUGGER + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.config_validate", + return_value={"validated": True}, + ) + app_config = mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=app_config, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ConversationService.get_conversation", + return_value=mocker.MagicMock(id="conv"), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + queue_manager = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=queue_manager, + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=invoke_from) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = { + "query": "hello", + "inputs": {"name": "world"}, + "conversation_id": "conv", + "model_config": {"model": {"provider": "p"}}, + "files": [{"id": "f1"}], + } + + result = generator.generate(app_model=app_model, user=user, args=args, invoke_from=invoke_from, streaming=True) + + assert result == {"result": "ok"} + thread_obj.start.assert_called_once() + + def test_generate_without_file_config(self, generator, mocker): + app_model = mocker.MagicMock(id="app1", tenant_id="tenant", mode="agent-chat") + app_model_config = mocker.MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "p"}} + + user = DummyAccount("user") + + generator._get_app_model_config = mocker.MagicMock(return_value=app_model_config) + generator._prepare_user_inputs = mocker.MagicMock(return_value={"x": 1}) + generator._init_generate_records = mocker.MagicMock( + return_value=(mocker.MagicMock(id="conv", mode="agent-chat"), mocker.MagicMock(id="msg")) + ) + generator._handle_response = mocker.MagicMock(return_value="response") + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppConfigManager.get_app_config", + return_value=mocker.MagicMock(variables={}, prompt_template=mocker.MagicMock(), external_data_variables=[]), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.ModelConfigConverter.convert", + return_value=mocker.MagicMock(), + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.FileUploadConfigManager.convert", + return_value=None, + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.file_factory.build_from_mappings", + return_value=["file-obj"], + ) + mocker.patch( + "core.app.apps.agent_chat.app_generator.TraceQueueManager", + return_value=mocker.MagicMock(), + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.MessageBasedAppQueueManager", + return_value=mocker.MagicMock(), + ) + + thread_obj = mocker.MagicMock() + mocker.patch( + "core.app.apps.agent_chat.app_generator.threading.Thread", + return_value=thread_obj, + ) + + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateResponseConverter.convert", + return_value={"result": "ok"}, + ) + app_entity = mocker.MagicMock(task_id="task", user_id="user", invoke_from=InvokeFrom.WEB_APP) + mocker.patch( + "core.app.apps.agent_chat.app_generator.AgentChatAppGenerateEntity", + return_value=app_entity, + ) + + args = {"query": "hello", "inputs": {"name": "world"}} + + result = generator.generate( + app_model=app_model, + user=user, + args=args, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == {"result": "ok"} + + +class TestAgentChatAppGeneratorWorker: + @pytest.fixture(autouse=True) + def patch_context(self, mocker): + @contextlib.contextmanager + def ctx_manager(*args, **kwargs): + yield + + mocker.patch("core.app.apps.agent_chat.app_generator.preserve_flask_contexts", ctx_manager) + + def test_generate_worker_handles_generate_task_stopped(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = GenerateTaskStoppedError() + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + queue_manager.publish_error.assert_not_called() + + @pytest.mark.parametrize( + "error", + [ + InvokeAuthorizationError("bad"), + ValidationError.from_exception_data("TestModel", []), + ValueError("bad"), + Exception("bad"), + ], + ) + def test_generate_worker_publishes_errors(self, generator, mocker, error): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = error + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + assert queue_manager.publish_error.called + + def test_generate_worker_logs_value_error_when_debug(self, generator, mocker): + queue_manager = mocker.MagicMock() + generator._get_conversation = mocker.MagicMock(return_value=mocker.MagicMock()) + generator._get_message = mocker.MagicMock(return_value=mocker.MagicMock()) + + runner = mocker.MagicMock() + runner.run.side_effect = ValueError("bad") + mocker.patch("core.app.apps.agent_chat.app_generator.AgentChatAppRunner", return_value=runner) + mocker.patch("core.app.apps.agent_chat.app_generator.db.session.close") + + mocker.patch("core.app.apps.agent_chat.app_generator.dify_config", new=mocker.MagicMock(DEBUG=True)) + logger = mocker.patch("core.app.apps.agent_chat.app_generator.logger") + + generator._generate_worker( + flask_app=mocker.MagicMock(), + context=mocker.MagicMock(), + application_generate_entity=mocker.MagicMock(), + queue_manager=queue_manager, + conversation_id="conv", + message_id="msg", + ) + + logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py new file mode 100644 index 0000000000..5603115b30 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -0,0 +1,413 @@ +import pytest + +from core.agent.entities import AgentEntity +from core.app.apps.agent_chat.app_runner import AgentChatAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey + + +@pytest.fixture +def runner(): + return AgentChatAppRunner() + + +class TestAgentChatAppRunnerRun: + def test_run_app_not_found(self, runner, mocker): + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", agent=mocker.MagicMock()) + generate_entity = mocker.MagicMock(app_config=app_config, inputs={}, query="q", files=[], stream=True) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=None) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_moderation_error_direct_output(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("bad")) + mocker.patch.object(runner, "direct_output") + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + runner.direct_output.assert_called_once() + + def test_run_annotation_reply_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + user_id="user", + invoke_from=mocker.MagicMock(), + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + annotation = mocker.MagicMock(id="anno", content="answer") + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=annotation) + mocker.patch.object(runner, "direct_output") + + queue_manager = mocker.MagicMock() + runner.run(generate_entity, queue_manager, mocker.MagicMock(), mocker.MagicMock()) + + queue_manager.publish.assert_called_once() + runner.direct_output.assert_called_once() + + def test_run_hosting_moderation_short_circuits(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock() + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock(), + conversation_id=None, + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=True) + + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + def test_run_model_schema_missing(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = None + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock()) + + @pytest.mark.parametrize( + ("mode", "expected_runner"), + [ + (LLMMode.CHAT, "CotChatAgentRunner"), + (LLMMode.COMPLETION, "CotCompletionAgentRunner"), + ], + ) + def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: mode} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + runner_instance.run.assert_called_once() + runner._handle_invoke_result.assert_called_once() + + def test_run_invalid_llm_mode_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + def test_run_function_calling_strategy_selected_by_features(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [ModelFeature.TOOL_CALL] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + runner_cls = mocker.MagicMock() + mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls) + + runner_instance = mocker.MagicMock() + runner_cls.return_value = runner_instance + runner_instance.run.return_value = [] + mocker.patch.object(runner, "_handle_invoke_result") + + runner.run(generate_entity, mocker.MagicMock(), conversation, message) + + assert app_config.agent.strategy == AgentEntity.Strategy.FUNCTION_CALLING + runner_instance.run.assert_called_once() + + def test_run_conversation_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_message_not_found(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.FUNCTION_CALLING) + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, mocker.MagicMock(id="conv"), None], + ) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg")) + + def test_run_invalid_agent_strategy_raises(self, runner, mocker): + app_record = mocker.MagicMock(id="app1", tenant_id="tenant") + app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock()) + app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m") + + generate_entity = mocker.MagicMock( + app_config=app_config, + inputs={}, + query="q", + files=[], + stream=True, + model_conf=mocker.MagicMock( + provider_model_bundle=mocker.MagicMock(), + model="m", + provider="p", + credentials={"k": "v"}, + ), + conversation_id="conv", + invoke_from=mocker.MagicMock(), + user_id="user", + ) + + mocker.patch("core.app.apps.agent_chat.app_runner.db.session.scalar", return_value=app_record) + mocker.patch.object(runner, "organize_prompt_messages", return_value=([], None)) + mocker.patch.object(runner, "moderation_for_inputs", return_value=(None, {}, "q")) + mocker.patch.object(runner, "query_app_annotations_to_reply", return_value=None) + mocker.patch.object(runner, "check_hosting_moderation", return_value=False) + + model_schema = mocker.MagicMock() + model_schema.features = [] + model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT} + + llm_instance = mocker.MagicMock() + llm_instance.model_type_instance.get_model_schema.return_value = model_schema + mocker.patch("core.app.apps.agent_chat.app_runner.ModelInstance", return_value=llm_instance) + + conversation = mocker.MagicMock(id="conv") + message = mocker.MagicMock(id="msg") + mocker.patch( + "core.app.apps.agent_chat.app_runner.db.session.scalar", + side_effect=[app_record, conversation, message], + ) + + with pytest.raises(ValueError): + runner.run(generate_entity, mocker.MagicMock(), conversation, message) diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py new file mode 100644 index 0000000000..e861a0c684 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_generate_response_converter.py @@ -0,0 +1,195 @@ +from collections.abc import Generator + +from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestAgentChatAppGenerateResponseConverterBlocking: + def test_convert_blocking_full_response(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={"a": 1}, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["answer"] == "answer" + assert result["metadata"] == {"a": 1} + + def test_convert_blocking_simple_response_with_dict_metadata(self): + blocking = ChatbotAppBlockingResponse( + task_id="task", + data=ChatbotAppBlockingResponse.Data( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata={ + "retriever_resources": [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", + "segment_id": "s1", + "position": 1, + "data_source_type": "file", + "document_name": "doc", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", + "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + + def test_convert_blocking_simple_response_with_non_dict_metadata(self): + blocking = ChatbotAppBlockingResponse.model_construct( + task_id="task", + data=ChatbotAppBlockingResponse.Data.model_construct( + id="id", + mode="agent-chat", + conversation_id="conv", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ), + ) + + result = AgentChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + +class TestAgentChatAppGenerateResponseConverterStream: + def build_stream(self) -> Generator[ChatbotAppStreamResponse, None, None]: + def _gen(): + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=1, + stream_response=PingStreamResponse(task_id="t"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=2, + stream_response=MessageStreamResponse(task_id="t", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=3, + stream_response=MessageEndStreamResponse( + task_id="t", + id="m1", + metadata={ + "retriever_resources": [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", + "segment_id": "s1", + "position": 1, + "data_source_type": "file", + "document_name": "doc", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", + "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], + "summary": "summary", + "extra": "ignored", + } + ], + "annotation_reply": {"id": "a"}, + "usage": {"prompt_tokens": 1}, + }, + ), + ) + yield ChatbotAppStreamResponse( + conversation_id="conv", + message_id="msg", + created_at=4, + stream_response=ErrorStreamResponse(task_id="t", err=RuntimeError("bad")), + ) + + return _gen() + + def test_convert_stream_full_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_full_response(self.build_stream())) + assert items[0] == "ping" + assert items[1]["event"] == "message" + assert "answer" in items[1] + assert items[2]["event"] == "message_end" + assert items[3]["event"] == "error" + + def test_convert_stream_simple_response(self): + items = list(AgentChatAppGenerateResponseConverter.convert_stream_simple_response(self.build_stream())) + assert items[0] == "ping" + # Assert the message event structure and content at items[1] + assert items[1]["event"] == "message" + assert items[1]["answer"] == "hi" or "hi" in items[1]["answer"] + assert items[2]["event"] == "message_end" + assert "metadata" in items[2] + metadata = items[2]["metadata"] + assert "annotation_reply" not in metadata + assert "usage" not in metadata + assert metadata["retriever_resources"] == [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", + "segment_id": "s1", + "position": 1, + "data_source_type": "file", + "document_name": "doc", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", + "content": "content", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], + "summary": "summary", + } + ] + assert items[3]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/chat/__init__.py b/api/tests/unit_tests/core/app/apps/chat/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py new file mode 100644 index 0000000000..271d007be6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_config_manager.py @@ -0,0 +1,113 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom, ModelConfigEntity, PromptTemplateEntity +from core.app.apps.chat.app_config_manager import ChatAppConfigManager +from models.model import AppMode + + +class TestChatAppConfigManager: + def test_get_app_config_uses_override_dict(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + app_model_config = SimpleNamespace(id="config-1", to_dict=lambda: {"model": "m"}) + override = {"model": "override"} + + model_entity = ModelConfigEntity(provider="p", model="m") + prompt_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hi", + ) + + with ( + patch("core.app.apps.chat.app_config_manager.ModelConfigManager.convert", return_value=model_entity), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.convert", return_value=prompt_entity + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch("core.app.apps.chat.app_config_manager.DatasetConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.convert", return_value=([], [])), + ): + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=None, + override_config_dict=override, + ) + + assert app_config.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert app_config.app_model_config_dict == override + assert app_config.app_mode == AppMode.CHAT + + def test_config_validate_filters_related_keys(self): + config = {"extra": 1} + + def _add_key(key, value): + def _inner(*args, **kwargs): + config = args[-1] + config = {**config, key: value} + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.chat.app_config_manager.ModelConfigManager.validate_and_set_defaults", + side_effect=_add_key("model", 1), + ), + patch( + "core.app.apps.chat.app_config_manager.BasicVariablesConfigManager.validate_and_set_defaults", + side_effect=_add_key("inputs", 2), + ), + patch( + "core.app.apps.chat.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 3), + ), + patch( + "core.app.apps.chat.app_config_manager.PromptTemplateConfigManager.validate_and_set_defaults", + side_effect=_add_key("prompt", 4), + ), + patch( + "core.app.apps.chat.app_config_manager.DatasetConfigManager.validate_and_set_defaults", + side_effect=_add_key("dataset", 5), + ), + patch( + "core.app.apps.chat.app_config_manager.OpeningStatementConfigManager.validate_and_set_defaults", + side_effect=_add_key("opening_statement", 6), + ), + patch( + "core.app.apps.chat.app_config_manager.SuggestedQuestionsAfterAnswerConfigManager.validate_and_set_defaults", + side_effect=_add_key("suggested_questions_after_answer", 7), + ), + patch( + "core.app.apps.chat.app_config_manager.SpeechToTextConfigManager.validate_and_set_defaults", + side_effect=_add_key("speech_to_text", 8), + ), + patch( + "core.app.apps.chat.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 9), + ), + patch( + "core.app.apps.chat.app_config_manager.RetrievalResourceConfigManager.validate_and_set_defaults", + side_effect=_add_key("retriever_resource", 10), + ), + patch( + "core.app.apps.chat.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 11), + ), + ): + filtered = ChatAppConfigManager.config_validate(tenant_id="t1", config=config) + + assert filtered["model"] == 1 + assert filtered["inputs"] == 2 + assert filtered["file_upload"] == 3 + assert filtered["prompt"] == 4 + assert filtered["dataset"] == 5 + assert filtered["opening_statement"] == 6 + assert filtered["suggested_questions_after_answer"] == 7 + assert filtered["speech_to_text"] == 8 + assert filtered["text_to_speech"] == 9 + assert filtered["retriever_resource"] == 10 + assert filtered["sensitive_word_avoidance"] == 11 diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py new file mode 100644 index 0000000000..3cdffbb4cd --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -0,0 +1,280 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.chat.app_generator import ChatAppGenerator +from core.app.apps.chat.app_runner import ChatAppRunner +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAnnotationReplyEvent +from core.moderation.base import ModerationError +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from models.model import AppMode + + +class DummyGenerateEntity: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + +class DummyQueueManager: + def __init__(self, *args, **kwargs): + self.published = [] + + def publish_error(self, error, pub_from): + self.published.append((error, pub_from)) + + def publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestChatAppGenerator: + def test_generate_requires_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_rejects_non_string_query(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + generator.generate( + app_model=SimpleNamespace(), + user=SimpleNamespace(), + args={"query": 1, "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_debugger_overrides_model_config(self): + generator = ChatAppGenerator() + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + user = SimpleNamespace(id="user-1", session_id="session-1") + args = {"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}} + + with ( + patch("core.app.apps.chat.app_generator.ConversationService.get_conversation", return_value=None), + patch("core.app.apps.chat.app_generator.ChatAppConfigManager.config_validate", return_value={"x": 1}), + patch( + "core.app.apps.chat.app_generator.ChatAppConfigManager.get_app_config", + return_value=SimpleNamespace( + variables=[], external_data_variables=[], app_model_config_dict={}, app_mode=AppMode.CHAT + ), + ), + patch("core.app.apps.chat.app_generator.ModelConfigConverter.convert", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.FileUploadConfigManager.convert", return_value=None), + patch("core.app.apps.chat.app_generator.file_factory.build_from_mappings", return_value=[]), + patch("core.app.apps.chat.app_generator.ChatAppGenerateEntity", DummyGenerateEntity), + patch("core.app.apps.chat.app_generator.TraceQueueManager", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.MessageBasedAppQueueManager", DummyQueueManager), + patch( + "core.app.apps.chat.app_generator.ChatAppGenerateResponseConverter.convert", return_value={"ok": True} + ), + patch.object(ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {})), + patch.object(ChatAppGenerator, "_prepare_user_inputs", return_value={}), + patch.object( + ChatAppGenerator, + "_init_generate_records", + return_value=(SimpleNamespace(id="c1", mode="chat"), SimpleNamespace(id="m1")), + ), + patch.object(ChatAppGenerator, "_handle_response", return_value={"response": True}), + patch("core.app.apps.chat.app_generator.copy_current_request_context", side_effect=lambda f: f), + patch("core.app.apps.chat.app_generator.threading.Thread") as mock_thread, + ): + mock_thread.return_value.start.return_value = None + result = generator.generate(app_model, user, args, InvokeFrom.DEBUGGER, streaming=False) + + assert result == {"ok": True} + + def test_generate_rejects_model_config_override_for_non_debugger(self): + generator = ChatAppGenerator() + with pytest.raises(ValueError): + with ( + patch.object( + ChatAppGenerator, "_get_app_model_config", return_value=SimpleNamespace(to_dict=lambda: {}) + ), + ): + generator.generate( + app_model=SimpleNamespace(tenant_id="t1", id="a1", mode=AppMode.CHAT.value), + user=SimpleNamespace(id="u1", session_id="s1"), + args={"query": "hi", "inputs": {}, "model_config": {"foo": "bar"}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_generate_worker_handles_exceptions(self): + generator = ChatAppGenerator() + queue_manager = DummyQueueManager() + entity = DummyGenerateEntity(task_id="t1", user_id="u1") + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=InvokeAuthorizationError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + assert queue_manager.published + + with ( + patch.object(ChatAppGenerator, "_get_conversation", return_value=SimpleNamespace()), + patch.object(ChatAppGenerator, "_get_message", return_value=SimpleNamespace()), + patch("core.app.apps.chat.app_generator.ChatAppRunner.run", side_effect=GenerateTaskStoppedError()), + patch("core.app.apps.chat.app_generator.db.session.close"), + ): + generator._generate_worker( + flask_app=Mock(app_context=Mock(return_value=Mock(__enter__=Mock(), __exit__=Mock()))), + application_generate_entity=entity, + queue_manager=queue_manager, + conversation_id="c1", + message_id="m1", + ) + + +class TestChatAppRunner: + def test_run_raises_when_app_missing(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", tenant_id="tenant-1", prompt_template=None, external_data_variables=[] + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with patch("core.app.apps.chat.app_runner.db.session.scalar", return_value=None): + with pytest.raises(ValueError): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + def test_run_moderation_error_direct_output(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", side_effect=ModerationError("blocked")), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) + + mock_direct.assert_called_once() + + def test_run_annotation_reply_short_circuits(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + annotation = SimpleNamespace(id="ann-1", content="answer") + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=annotation), + patch.object(ChatAppRunner, "direct_output") as mock_direct, + ): + queue_manager = DummyQueueManager() + runner.run(app_generate_entity, queue_manager, SimpleNamespace(), SimpleNamespace(id="m1")) + + assert any(isinstance(item[0], QueueAnnotationReplyEvent) for item in queue_manager.published) + mock_direct.assert_called_once() + + def test_run_returns_when_hosting_moderation_blocks(self): + runner = ChatAppRunner() + app_config = SimpleNamespace( + app_id="app-1", + tenant_id="tenant-1", + prompt_template=None, + external_data_variables=[], + dataset=None, + additional_features=None, + ) + app_generate_entity = DummyGenerateEntity( + app_config=app_config, + model_conf=SimpleNamespace(provider_model_bundle=None, model=None, parameters={}, app_model_config_dict={}), + inputs={}, + query="hi", + files=[], + file_upload_config=None, + conversation_id=None, + stream=False, + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + ) + + with ( + patch( + "core.app.apps.chat.app_runner.db.session.scalar", + return_value=SimpleNamespace(id="app-1", tenant_id="tenant-1"), + ), + patch.object(ChatAppRunner, "organize_prompt_messages", return_value=([], [])), + patch.object(ChatAppRunner, "moderation_for_inputs", return_value=(None, {}, "hi")), + patch.object(ChatAppRunner, "query_app_annotations_to_reply", return_value=None), + patch.object(ChatAppRunner, "check_hosting_moderation", return_value=True), + ): + runner.run(app_generate_entity, DummyQueueManager(), SimpleNamespace(), SimpleNamespace(id="m1")) diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 1931e230b2..67b3777c40 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -9,8 +9,8 @@ from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from core.model_runtime.entities.message_entities import ImagePromptMessageContent -from core.workflow.file.enums import FileTransferMethod, FileType +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py new file mode 100644 index 0000000000..01272ba052 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/chat/test_generate_response_converter.py @@ -0,0 +1,65 @@ +from collections.abc import Generator + +from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ChatbotAppBlockingResponse, + ChatbotAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestChatAppGenerateResponseConverter: + def test_convert_blocking_simple_response_metadata(self): + data = ChatbotAppBlockingResponse.Data( + id="msg-1", + mode="chat", + conversation_id="c1", + message_id="m1", + answer="hi", + metadata={"usage": {"total_tokens": 1}}, + created_at=1, + ) + blocking = ChatbotAppBlockingResponse(task_id="t1", data=data) + + response = ChatAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "usage" not in response["metadata"] + + def test_convert_stream_responses(self): + def stream() -> Generator[ChatbotAppStreamResponse, None, None]: + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=PingStreamResponse(task_id="t1"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageStreamResponse(task_id="t1", id="m1", answer="hi"), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")), + ) + yield ChatbotAppStreamResponse( + conversation_id="c1", + message_id="m1", + created_at=1, + stream_response=MessageEndStreamResponse(task_id="t1", id="m1"), + ) + + full = list(ChatAppGenerateResponseConverter.convert_stream_full_response(stream())) + assert full[0] == "ping" + assert full[1]["event"] == "message" + assert full[2]["event"] == "error" + + simple = list(ChatAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert simple[0] == "ping" + assert simple[-1]["event"] == "message_end" diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index cd5ea8986a..b0789bbc1e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,9 +3,9 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from core.workflow.runtime import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 5508a117c1..72430a3347 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,8 @@ from collections.abc import Mapping, Sequence from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from core.workflow.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from core.workflow.variables.segments import ArrayFileSegment, FileSegment +from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from dify_graph.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 1c36b4d12b..4ed7d73cd0 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,9 +4,9 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable def _build_converter(): diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 0a9794e41c..5879e8fb9b 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,9 +2,9 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable def _build_converter() -> WorkflowResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index d25bff92dc..374af5ddc4 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun import uuid from collections.abc import Mapping from dataclasses import dataclass +from datetime import UTC, datetime from typing import Any from unittest.mock import Mock @@ -23,9 +24,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -66,7 +67,7 @@ class TestWorkflowResponseConverter: node_execution_id=node_execution_id or str(uuid.uuid4()), node_id="test-node-id", node_title="Test Node", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, start_at=naive_utc_now(), in_iteration_id=None, in_loop_id=None, @@ -83,7 +84,7 @@ class TestWorkflowResponseConverter: """Create a QueueNodeSucceededEvent for testing.""" return QueueNodeSucceededEvent( node_id="test-node-id", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_execution_id=node_execution_id, start_at=naive_utc_now(), in_iteration_id=None, @@ -108,7 +109,7 @@ class TestWorkflowResponseConverter: error="oops", retry_index=1, node_id="test-node-id", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="test code", provider_type="built-in", provider_id="code", @@ -234,6 +235,50 @@ class TestWorkflowResponseConverter: assert response.data.process_data == {} assert response.data.process_data_truncated is False + def test_workflow_node_finish_response_prefers_event_finished_at( + self, + monkeypatch: pytest.MonkeyPatch, + ) -> None: + """Finished timestamps should come from the event, not delayed queue processing time.""" + converter = self.create_workflow_response_converter() + start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None) + delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None) + + monkeypatch.setattr( + "core.app.apps.common.workflow_response_converter.naive_utc_now", + lambda: delayed_processing_time, + ) + converter.workflow_start_to_stream_response( + task_id="bootstrap", + workflow_run_id="run-id", + workflow_id="wf-id", + reason=WorkflowStartReason.INITIAL, + ) + + event = QueueNodeSucceededEvent( + node_id="test-node-id", + node_type=BuiltinNodeTypes.CODE, + node_execution_id="node-exec-1", + start_at=start_at, + finished_at=finished_at, + in_iteration_id=None, + in_loop_id=None, + inputs={}, + process_data={}, + outputs={}, + execution_metadata={}, + ) + + response = converter.workflow_node_finish_to_stream_response( + event=event, + task_id="test-task-id", + ) + + assert response is not None + assert response.data.elapsed_time == 2.0 + assert response.data.finished_at == int(finished_at.timestamp()) + def test_workflow_node_retry_response_uses_truncated_process_data(self): """Test that node retry response uses get_response_process_data().""" converter = self.create_workflow_response_converter() @@ -319,7 +364,7 @@ class TestWorkflowResponseConverter: iteration_event = QueueNodeSucceededEvent( node_id="iteration-node", - node_type=NodeType.ITERATION, + node_type=BuiltinNodeTypes.ITERATION, node_execution_id=str(uuid.uuid4()), start_at=naive_utc_now(), in_iteration_id=None, @@ -336,7 +381,7 @@ class TestWorkflowResponseConverter: ) assert response is None - loop_event = iteration_event.model_copy(update={"node_type": NodeType.LOOP}) + loop_event = iteration_event.model_copy(update={"node_type": BuiltinNodeTypes.LOOP}) response = converter.workflow_node_finish_to_stream_response( event=loop_event, task_id="test-task-id", @@ -478,7 +523,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -523,7 +568,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -562,7 +607,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_value, process_data=large_value, @@ -600,7 +645,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -614,7 +659,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeFailedEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -628,7 +673,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: return QueueNodeExceptionEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=inputs, process_data=process_data, @@ -690,7 +735,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueNodeStartedEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Test Node", node_run_index=1, start_at=naive_utc_now(), @@ -706,7 +751,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeRetryEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_title="Test Node", node_run_index=1, start_at=naive_utc_now(), @@ -748,7 +793,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueIterationStartEvent( node_execution_id="test_iter_exec_id", node_id="test_iteration", - node_type=NodeType.ITERATION, + node_type=BuiltinNodeTypes.ITERATION, node_title="Test Iteration", node_run_index=0, start_at=naive_utc_now(), @@ -776,7 +821,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: start_event = QueueLoopStartEvent( node_execution_id="test_loop_exec_id", node_id="test_loop", - node_type=NodeType.LOOP, + node_type=BuiltinNodeTypes.LOOP, node_title="Test Loop", start_at=naive_utc_now(), inputs=large_inputs, @@ -806,7 +851,7 @@ class TestWorkflowResponseConverterServiceApiTruncation: event = QueueNodeSucceededEvent( node_execution_id="test_node_exec_id", node_id="test_node", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=naive_utc_now(), inputs=large_inputs, process_data=large_process_data, diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py new file mode 100644 index 0000000000..51f33bac35 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -0,0 +1,162 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.completion.app_runner as module +from core.app.apps.completion.app_runner import CompletionAppRunner +from core.moderation.base import ModerationError +from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + + +@pytest.fixture +def runner(): + return CompletionAppRunner() + + +def _build_app_config(dataset=None, external_tools=None, additional_features=None): + app_config = MagicMock() + app_config.app_id = "app1" + app_config.tenant_id = "tenant" + app_config.prompt_template = MagicMock() + app_config.dataset = dataset + app_config.external_data_variables = external_tools or [] + app_config.additional_features = additional_features + app_config.app_model_config_dict = {"file_upload": {"enabled": True}} + return app_config + + +def _build_generate_entity(app_config, file_upload_config=None): + model_conf = MagicMock( + provider_model_bundle="bundle", + model="model", + parameters={"max_tokens": 10}, + stop=["stop"], + ) + return SimpleNamespace( + app_config=app_config, + model_conf=model_conf, + inputs={"qvar": "query_from_input"}, + query="original_query", + files=[], + file_upload_config=file_upload_config, + stream=True, + user_id="user", + invoke_from=MagicMock(), + ) + + +class TestCompletionAppRunner: + def test_run_app_not_found(self, runner, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + with pytest.raises(ValueError): + runner.run(app_generate_entity, MagicMock(), MagicMock()) + + def test_run_moderation_error_outputs_direct(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(side_effect=ModerationError("blocked")) + runner.direct_output = MagicMock() + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner.direct_output.assert_called_once() + runner._handle_invoke_result.assert_not_called() + + def test_run_hosting_moderation_stops(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + runner._handle_invoke_result = MagicMock() + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + runner._handle_invoke_result.assert_not_called() + + def test_run_dataset_and_external_tools_flow(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + session.close = MagicMock() + mocker.patch.object(module.db, "session", session) + + retrieve_config = MagicMock(query_variable="qvar") + dataset_config = MagicMock(dataset_ids=["ds"], retrieve_config=retrieve_config) + additional_features = MagicMock(show_retrieve_source=True) + app_config = _build_app_config( + dataset=dataset_config, + external_tools=["tool"], + additional_features=additional_features, + ) + + file_upload_config = MagicMock() + file_upload_config.image_config.detail = ImagePromptMessageContent.DETAIL.HIGH + + app_generate_entity = _build_generate_entity(app_config, file_upload_config=file_upload_config) + + runner.organize_prompt_messages = MagicMock(side_effect=[(["pm1"], ["stop"]), (["pm2"], ["stop"])]) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.fill_in_inputs_from_external_data_tools = MagicMock(return_value=app_generate_entity.inputs) + runner.check_hosting_moderation = MagicMock(return_value=False) + runner.recalc_llm_max_tokens = MagicMock() + runner._handle_invoke_result = MagicMock() + + dataset_retrieval = MagicMock() + dataset_retrieval.retrieve.return_value = ("ctx", ["file1"]) + mocker.patch.object(module, "DatasetRetrieval", return_value=dataset_retrieval) + + model_instance = MagicMock() + model_instance.invoke_llm.return_value = "invoke_result" + mocker.patch.object(module, "ModelInstance", return_value=model_instance) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg", tenant_id="tenant")) + + dataset_retrieval.retrieve.assert_called_once() + assert dataset_retrieval.retrieve.call_args.kwargs["query"] == "query_from_input" + runner._handle_invoke_result.assert_called_once() + + def test_run_uses_low_image_detail_default(self, runner, mocker): + app_record = MagicMock(id="app1", tenant_id="tenant") + + session = mocker.MagicMock() + session.scalar.return_value = app_record + mocker.patch.object(module.db, "session", session) + + app_config = _build_app_config() + app_generate_entity = _build_generate_entity(app_config, file_upload_config=None) + + runner.organize_prompt_messages = MagicMock(return_value=([], None)) + runner.moderation_for_inputs = MagicMock(return_value=(None, app_generate_entity.inputs, "query")) + runner.check_hosting_moderation = MagicMock(return_value=True) + + runner.run(app_generate_entity, MagicMock(), MagicMock(id="msg")) + + assert ( + runner.organize_prompt_messages.call_args.kwargs["image_detail_config"] + == ImagePromptMessageContent.DETAIL.LOW + ) diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py new file mode 100644 index 0000000000..024bd8f302 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_app_config_manager.py @@ -0,0 +1,122 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.completion.app_config_manager as module +from core.app.app_config.entities import EasyUIBasedAppModelConfigFrom +from core.app.apps.completion.app_config_manager import CompletionAppConfigManager +from models.model import AppMode + + +class TestCompletionAppConfigManager: + def test_get_app_config_with_override(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + override_config = {"model": {"provider": "override"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=(["v1"], ["ext1"])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + override_config_dict=override_config, + ) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.ARGS + assert result.app_model_config_dict == override_config + assert result.variables == ["v1"] + assert result.external_data_variables == ["ext1"] + assert result.app_mode == AppMode.COMPLETION + + def test_get_app_config_without_override_uses_model_config(self, mocker): + app_model = MagicMock(tenant_id="tenant", id="app1", mode=AppMode.COMPLETION.value) + app_model_config = MagicMock(id="cfg1") + app_model_config.to_dict.return_value = {"model": {"provider": "x"}} + + mocker.patch.object(module.ModelConfigManager, "convert", return_value="model") + mocker.patch.object(module.PromptTemplateConfigManager, "convert", return_value="prompt") + mocker.patch.object(module.SensitiveWordAvoidanceConfigManager, "convert", return_value="moderation") + mocker.patch.object(module.DatasetConfigManager, "convert", return_value="dataset") + mocker.patch.object(CompletionAppConfigManager, "convert_features", return_value="features") + mocker.patch.object(module.BasicVariablesConfigManager, "convert", return_value=([], [])) + mocker.patch.object(module, "CompletionAppConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = CompletionAppConfigManager.get_app_config(app_model=app_model, app_model_config=app_model_config) + + assert result.app_model_config_from == EasyUIBasedAppModelConfigFrom.APP_LATEST_CONFIG + assert result.app_model_config_dict == {"model": {"provider": "x"}} + + def test_config_validate_filters_related_keys(self, mocker): + config = { + "model": {"provider": "x"}, + "variables": ["v"], + "file_upload": {"enabled": True}, + "prompt": {"template": "t"}, + "dataset": {"enabled": True}, + "tts": {"enabled": True}, + "more_like_this": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.ModelConfigManager, + "validate_and_set_defaults", + return_value=(config, ["model"]), + ) + mocker.patch.object( + module.BasicVariablesConfigManager, + "validate_and_set_defaults", + return_value=(config, ["variables"]), + ) + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.PromptTemplateConfigManager, + "validate_and_set_defaults", + return_value=(config, ["prompt"]), + ) + mocker.patch.object( + module.DatasetConfigManager, + "validate_and_set_defaults", + return_value=(config, ["dataset"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.MoreLikeThisConfigManager, + "validate_and_set_defaults", + return_value=(config, ["more_like_this"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = CompletionAppConfigManager.config_validate("tenant", config) + + assert "extra" not in filtered + assert set(filtered.keys()) == { + "model", + "variables", + "file_upload", + "prompt", + "dataset", + "tts", + "more_like_this", + "moderation", + } diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py new file mode 100644 index 0000000000..2714757353 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -0,0 +1,321 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pydantic import ValidationError + +import core.app.apps.completion.app_generator as module +from core.app.apps.completion.app_generator import CompletionAppGenerator +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError +from services.errors.app import MoreLikeThisDisabledError +from services.errors.message import MessageNotExistsError + + +@pytest.fixture +def generator(mocker): + gen = CompletionAppGenerator() + + mocker.patch.object(module, "copy_current_request_context", side_effect=lambda fn: fn) + + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "current_app", MagicMock(_get_current_object=MagicMock(return_value=flask_app))) + + thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=thread) + + mocker.patch.object(module, "MessageBasedAppQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "TraceQueueManager", return_value=MagicMock()) + mocker.patch.object(module, "CompletionAppGenerateEntity", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + return gen + + +def _build_app_model(): + return MagicMock(tenant_id="tenant", id="app1", mode="completion") + + +def _build_user(): + return MagicMock(id="user", session_id="session") + + +def _build_app_model_config(): + config = MagicMock(id="cfg") + config.to_dict.return_value = {"model": {"provider": "x"}} + return config + + +class TestCompletionAppGenerator: + def test_generate_invalid_query_type(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": 123, "inputs": {}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + def test_generate_override_not_debugger(self, generator): + with pytest.raises(ValueError): + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": {}}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + def test_generate_success_no_file_config(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.file_factory, "build_from_mappings") + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": []}, + invoke_from=InvokeFrom.WEB_APP, + streaming=True, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_not_called() + + def test_generate_success_with_files(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + mocker.patch.object(module.CompletionAppConfigManager, "get_app_config", return_value=app_config) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + conversation = MagicMock(id="conv", mode="completion") + message = MagicMock(id="msg") + mocker.patch.object(generator, "_init_generate_records", return_value=(conversation, message)) + + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {"a": 1}, "files": [{"id": "f"}]}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + assert result == "converted" + module.file_factory.build_from_mappings.assert_called_once() + + def test_generate_override_model_config_debugger(self, generator, mocker): + app_model_config = _build_app_model_config() + mocker.patch.object(generator, "_get_app_model_config", return_value=app_model_config) + + override_config = {"model": {"provider": "override"}} + mocker.patch.object(module.CompletionAppConfigManager, "config_validate", return_value=override_config) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=None) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + generator.generate( + app_model=_build_app_model(), + user=_build_user(), + args={"query": "q", "inputs": {}, "files": [], "model_config": override_config}, + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert get_app_config.call_args.kwargs["override_config_dict"] == override_config + + def test_generate_more_like_this_message_not_found(self, generator, mocker): + session = mocker.MagicMock() + session.scalar.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MessageNotExistsError): + generator.generate_more_like_this( + app_model=_build_app_model(), + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_disabled(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=False, more_like_this_dict={"enabled": False}) + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_app_model_config_missing(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = None + + message = MagicMock() + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(MoreLikeThisDisabledError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_message_config_none(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock(app_model_config=None) + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + ) + + def test_generate_more_like_this_success(self, generator, mocker): + app_model = _build_app_model() + app_model.app_model_config = MagicMock(more_like_this=True, more_like_this_dict={"enabled": True}) + + message = MagicMock() + message.message_files = [{"id": "f"}] + message.inputs = {"a": 1} + message.query = "q" + + app_model_config = MagicMock() + app_model_config.to_dict.return_value = { + "model": {"completion_params": {"temperature": 0.1}}, + "file_upload": {"enabled": True}, + } + message.app_model_config = app_model_config + + session = mocker.MagicMock() + session.scalar.return_value = message + mocker.patch.object(module.db, "session", session) + + file_extra_config = MagicMock() + mocker.patch.object(module.FileUploadConfigManager, "convert", return_value=file_extra_config) + mocker.patch.object(module.file_factory, "build_from_mappings", return_value=["file1"]) + + app_config = MagicMock(variables=["v"], to_dict=MagicMock(return_value={})) + get_app_config = mocker.patch.object( + module.CompletionAppConfigManager, + "get_app_config", + return_value=app_config, + ) + mocker.patch.object(module.ModelConfigConverter, "convert", return_value=MagicMock()) + + mocker.patch.object( + generator, + "_init_generate_records", + return_value=(MagicMock(id="conv", mode="completion"), MagicMock(id="msg")), + ) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.CompletionAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator.generate_more_like_this( + app_model=app_model, + message_id="msg", + user=_build_user(), + invoke_from=InvokeFrom.WEB_APP, + stream=True, + ) + + assert result == "converted" + override_dict = get_app_config.call_args.kwargs["override_config_dict"] + assert override_dict["model"]["completion_params"]["temperature"] == 0.9 + + @pytest.mark.parametrize( + ("error", "should_publish"), + [ + (GenerateTaskStoppedError(), False), + (InvokeAuthorizationError("bad"), True), + ( + ValidationError.from_exception_data( + "Model", + [{"type": "missing", "loc": ("x",), "msg": "Field required", "input": {}}], + ), + True, + ), + (ValueError("bad"), True), + (RuntimeError("boom"), True), + ], + ) + def test_generate_worker_error_handling(self, generator, mocker, error, should_publish): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + + session = mocker.MagicMock() + mocker.patch.object(module.db, "session", session) + + mocker.patch.object(generator, "_get_message", return_value=MagicMock()) + + runner_instance = MagicMock() + runner_instance.run.side_effect = error + mocker.patch.object(module, "CompletionAppRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=MagicMock(), + queue_manager=queue_manager, + message_id="msg", + ) + + assert queue_manager.publish_error.called is should_publish diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py new file mode 100644 index 0000000000..0136dbf5ad --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_generate_response_converter.py @@ -0,0 +1,169 @@ +from collections.abc import Generator + +from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + CompletionAppBlockingResponse, + CompletionAppStreamResponse, + ErrorStreamResponse, + MessageEndStreamResponse, + MessageStreamResponse, + PingStreamResponse, +) + + +class TestCompletionAppGenerateResponseConverter: + def test_convert_blocking_full_response(self): + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata={"k": "v"}, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_full_response(blocking) + + assert result["event"] == "message" + assert result["task_id"] == "task" + assert result["message_id"] == "msg" + assert result["answer"] == "answer" + assert result["metadata"] == {"k": "v"} + + def test_convert_blocking_simple_response_metadata_simplified(self): + metadata = { + "retriever_resources": [ + { + "dataset_id": "dataset-1", + "dataset_name": "Dataset 1", + "document_id": "document-1", + "segment_id": "s", + "position": 1, + "data_source_type": "file", + "document_name": "doc", + "score": 0.9, + "hit_count": 2, + "word_count": 128, + "segment_position": 3, + "index_node_hash": "abc1234", + "content": "c", + "page": 5, + "title": "Citation Title", + "files": [{"id": "file-1"}], + "summary": "sum", + "extra": "x", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + } + blocking = CompletionAppBlockingResponse( + task_id="task", + data=CompletionAppBlockingResponse.Data( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata=metadata, + created_at=123, + ), + ) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert "annotation_reply" not in result["metadata"] + assert "usage" not in result["metadata"] + assert result["metadata"]["retriever_resources"][0]["dataset_id"] == "dataset-1" + assert result["metadata"]["retriever_resources"][0]["document_id"] == "document-1" + assert result["metadata"]["retriever_resources"][0]["segment_id"] == "s" + assert result["metadata"]["retriever_resources"][0]["data_source_type"] == "file" + assert result["metadata"]["retriever_resources"][0]["segment_position"] == 3 + assert result["metadata"]["retriever_resources"][0]["index_node_hash"] == "abc1234" + assert "extra" not in result["metadata"]["retriever_resources"][0] + + def test_convert_blocking_simple_response_metadata_not_dict(self): + data = CompletionAppBlockingResponse.Data.model_construct( + id="id", + mode="completion", + message_id="msg", + answer="answer", + metadata="bad", + created_at=123, + ) + blocking = CompletionAppBlockingResponse.model_construct(task_id="task", data=data) + + result = CompletionAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert result["metadata"] == {} + + def test_convert_stream_full_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=MessageStreamResponse(task_id="t", id="1", answer="ok"), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + assert result[2]["event"] == "message" + + def test_convert_stream_simple_response(self): + def stream() -> Generator[AppStreamResponse, None, None]: + yield CompletionAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + message_id="m", + created_at=1, + ) + yield CompletionAppStreamResponse( + stream_response=MessageEndStreamResponse( + task_id="t", + id="end", + metadata={ + "retriever_resources": [ + { + "segment_id": "s", + "position": 1, + "document_name": "doc", + "score": 0.9, + "content": "c", + "summary": "sum", + } + ], + "annotation_reply": {"a": 1}, + "usage": {"t": 2}, + }, + ), + message_id="m", + created_at=2, + ) + yield CompletionAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + message_id="m", + created_at=3, + ) + + result = list(CompletionAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "message_end" + assert "annotation_reply" not in result[1]["metadata"] + assert "usage" not in result[1]["metadata"] + assert result[2]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py new file mode 100644 index 0000000000..5d4c9bcde0 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_config_manager.py @@ -0,0 +1,55 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import core.app.apps.pipeline.pipeline_config_manager as module +from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager +from models.model import AppMode + + +def test_get_pipeline_config(mocker): + pipeline = MagicMock(tenant_id="tenant", id="pipe1") + workflow = MagicMock(id="wf1") + + mocker.patch.object( + module.WorkflowVariablesConfigManager, + "convert_rag_pipeline_variable", + return_value=["var1"], + ) + mocker.patch.object(module, "PipelineConfig", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + result = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow, start_node_id="start") + + assert result.tenant_id == "tenant" + assert result.app_id == "pipe1" + assert result.workflow_id == "wf1" + assert result.app_mode == AppMode.RAG_PIPELINE + assert result.rag_pipeline_variables == ["var1"] + + +def test_config_validate_filters_related_keys(mocker): + config = { + "file_upload": {"enabled": True}, + "tts": {"enabled": True}, + "moderation": {"enabled": True}, + "extra": "drop", + } + + mocker.patch.object( + module.FileUploadConfigManager, + "validate_and_set_defaults", + return_value=(config, ["file_upload"]), + ) + mocker.patch.object( + module.TextToSpeechConfigManager, + "validate_and_set_defaults", + return_value=(config, ["tts"]), + ) + mocker.patch.object( + module.SensitiveWordAvoidanceConfigManager, + "validate_and_set_defaults", + return_value=(config, ["moderation"]), + ) + + filtered = PipelineConfigManager.config_validate("tenant", config) + + assert set(filtered.keys()) == {"file_upload", "tts", "moderation"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py new file mode 100644 index 0000000000..94ed8166b9 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -0,0 +1,111 @@ +from collections.abc import Generator + +from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + AppStreamResponse, + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +def test_convert_blocking_full_and_simple_response(): + blocking = WorkflowAppBlockingResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowAppBlockingResponse.Data( + id="id", + workflow_id="wf", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"k": "v"}, + error=None, + elapsed_time=0.1, + total_tokens=10, + total_steps=1, + created_at=1, + finished_at=2, + ), + ) + + full = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + simple = WorkflowAppGenerateResponseConverter.convert_blocking_simple_response(blocking) + + assert full == simple + assert full["workflow_run_id"] == "run" + assert full["data"]["status"] == WorkflowExecutionStatus.SUCCEEDED + + +def test_convert_stream_full_response(): + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse( + stream_response=PingStreamResponse(task_id="t"), + workflow_run_id="run", + ) + yield WorkflowAppStreamResponse( + stream_response=ErrorStreamResponse(task_id="t", err=ValueError("bad")), + workflow_run_id="run", + ) + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(stream())) + + assert result[0] == "ping" + assert result[1]["event"] == "error" + assert result[1]["code"] == "invalid_param" + + +def test_convert_stream_simple_response_node_ignore_details(): + node_start = NodeStartStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="nid", + node_id="node", + node_type="type", + title="Title", + index=1, + predecessor_node_id=None, + inputs={"a": 1}, + inputs_truncated=False, + process_data=None, + process_data_truncated=False, + outputs={"b": 2}, + outputs_truncated=False, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + error=None, + elapsed_time=0.1, + execution_metadata=None, + created_at=1, + finished_at=2, + files=[], + ), + ) + + def stream() -> Generator[AppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(stream_response=node_start, workflow_run_id="run") + yield WorkflowAppStreamResponse(stream_response=node_finish, workflow_run_id="run") + + result = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + + assert result[0]["event"] == "node_started" + assert result[0]["data"]["inputs"] is None + assert result[1]["event"] == "node_finished" + assert result[1]["data"]["inputs"] is None diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py new file mode 100644 index 0000000000..06face41fe --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generator.py @@ -0,0 +1,699 @@ +import contextlib +from types import SimpleNamespace +from unittest.mock import MagicMock, PropertyMock + +import pytest + +import core.app.apps.pipeline.pipeline_generator as module +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.entities.datasource_entities import DatasourceProviderType + + +class FakeRagPipelineGenerateEntity(SimpleNamespace): + class SingleIterationRunEntity(SimpleNamespace): + pass + + class SingleLoopRunEntity(SimpleNamespace): + pass + + def model_dump(self): + return dict(self.__dict__) + + +@pytest.fixture +def generator(mocker): + gen = module.PipelineGenerator() + + mocker.patch.object(module, "RagPipelineGenerateEntity", FakeRagPipelineGenerateEntity) + mocker.patch.object(module, "RagPipelineInvokeEntity", side_effect=lambda **kwargs: kwargs) + mocker.patch.object(module.contexts, "plugin_tool_providers", SimpleNamespace(set=MagicMock())) + mocker.patch.object(module.contexts, "plugin_tool_providers_lock", SimpleNamespace(set=MagicMock())) + + return gen + + +def _build_pipeline_dataset(): + return SimpleNamespace( + id="ds", + name="dataset", + description="desc", + chunk_structure="chunk", + built_in_field_enabled=True, + tenant_id="tenant", + ) + + +def _build_pipeline(): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + pipeline.retrieve_dataset.return_value = _build_pipeline_dataset() + return pipeline + + +def _build_workflow(): + return MagicMock(id="wf", graph_dict={"nodes": [], "edges": []}, tenant_id="tenant") + + +def _build_user(): + return MagicMock(id="user", name="User", session_id="session") + + +def _build_args(): + return { + "inputs": {"k": "v"}, + "start_node_id": "start", + "datasource_type": DatasourceProviderType.LOCAL_FILE.value, + "datasource_info_list": [{"name": "file"}], + } + + +def _patch_session(mocker, session): + mocker.patch.object(module, "Session", return_value=session) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + +def _dummy_preserve(*args, **kwargs): + return contextlib.nullcontext() + + +class DummySession: + def __init__(self): + self.scalar = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +def test_generate_dataset_missing(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.generate( + pipeline=pipeline, + workflow=_build_workflow(), + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + ) + + +def test_generate_debugger_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.DEBUGGER, + streaming=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_published_pipeline_creates_documents_and_delay(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + datasource_info_list = [{"name": "file1"}, {"name": "file2"}] + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=datasource_info_list, + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch("services.dataset_service.DocumentService.get_documents_position", return_value=1) + + document1 = SimpleNamespace( + id="doc1", + position=1, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file1", + indexing_status="", + error=None, + enabled=True, + ) + document2 = SimpleNamespace( + id="doc2", + position=2, + data_source_type=DatasourceProviderType.LOCAL_FILE, + data_source_info="{}", + name="file2", + indexing_status="", + error=None, + enabled=True, + ) + mocker.patch.object(generator, "_build_document", side_effect=[document1, document2]) + + mocker.patch.object(module, "DocumentPipelineExecutionLog", return_value=MagicMock()) + + db_session = MagicMock() + mocker.patch.object(module.db, "session", db_session) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + task_proxy = MagicMock() + mocker.patch.object(module, "RagPipelineTaskProxy", return_value=task_proxy) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=False, + ) + + assert result["batch"] + assert len(result["documents"]) == 2 + task_proxy.delay.assert_called_once() + + +def test_generate_is_retry_calls_generate(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + generator, + "_format_datasource_info_list", + return_value=[{"name": "file"}], + ) + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", rag_pipeline_variables=[]), + ) + mocker.patch.object(generator, "_prepare_user_inputs", return_value={"k": "v"}) + + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + + mocker.patch.object(generator, "_generate", return_value={"result": "ok"}) + + result = generator.generate( + pipeline=pipeline, + workflow=workflow, + user=_build_user(), + args=_build_args(), + invoke_from=InvokeFrom.PUBLISHED_PIPELINE, + streaming=True, + is_retry=True, + ) + + assert result == {"result": "ok"} + + +def test_generate_worker_handles_errors(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + runner_instance.run.side_effect = ValueError("bad") + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + queue_manager = MagicMock() + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + queue_manager.publish_error.assert_called_once() + + +def test_generate_worker_sets_system_user_id_for_external_call(generator, mocker): + flask_app = MagicMock() + flask_app.app_context.return_value = contextlib.nullcontext() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + mocker.patch.object(module.db, "session", MagicMock(close=MagicMock())) + mocker.patch.object(type(module.db), "engine", new_callable=PropertyMock, return_value=MagicMock()) + + application_generate_entity = FakeRagPipelineGenerateEntity( + app_config=SimpleNamespace(tenant_id="tenant", app_id="pipe", workflow_id="wf"), + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + ) + + session = DummySession() + session.scalar.side_effect = [MagicMock(), MagicMock(session_id="session")] + _patch_session(mocker, session) + + runner_instance = MagicMock() + mocker.patch.object(module, "PipelineRunner", return_value=runner_instance) + + generator._generate_worker( + flask_app=flask_app, + application_generate_entity=application_generate_entity, + queue_manager=MagicMock(), + context=contextlib.nullcontext(), + variable_loader=MagicMock(), + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + assert module.PipelineRunner.call_args.kwargs["system_user_id"] == "session" + + +def test_generate_raises_when_workflow_not_found(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object(module.db, "session", session) + + with pytest.raises(ValueError): + generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + +def test_generate_success_returns_converted(generator, mocker): + flask_app = MagicMock() + mocker.patch.object(module, "preserve_flask_contexts", _dummy_preserve) + + workflow = MagicMock(id="wf", tenant_id="tenant", app_id="pipe", graph_dict={}) + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", session) + + queue_manager = MagicMock() + mocker.patch.object(module, "PipelineQueueManager", return_value=queue_manager) + + worker_thread = MagicMock() + mocker.patch.object(module.threading, "Thread", return_value=worker_thread) + + mocker.patch.object(generator, "_get_draft_var_saver_factory", return_value=MagicMock()) + mocker.patch.object(generator, "_handle_response", return_value="response") + mocker.patch.object(module.WorkflowAppGenerateResponseConverter, "convert", return_value="converted") + + result = generator._generate( + flask_app=flask_app, + context=contextlib.nullcontext(), + pipeline=_build_pipeline(), + workflow_id="wf", + user=_build_user(), + application_generate_entity=FakeRagPipelineGenerateEntity( + task_id="t", + app_config=SimpleNamespace(app_id="pipe"), + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + ), + invoke_from=InvokeFrom.DEBUGGER, + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + streaming=True, + ) + + assert result == "converted" + + +def test_single_iteration_generate_validates_inputs(generator, mocker): + with pytest.raises(ValueError): + generator.single_iteration_generate(_build_pipeline(), _build_workflow(), "", _build_user(), {}) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + _build_pipeline(), _build_workflow(), "node", _build_user(), {"inputs": None} + ) + + +def test_single_iteration_generate_dataset_required(generator, mocker): + pipeline = _build_pipeline() + pipeline.retrieve_dataset.return_value = None + + session = DummySession() + _patch_session(mocker, session) + + with pytest.raises(ValueError): + generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + ) + + +def test_single_iteration_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_iteration_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_single_loop_generate_success(generator, mocker): + pipeline = _build_pipeline() + + session = DummySession() + _patch_session(mocker, session) + + mocker.patch.object( + module.PipelineConfigManager, + "get_pipeline_config", + return_value=SimpleNamespace(app_id="pipe", tenant_id="tenant"), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object( + module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + return_value=MagicMock(), + ) + mocker.patch.object(module.db, "session", MagicMock(return_value=MagicMock())) + + mocker.patch.object(module, "WorkflowDraftVariableService", return_value=MagicMock()) + mocker.patch.object(module, "DraftVarLoader", return_value=MagicMock()) + + mocker.patch.object(generator, "_generate", return_value={"ok": True}) + + result = generator.single_loop_generate( + pipeline, + _build_workflow(), + "node", + _build_user(), + {"inputs": {"a": 1}}, + streaming=False, + ) + + assert result == {"ok": True} + + +def test_handle_response_value_error_triggers_generate_task_stopped(generator, mocker): + pipeline = _build_pipeline() + workflow = _build_workflow() + app_entity = FakeRagPipelineGenerateEntity(task_id="t") + + task_pipeline = MagicMock() + task_pipeline.process.side_effect = ValueError("I/O operation on closed file.") + mocker.patch.object(module, "WorkflowAppGenerateTaskPipeline", return_value=task_pipeline) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=app_entity, + workflow=workflow, + queue_manager=MagicMock(), + user=_build_user(), + draft_var_saver_factory=MagicMock(), + stream=False, + ) + + +def test_build_document_sets_metadata_for_builtin_fields(generator, mocker): + class DummyDocument(SimpleNamespace): + pass + + mocker.patch.object(module, "Document", side_effect=lambda **kwargs: DummyDocument(**kwargs)) + + document = generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=True, + datasource_type=DatasourceProviderType.LOCAL_FILE, + datasource_info={"name": "file"}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + assert document.name == "file" + assert document.doc_metadata + + +def test_build_document_invalid_datasource_type(generator): + with pytest.raises(ValueError): + generator._build_document( + tenant_id="tenant", + dataset_id="ds", + built_in_field_enabled=False, + datasource_type="invalid", + datasource_info={}, + created_from="rag-pipeline", + position=1, + account=_build_user(), + batch="batch", + document_form="text", + ) + + +def test_format_datasource_info_list_non_online_drive(generator): + result = generator._format_datasource_info_list( + DatasourceProviderType.LOCAL_FILE, + [{"name": "file"}], + _build_pipeline(), + _build_workflow(), + "start", + _build_user(), + ) + + assert result == [{"name": "file"}] + + +def test_format_datasource_info_list_missing_node_data(generator): + workflow = MagicMock(graph_dict={"nodes": []}) + + with pytest.raises(ValueError): + generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + +def test_format_datasource_info_list_online_drive_folder(generator, mocker): + workflow = MagicMock( + graph_dict={ + "nodes": [ + { + "id": "start", + "data": { + "plugin_id": "p", + "provider_name": "provider", + "datasource_name": "drive", + "credential_id": "cred", + }, + } + ] + } + ) + + runtime = MagicMock() + runtime.runtime = SimpleNamespace(credentials=None) + runtime.datasource_provider_type.return_value = DatasourceProviderType.ONLINE_DRIVE + + mocker.patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=runtime, + ) + mocker.patch.object(module.DatasourceProviderService, "get_datasource_credentials", return_value={"k": "v"}) + + mocker.patch.object( + generator, + "_get_files_in_folder", + side_effect=lambda *args, **kwargs: args[4].append({"id": "f"}), + ) + + result = generator._format_datasource_info_list( + DatasourceProviderType.ONLINE_DRIVE, + [{"id": "folder", "type": "folder", "name": "Folder", "bucket": "b"}], + _build_pipeline(), + workflow, + "start", + _build_user(), + ) + + assert result == [{"id": "f"}] + + +def test_get_files_in_folder_recurses_and_collects(generator): + class File: + def __init__(self, id, name, type): + self.id = id + self.name = name + self.type = type + + class FilesPage: + def __init__(self, files, is_truncated=False, next_page_parameters=None): + self.files = files + self.is_truncated = is_truncated + self.next_page_parameters = next_page_parameters + + class Result: + def __init__(self, result): + self.result = result + + class Runtime: + def __init__(self): + self.calls = [] + + def datasource_provider_type(self): + return DatasourceProviderType.ONLINE_DRIVE + + def online_drive_browse_files(self, user_id, request, provider_type): + self.calls.append(request.next_page_parameters) + if request.prefix == "fd": + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + if request.next_page_parameters is None: + return iter( + [ + Result( + [FilesPage([File("f1", "file", "file"), File("fd", "folder", "folder")], True, {"page": 2})] + ) + ] + ) + return iter([Result([FilesPage([File("f2", "file2", "file")], False, None)])]) + + runtime = Runtime() + all_files = [] + + generator._get_files_in_folder( + datasource_runtime=runtime, + prefix="root", + bucket="b", + user_id="user", + all_files=all_files, + datasource_info={}, + ) + + assert {f["id"] for f in all_files} == {"f1", "f2"} diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py new file mode 100644 index 0000000000..72f7552bd1 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -0,0 +1,57 @@ +import pytest + +import core.app.apps.pipeline.pipeline_queue_manager as module +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import ( + QueueErrorEvent, + QueueMessageEndEvent, + QueueStopEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult + + +def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.APPLICATION_MANAGER) + + manager.stop_listen.assert_called_once() + + +def test_publish_stop_events_trigger_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + for event in [ + QueueErrorEvent(error=ValueError("bad")), + QueueMessageEndEvent(llm_result=LLMResult.model_construct()), + QueueWorkflowSucceededEvent(), + QueueWorkflowFailedEvent(error="failed", exceptions_count=1), + QueueWorkflowPartialSuccessEvent(exceptions_count=1), + ]: + manager.stop_listen.reset_mock() + manager._publish(event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_called_once() + + +def test_publish_non_stop_event_no_stop_listen(mocker): + manager = PipelineQueueManager(task_id="t", user_id="u", invoke_from=InvokeFrom.WEB_APP, app_mode="rag") + manager._q = mocker.MagicMock() + manager.stop_listen = mocker.MagicMock() + manager._is_stopped = mocker.MagicMock(return_value=False) + + non_stop_event = mocker.MagicMock(spec=module.AppQueueEvent) + manager._publish(non_stop_event, PublishFrom.TASK_PIPELINE) + manager.stop_listen.assert_not_called() diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py new file mode 100644 index 0000000000..eec95b7f39 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -0,0 +1,297 @@ +""" +Unit tests for PipelineRunner behavior. +Asserts correct event handling, error propagation, and user invocation logic. +Primary collaborators: PipelineRunner, InvokeFrom, GraphRunFailedEvent, UserFrom, and mocked dependencies. +Cross-references: core.app.apps.pipeline.pipeline_runner, core.app.entities.app_invoke_entities. +""" + +"""Unit tests for PipelineRunner behavior. + +This module validates core control-flow outcomes for +``core.app.apps.pipeline.pipeline_runner``: app/workflow lookup, graph +initialization guards, invoke-source to user-source resolution, and failed-run +event handling. Invariants asserted here include strict graph-config +validation, correct ``InvokeFrom`` to ``UserFrom`` mapping, and publishing +error paths driven by ``GraphRunFailedEvent`` through mocked collaborators. +Primary collaborators include ``PipelineRunner``, +``core.app.entities.app_invoke_entities.InvokeFrom``, ``GraphRunFailedEvent``, +``UserFrom``, and patched DB/runtime dependencies used by the runner. +""" + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import core.app.apps.pipeline.pipeline_runner as module +from core.app.apps.pipeline.pipeline_runner import PipelineRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.graph_events import GraphRunFailedEvent + + +def _build_app_generate_entity() -> SimpleNamespace: + app_config = SimpleNamespace(app_id="pipe", workflow_id="wf", tenant_id="tenant") + return SimpleNamespace( + app_config=app_config, + invoke_from=InvokeFrom.WEB_APP, + user_id="user", + trace_manager=MagicMock(), + inputs={"input1": "v1"}, + files=[], + workflow_execution_id="run", + document_id="doc", + original_document_id=None, + batch="batch", + dataset_id="ds", + datasource_type="local_file", + datasource_info={"name": "file"}, + start_node_id="start", + call_depth=0, + single_iteration_run=None, + single_loop_run=None, + ) + + +@pytest.fixture +def runner(): + app_generate_entity = _build_app_generate_entity() + queue_manager = MagicMock() + variable_loader = MagicMock() + workflow = MagicMock() + workflow_execution_repository = MagicMock() + workflow_node_execution_repository = MagicMock() + + return PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=queue_manager, + variable_loader=variable_loader, + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + ) + + +def test_get_app_id(runner): + assert runner._get_app_id() == "pipe" + + +def test_get_workflow_returns_workflow(mocker, runner): + pipeline = MagicMock(tenant_id="tenant", id="pipe") + workflow = MagicMock(id="wf") + + query = MagicMock() + query.where.return_value.first.return_value = workflow + mocker.patch.object(module.db, "session", MagicMock(query=MagicMock(return_value=query))) + + result = runner.get_workflow(pipeline=pipeline, workflow_id="wf") + + assert result == workflow + + +def test_init_rag_pipeline_graph_invalid_config(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={}) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": "bad", "edges": []} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + workflow.graph_dict = {"nodes": [], "edges": "bad"} + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_init_rag_pipeline_graph_not_found(mocker, runner): + workflow = MagicMock(id="wf", tenant_id="tenant", graph_dict={"nodes": [], "edges": []}) + mocker.patch.object(module.Graph, "init", return_value=None) + + with pytest.raises(ValueError): + runner._init_rag_pipeline_graph(workflow=workflow, graph_runtime_state=MagicMock()) + + +def test_update_document_status_on_failure(mocker, runner): + document = MagicMock() + + query = MagicMock() + query.where.return_value.first.return_value = document + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + event = GraphRunFailedEvent(error="boom") + + runner._update_document_status(event, document_id="doc", dataset_id="ds") + + assert document.indexing_status == "error" + assert document.error == "boom" + session.commit.assert_called_once() + + +def test_run_pipeline_not_found(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.invoke_from = InvokeFrom.WEB_APP + app_generate_entity.single_iteration_run = None + app_generate_entity.single_loop_run = None + + query = MagicMock() + query.where.return_value.first.return_value = None + + session = MagicMock() + session.query.return_value = query + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_workflow_not_initialized(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + session = MagicMock() + session.query.return_value = query_pipeline + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + runner.get_workflow = MagicMock(return_value=None) + + with pytest.raises(ValueError): + runner.run() + + +def test_run_single_iteration_path(mocker): + app_generate_entity = _build_app_generate_entity() + app_generate_entity.single_iteration_run = MagicMock() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=MagicMock(), + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock( + return_value=MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={}, + type="rag-pipeline", + version="v1", + ) + ) + runner._prepare_single_node_execution = MagicMock(return_value=("graph", "pool", "state")) + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [MagicMock()] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._prepare_single_node_execution.assert_called_once() + runner._handle_event.assert_called() + + +def test_run_normal_path_builds_graph(mocker): + app_generate_entity = _build_app_generate_entity() + + pipeline = MagicMock(id="pipe") + query_pipeline = MagicMock() + query_pipeline.where.return_value.first.return_value = pipeline + + query_end_user = MagicMock() + query_end_user.where.return_value.first.return_value = MagicMock(session_id="sess") + + session = MagicMock() + session.query.side_effect = [query_end_user, query_pipeline] + mocker.patch.object(module.db, "session", session) + + workflow = MagicMock( + id="wf", + tenant_id="tenant", + app_id="pipe", + graph_dict={"nodes": [], "edges": []}, + environment_variables=[], + rag_pipeline_variables=[{"variable": "input1", "belong_to_node_id": "start"}], + type="rag-pipeline", + version="v1", + ) + + runner = PipelineRunner( + application_generate_entity=app_generate_entity, + queue_manager=MagicMock(), + variable_loader=MagicMock(), + workflow=workflow, + system_user_id="sys", + workflow_execution_repository=MagicMock(), + workflow_node_execution_repository=MagicMock(), + ) + + runner._resolve_user_from = MagicMock(return_value=UserFrom.ACCOUNT) + runner.get_workflow = MagicMock(return_value=workflow) + runner._init_rag_pipeline_graph = MagicMock(return_value="graph") + runner._update_document_status = MagicMock() + runner._handle_event = MagicMock() + + mocker.patch.object( + module.RAGPipelineVariable, + "model_validate", + return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), + ) + mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + workflow_entry = MagicMock() + workflow_entry.graph_engine = MagicMock() + workflow_entry.run.return_value = [] + mocker.patch.object(module, "WorkflowEntry", return_value=workflow_entry) + mocker.patch.object(module, "WorkflowPersistenceLayer", return_value=MagicMock()) + + runner.run() + + runner._init_rag_pipeline_graph.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py index f0d9afc0db..f48a7fb38e 100644 --- a/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_advanced_chat_app_generator.py @@ -11,6 +11,7 @@ from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.task_pipeline import message_cycle_manager from core.app.task_pipeline.message_cycle_manager import MessageCycleManager +from models.enums import ConversationFromSource from models.model import AppMode, Conversation, Message @@ -92,7 +93,7 @@ def test_init_generate_records_marks_existing_conversation(): system_instruction_tokens=0, status="normal", invoke_from=InvokeFrom.WEB_APP.value, - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id="user-id", from_account_id=None, ) @@ -124,12 +125,12 @@ def test_message_cycle_manager_uses_new_conversation_flag(monkeypatch): def start(self): self.started = True - def fake_thread(**kwargs): + def fake_thread(*args, **kwargs): thread = DummyThread(**kwargs) captured["thread"] = thread return thread - monkeypatch.setattr(message_cycle_manager, "Thread", fake_thread) + monkeypatch.setattr(message_cycle_manager, "Timer", fake_thread) manager = MessageCycleManager(application_generate_entity=entity, task_state=MagicMock()) thread = manager.generate_conversation_name(conversation_id="existing-conversation-id", query="hello") diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index 04c8696525..a3ced02394 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,7 +1,9 @@ +from unittest.mock import MagicMock + import pytest from core.app.apps.base_app_generator import BaseAppGenerator -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -366,3 +368,132 @@ def test_validate_inputs_optional_file_with_empty_string_ignores_default(): ) assert result is None + + +class TestBaseAppGeneratorExtras: + def test_prepare_user_inputs_converts_files_and_lists(self, monkeypatch): + base_app_generator = BaseAppGenerator() + + variables = [ + VariableEntity( + variable="file", + label="file", + type=VariableEntityType.FILE, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="file_list", + label="file_list", + type=VariableEntityType.FILE_LIST, + required=False, + allowed_file_types=[], + allowed_file_extensions=[], + allowed_file_upload_methods=[], + ), + VariableEntity( + variable="json", + label="json", + type=VariableEntityType.JSON_OBJECT, + required=False, + ), + ] + + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mapping", + lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + ) + monkeypatch.setattr( + "core.app.apps.base_app_generator.file_factory.build_from_mappings", + lambda mappings, tenant_id, config: ["file-1", "file-2"], + ) + + user_inputs = { + "file": {"id": "file-id"}, + "file_list": [{"id": "file-1"}, {"id": "file-2"}], + "json": {"key": "value"}, + } + + prepared = base_app_generator._prepare_user_inputs( + user_inputs=user_inputs, + variables=variables, + tenant_id="tenant-id", + ) + + assert prepared["file"] == "file-object" + assert prepared["file_list"] == ["file-1", "file-2"] + assert prepared["json"] == {"key": "value"} + + def test_prepare_user_inputs_rejects_invalid_dict_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": {"unexpected": "dict"}}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_prepare_user_inputs_rejects_invalid_list_inputs(self): + base_app_generator = BaseAppGenerator() + variables = [ + VariableEntity( + variable="text", + label="text", + type=VariableEntityType.TEXT_INPUT, + required=False, + ) + ] + + with pytest.raises(ValueError, match="must be a string"): + base_app_generator._prepare_user_inputs( + user_inputs={"text": [{"unexpected": "dict"}]}, + variables=variables, + tenant_id="tenant-id", + ) + + def test_convert_to_event_stream(self): + base_app_generator = BaseAppGenerator() + + assert base_app_generator.convert_to_event_stream({"ok": True}) == {"ok": True} + + def _gen(): + yield {"delta": "hi"} + yield "ping" + + converted = list(base_app_generator.convert_to_event_stream(_gen())) + + assert converted[0].startswith("data: ") + assert "\n\n" in converted[0] + assert converted[1] == "event: ping\n\n" + + def test_get_draft_var_saver_factory_debugger(self): + from core.app.entities.app_invoke_entities import InvokeFrom + from dify_graph.enums import BuiltinNodeTypes + from models import Account + + base_app_generator = BaseAppGenerator() + account = Account(name="Tester", email="tester@example.com") + account.id = "account-id" + account.tenant_id = "tenant-id" + + factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) + saver = factory( + session=MagicMock(), + app_id="app-id", + node_id="node-id", + node_type=BuiltinNodeTypes.START, + node_execution_id="node-exec-id", + ) + + assert saver is not None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py new file mode 100644 index 0000000000..c6dc20ffc6 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent + + +class DummyQueueManager(AppQueueManager): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.published = [] + + def _publish(self, event, pub_from): + self.published.append((event, pub_from)) + + +class TestBaseAppQueueManager: + def test_init_requires_user_id(self): + with pytest.raises(ValueError): + DummyQueueManager(task_id="t1", user_id="", invoke_from=InvokeFrom.SERVICE_API) + + def test_publish_error_records_event(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + manager.publish_error(ValueError("boom"), PublishFrom.TASK_PIPELINE) + + assert isinstance(manager.published[0][0], QueueErrorEvent) + + def test_set_stop_flag_checks_user(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.get.return_value = b"end-user-u1" + AppQueueManager.set_stop_flag(task_id="t1", invoke_from=InvokeFrom.SERVICE_API, user_id="u1") + + mock_redis.setex.assert_called_once() + + def test_set_stop_flag_no_user_check(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + AppQueueManager.set_stop_flag_no_user_check(task_id="t1") + + mock_redis.setex.assert_called_once() + + def test_is_stopped_reads_cache(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = b"1" + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + assert manager._is_stopped() is True + + def test_check_for_sqlalchemy_models_raises(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + bad = SimpleNamespace(_sa_instance_state=True) + with pytest.raises(TypeError): + manager._check_for_sqlalchemy_models(bad) diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py new file mode 100644 index 0000000000..aabeb54553 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -0,0 +1,442 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, + PromptTemplateEntity, +) +from core.app.apps.base_app_runner import AppRunner +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError +from models.model import AppMode + + +class _DummyParameterRule: + def __init__(self, name: str, use_template: str | None = None) -> None: + self.name = name + self.use_template = use_template + + +class _QueueRecorder: + def __init__(self) -> None: + self.events: list[object] = [] + + def publish(self, event, pub_from): + _ = pub_from + self.events.append(event) + + +class TestAppRunner: + def test_recalc_llm_max_tokens_updates_parameters(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 80), + ) + + runner.recalc_llm_max_tokens(model_config, prompt_messages=[AssistantPromptMessage(content="hi")]) + + assert model_config.parameters["max_tokens"] == 20 + + def test_recalc_llm_max_tokens_returns_minus_one_when_no_context(self, monkeypatch): + runner = AppRunner() + + model_schema = SimpleNamespace( + model_properties={}, + parameter_rules=[_DummyParameterRule("max_tokens")], + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="mock", + model_schema=model_schema, + parameters={"max_tokens": 30}, + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.ModelInstance", + lambda provider_model_bundle, model: SimpleNamespace(get_llm_num_tokens=lambda messages: 10), + ) + + assert runner.recalc_llm_max_tokens(model_config, prompt_messages=[]) == -1 + + def test_direct_output_streaming_publishes_chunks_and_end(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(model_conf=SimpleNamespace(model="mock"), stream=True) + + monkeypatch.setattr("core.app.apps.base_app_runner.time.sleep", lambda _: None) + + runner.direct_output( + queue_manager=queue, + app_generate_entity=app_generate_entity, + prompt_messages=[], + text="hi", + stream=True, + ) + + assert any(isinstance(event, QueueLLMChunkEvent) for event in queue.events) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_direct_publishes_end_event(self): + runner = AppRunner() + queue = _QueueRecorder() + llm_result = LLMResult( + model="mock", + prompt_messages=[], + message=AssistantPromptMessage(content="done"), + usage=LLMUsage.empty_usage(), + ) + + runner._handle_invoke_result( + invoke_result=llm_result, + queue_manager=queue, + stream=False, + ) + + assert isinstance(queue.events[-1], QueueMessageEndEvent) + + def test_handle_invoke_result_invalid_type_raises(self): + runner = AppRunner() + queue = _QueueRecorder() + + with pytest.raises(NotImplementedError): + runner._handle_invoke_result( + invoke_result=["unexpected"], + queue_manager=queue, + stream=True, + ) + + def test_organize_prompt_messages_simple_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=["STOP"]) + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="hello", + ) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.SimplePromptTransform.get_prompt", + lambda self, **kwargs: (["simple-message"], ["simple-stop"]), + ) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["simple-message"] + assert stop == ["simple-stop"] + + def test_organize_prompt_messages_advanced_completion_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="completion", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="answer", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="U", assistant="A"), + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-completion-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-completion-message"] + assert stop == [""] + memory_config = captured["memory_config"] + assert memory_config.role_prefix.user == "U" + assert memory_config.role_prefix.assistant == "A" + + def test_organize_prompt_messages_advanced_chat_template(self, monkeypatch): + runner = AppRunner() + model_config = SimpleNamespace(mode="chat", stop=[""]) + captured: dict[str, object] = {} + prompt_template_entity = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity(text="hello", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="world", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + + def _fake_advanced_prompt(self, **kwargs): + captured.update(kwargs) + return ["advanced-chat-message"] + + monkeypatch.setattr("core.app.apps.base_app_runner.AdvancedPromptTransform.get_prompt", _fake_advanced_prompt) + + prompt_messages, stop = runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=model_config, + prompt_template_entity=prompt_template_entity, + inputs={}, + files=[], + query="q", + ) + + assert prompt_messages == ["advanced-chat-message"] + assert stop == [""] + assert len(captured["prompt_template"]) == 2 + + def test_organize_prompt_messages_advanced_missing_templates_raise(self): + runner = AppRunner() + + with pytest.raises(InvokeBadRequestError, match="Advanced completion prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="completion", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + with pytest.raises(InvokeBadRequestError, match="Advanced chat prompt template is required"): + runner.organize_prompt_messages( + app_record=SimpleNamespace(mode=AppMode.CHAT.value), + model_config=SimpleNamespace(mode="chat", stop=[]), + prompt_template_entity=PromptTemplateEntity(prompt_type=PromptTemplateEntity.PromptType.ADVANCED), + inputs={}, + files=[], + ) + + def test_handle_invoke_result_stream_routes_chunks_and_builds_message(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + warning_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.warning", warning_logger) + + image_content = ImagePromptMessageContent( + url="https://example.com/image.png", format="png", mime_type="image/png" + ) + + def _stream(): + yield LLMResultChunk( + model="stream-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage.model_construct( + content=[ + "a", + TextPromptMessageContent(data="b"), + SimpleNamespace(data="c"), + image_content, + ] + ), + ), + ) + + runner._handle_invoke_result( + invoke_result=_stream(), + queue_manager=queue, + stream=True, + agent=False, + ) + + assert isinstance(queue.events[0], QueueLLMChunkEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.message.content == "abc" + warning_logger.assert_called_once() + + def test_handle_invoke_result_stream_agent_mode_handles_multimodal_errors(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + exception_logger = MagicMock() + monkeypatch.setattr("core.app.apps.base_app_runner._logger.exception", exception_logger) + + monkeypatch.setattr( + runner, + "_handle_multimodal_image_content", + MagicMock(side_effect=RuntimeError("failed to save image")), + ) + usage = LLMUsage.empty_usage() + + def _stream(): + yield LLMResultChunk( + model="agent-model", + prompt_messages=[AssistantPromptMessage(content="prompt")], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + ImagePromptMessageContent( + url="https://example.com/image.png", + format="png", + mime_type="image/png", + ), + TextPromptMessageContent(data="done"), + ] + ), + usage=usage, + ), + ) + + runner._handle_invoke_result_stream( + invoke_result=_stream(), + queue_manager=queue, + agent=True, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + ) + + assert isinstance(queue.events[0], QueueAgentMessageEvent) + assert isinstance(queue.events[-1], QueueMessageEndEvent) + assert queue.events[-1].llm_result.usage == usage + exception_logger.assert_called_once() + + def test_handle_multimodal_image_content_fallback_return_branch(self, monkeypatch): + runner = AppRunner() + + class _ToggleBool: + def __init__(self, values: list[bool]): + self._values = values + self._index = 0 + + def __bool__(self): + value = self._values[min(self._index, len(self._values) - 1)] + self._index += 1 + return value + + content = SimpleNamespace( + url=_ToggleBool([False, False]), + base64_data=_ToggleBool([True, False]), + mime_type="image/png", + ) + + db_session = SimpleNamespace(add=MagicMock(), commit=MagicMock(), refresh=MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.ToolFileManager", lambda: MagicMock()) + monkeypatch.setattr("core.app.apps.base_app_runner.db", SimpleNamespace(session=db_session)) + + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.SERVICE_API, publish=MagicMock()) + + runner._handle_multimodal_image_content( + content=content, + message_id="message-id", + user_id="user-id", + tenant_id="tenant-id", + queue_manager=queue_manager, + ) + + db_session.add.assert_not_called() + queue_manager.publish.assert_not_called() + + def test_check_hosting_moderation_direct_output_called(self, monkeypatch): + runner = AppRunner() + queue = _QueueRecorder() + app_generate_entity = SimpleNamespace(stream=False) + + monkeypatch.setattr( + "core.app.apps.base_app_runner.HostingModerationFeature.check", + lambda self, application_generate_entity, prompt_messages: True, + ) + direct_output = MagicMock() + monkeypatch.setattr(runner, "direct_output", direct_output) + + result = runner.check_hosting_moderation( + application_generate_entity=app_generate_entity, + queue_manager=queue, + prompt_messages=[], + ) + + assert result is True + assert direct_output.called + + def test_fill_in_inputs_from_external_data_tools(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.ExternalDataFetch.fetch", + lambda self, tenant_id, app_id, external_data_tools, inputs, query: {"foo": "bar"}, + ) + + result = runner.fill_in_inputs_from_external_data_tools( + tenant_id="tenant", + app_id="app", + external_data_tools=[], + inputs={}, + query="q", + ) + + assert result == {"foo": "bar"} + + def test_moderation_for_inputs_returns_result(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.InputModeration.check", + lambda self, app_id, tenant_id, app_config, inputs, query, message_id, trace_manager: (True, {}, ""), + ) + app_generate_entity = SimpleNamespace(app_config=SimpleNamespace(), trace_manager=None) + + result = runner.moderation_for_inputs( + app_id="app", + tenant_id="tenant", + app_generate_entity=app_generate_entity, + inputs={}, + query="q", + message_id="msg", + ) + + assert result == (True, {}, "") + + def test_query_app_annotations_to_reply(self, monkeypatch): + runner = AppRunner() + monkeypatch.setattr( + "core.app.apps.base_app_runner.AnnotationReplyFeature.query", + lambda self, app_record, message, query, user_id, invoke_from: "reply", + ) + + response = runner.query_app_annotations_to_reply( + app_record=SimpleNamespace(), + message=SimpleNamespace(), + query="hello", + user_id="user", + invoke_from=InvokeFrom.WEB_APP, + ) + + assert response == "reply" diff --git a/api/tests/unit_tests/core/app/apps/test_exc.py b/api/tests/unit_tests/core/app/apps/test_exc.py new file mode 100644 index 0000000000..e41c78e89e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_exc.py @@ -0,0 +1,7 @@ +from core.app.apps.exc import GenerateTaskStoppedError + + +class TestAppsExceptions: + def test_generate_task_stopped_error(self): + err = GenerateTaskStoppedError("stopped") + assert str(err) == "stopped" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py index 87b8dc51e7..1250ac5ecf 100644 --- a/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_generator.py @@ -13,9 +13,11 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.app.apps import message_based_app_generator +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from models.model import AppMode, Conversation, Message +from services.errors.app_model_config import AppModelConfigBrokenError class DummyModelConf: @@ -125,3 +127,55 @@ def test_init_generate_records_sets_conversation_fields_for_chat_entity(): assert entity.conversation_id == "generated-conversation-id" assert entity.is_new_conversation is True assert conversation.id == "generated-conversation-id" + + +class TestMessageBasedAppGeneratorExtras: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = MessageBasedAppGenerator() + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.message_based_app_generator.EasyUIBasedGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=_make_chat_generate_entity(_make_app_config(AppMode.CHAT)), + queue_manager=SimpleNamespace(), + conversation=SimpleNamespace(id="conv"), + message=SimpleNamespace(id="msg"), + user=SimpleNamespace(), + stream=False, + ) + + def test_get_app_model_config_requires_valid_config(self, monkeypatch): + generator = MessageBasedAppGenerator() + app_model = SimpleNamespace(id="app", app_model_config_id=None, app_model_config=None) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model, conversation=None) + + conversation = SimpleNamespace(app_model_config_id="missing-id") + monkeypatch.setattr( + message_based_app_generator, "db", SimpleNamespace(session=SimpleNamespace(scalar=lambda _: None)) + ) + + with pytest.raises(AppModelConfigBrokenError): + generator._get_app_model_config(app_model=SimpleNamespace(id="app"), conversation=conversation) + + def test_get_conversation_introduction_handles_missing_inputs(self): + app_config = _make_app_config(AppMode.CHAT) + app_config.additional_features.opening_statement = "Hello {{name}}" + entity = _make_chat_generate_entity(app_config) + entity.inputs = {} + + generator = MessageBasedAppGenerator() + + assert generator._get_conversation_introduction(entity) == "Hello {name}" diff --git a/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py new file mode 100644 index 0000000000..847ad0ce9b --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_based_app_queue_manager.py @@ -0,0 +1,65 @@ +from unittest.mock import Mock, patch + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueErrorEvent, QueueMessageEndEvent, QueueStopEvent + + +class TestMessageBasedAppQueueManager: + def test_publish_stops_on_terminal_events(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager.stop_listen = Mock() + manager._is_stopped = Mock(return_value=False) + + manager._publish(QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), Mock()) + manager.stop_listen.assert_called_once() + + def test_publish_raises_when_stopped(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=True) + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueErrorEvent(error=ValueError("boom")), PublishFrom.APPLICATION_MANAGER) + + def test_publish_enqueues_message_end(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + manager = MessageBasedAppQueueManager( + task_id="t1", + user_id="u1", + invoke_from=InvokeFrom.SERVICE_API, + conversation_id="c1", + app_mode="chat", + message_id="m1", + ) + + manager._is_stopped = Mock(return_value=False) + manager.stop_listen = Mock() + + manager._publish(QueueMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + assert manager._q.qsize() == 1 diff --git a/api/tests/unit_tests/core/app/apps/test_message_generator.py b/api/tests/unit_tests/core/app/apps/test_message_generator.py new file mode 100644 index 0000000000..25377e633e --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_message_generator.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode + + +class TestMessageGenerator: + def test_get_response_topic(self): + channel = Mock() + channel.topic.return_value = "topic" + + with patch("core.app.apps.message_generator.get_pubsub_broadcast_channel", return_value=channel): + topic = MessageGenerator.get_response_topic(AppMode.WORKFLOW, "run-1") + + assert topic == "topic" + expected_key = MessageGenerator._make_channel_key(AppMode.WORKFLOW, "run-1") + channel.topic.assert_called_once_with(expected_key) + + def test_retrieve_events_passes_arguments(self): + with ( + patch("core.app.apps.message_generator.MessageGenerator.get_response_topic", return_value="topic"), + patch( + "core.app.apps.message_generator.stream_topic_events", return_value=iter([{"event": "ping"}]) + ) as mock_stream, + ): + events = list(MessageGenerator.retrieve_events(AppMode.WORKFLOW, "run-1", idle_timeout=1, ping_interval=2)) + + assert events == [{"event": "ping"}] + mock_stream.assert_called_once() diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 97c993928e..2f73a8cda8 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -1,39 +1,36 @@ import sys import time -from pathlib import Path from types import ModuleType, SimpleNamespace from typing import Any -API_DIR = str(Path(__file__).resolve().parents[5]) -if API_DIR not in sys.path: - sys.path.insert(0, API_DIR) - -import core.workflow.nodes.human_input.entities # noqa: F401 +import dify_graph.nodes.human_input.entities # noqa: F401 from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult, PauseRequestedEvent -from core.workflow.nodes.base.entities import BaseNodeData, OutputVariableEntity, RetryConfig -from core.workflow.nodes.base.node import Node -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.node_events import NodeRunResult, PauseRequestedEvent +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.nodes.base.node import Node +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: ops_stub = ModuleType("core.ops.ops_trace_manager") @@ -47,11 +44,12 @@ if "core.ops.ops_trace_manager" not in sys.modules: class _StubToolNodeData(BaseNodeData): + type: NodeType = BuiltinNodeTypes.TOOL pause_on: bool = False class _StubToolNode(Node[_StubToolNodeData]): - node_type = NodeType.TOOL + node_type = BuiltinNodeTypes.TOOL @classmethod def version(cls) -> str: @@ -93,23 +91,24 @@ class _StubToolNode(Node[_StubToolNodeData]): def _patch_tool_node(mocker): original_create_node = DifyNodeFactory.create_node - def _patched_create_node(self, node_config: dict[str, object]) -> Node: - node_data = node_config.get("data", {}) - if isinstance(node_data, dict) and node_data.get("type") == NodeType.TOOL.value: + def _patched_create_node(self, node_config: dict[str, object] | NodeConfigDict) -> Node: + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + if node_data.type == BuiltinNodeTypes.TOOL: return _StubToolNode( - id=str(node_config["id"]), - config=node_config, + id=str(typed_node_config["id"]), + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, ) - return original_create_node(self, node_config) + return original_create_node(self, typed_node_config) mocker.patch.object(DifyNodeFactory, "create_node", _patched_create_node) def _node_data(node_type: NodeType, data: BaseNodeData) -> dict[str, object]: node_data = data.model_dump() - node_data["type"] = node_type.value + node_data["type"] = str(node_type) return node_data @@ -125,11 +124,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: ) nodes = [ - {"id": "start", "data": _node_data(NodeType.START, start_data)}, - {"id": "tool_a", "data": _node_data(NodeType.TOOL, tool_data_a)}, - {"id": "tool_b", "data": _node_data(NodeType.TOOL, tool_data_b)}, - {"id": "tool_c", "data": _node_data(NodeType.TOOL, tool_data_c)}, - {"id": "end", "data": _node_data(NodeType.END, end_data)}, + {"id": "start", "data": _node_data(BuiltinNodeTypes.START, start_data)}, + {"id": "tool_a", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_a)}, + {"id": "tool_b", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_b)}, + {"id": "tool_c", "data": _node_data(BuiltinNodeTypes.TOOL, tool_data_c)}, + {"id": "end", "data": _node_data(BuiltinNodeTypes.END, end_data)}, ] edges = [ {"source": "start", "target": "tool_a"}, @@ -142,11 +141,11 @@ def _build_graph_config(*, pause_on: str | None) -> dict[str, object]: def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> Graph: graph_config = _build_graph_config(pause_on=pause_on) - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", @@ -158,7 +157,7 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G graph_runtime_state=runtime_state, ) - return Graph.init(graph_config=graph_config, node_factory=node_factory) + return Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") def _build_runtime_state(run_id: str) -> GraphRuntimeState: diff --git a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py index 7b5447c01e..a7714c56ce 100644 --- a/api/tests/unit_tests/core/app/apps/test_streaming_utils.py +++ b/api/tests/unit_tests/core/app/apps/test_streaming_utils.py @@ -6,6 +6,7 @@ import queue import pytest from core.app.apps.message_based_app_generator import MessageBasedAppGenerator +from core.app.apps.streaming_utils import _normalize_terminal_events, stream_topic_events from core.app.entities.task_entities import StreamEvent from models.model import AppMode @@ -78,3 +79,30 @@ def test_retrieve_events_calls_on_subscribe_after_subscription(monkeypatch): assert event["event"] == StreamEvent.WORKFLOW_FINISHED.value with pytest.raises(StopIteration): next(generator) + + +def test_normalize_terminal_events_defaults(): + assert _normalize_terminal_events(None) == { + StreamEvent.WORKFLOW_FINISHED.value, + StreamEvent.WORKFLOW_PAUSED.value, + } + + +def test_stream_topic_events_emits_ping_and_idle_timeout(monkeypatch): + topic = FakeTopic() + times = [1000.0, 1000.0, 1001.0, 1001.0, 1002.0] + + def fake_time(): + return times.pop(0) + + monkeypatch.setattr("core.app.apps.streaming_utils.time.time", fake_time) + + generator = stream_topic_events( + topic=topic, + idle_timeout=10.0, + ping_interval=1.0, + ) + + assert next(generator) == StreamEvent.PING.value + # next receive yields None -> ping interval triggers + assert next(generator) == StreamEvent.PING.value diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py new file mode 100644 index 0000000000..3f1dd14569 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -0,0 +1,261 @@ +from __future__ import annotations + +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueLoopCompletedEvent, + QueueTextChunkEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, +) +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable + + +class TestWorkflowBasedAppRunner: + def test_resolve_user_from(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + assert runner._resolve_user_from(InvokeFrom.EXPLORE) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.DEBUGGER) == UserFrom.ACCOUNT + assert runner._resolve_user_from(InvokeFrom.WEB_APP) == UserFrom.END_USER + + def test_init_graph_validates_graph_structure(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + with pytest.raises(ValueError, match="nodes or edges not found"): + runner._init_graph( + graph_config={}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="nodes in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": {}, "edges": []}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with pytest.raises(ValueError, match="edges in workflow graph must be a list"): + runner._init_graph( + graph_config={"nodes": [], "edges": {}}, + graph_runtime_state=runtime_state, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + def test_prepare_single_node_execution_requires_run(self): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + + workflow = SimpleNamespace(environment_variables=[], graph_dict={}) + + with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): + runner._prepare_single_node_execution(workflow, None, None) + + def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): + runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + + graph_config = { + "nodes": [{"id": "node-1", "data": {"type": "start", "version": "1"}}], + "edges": [], + } + workflow = SimpleNamespace(tenant_id="tenant", id="workflow", graph_dict=graph_config) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + lambda **kwargs: SimpleNamespace(), + ) + + class _NodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + from core.app.apps import workflow_app_runner + + monkeypatch.setattr( + workflow_app_runner, + "resolve_workflow_node_class", + lambda **_kwargs: _NodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="node-1", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="iteration_id", + node_type_label="iteration", + ) + + assert graph is not None + assert variable_pool is graph_runtime_state.variable_pool + + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append((event, publish_from)) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + graph_runtime_state.register_paused_node("node-1") + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + emails: list[dict] = [] + + class _Dispatch: + def apply_async(self, *, kwargs, queue): + emails.append({"kwargs": kwargs, "queue": queue}) + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.dispatch_human_input_email_task", + _Dispatch(), + ) + + reason = HumanInputRequired( + form_id="form", + form_content="content", + node_id="node-1", + node_title="Node", + ) + + runner._handle_event(workflow_entry, GraphRunStartedEvent()) + runner._handle_event(workflow_entry, GraphRunSucceededEvent(outputs={"ok": True})) + runner._handle_event(workflow_entry, GraphRunPausedEvent(reasons=[reason], outputs={})) + + assert any(isinstance(event, QueueWorkflowStartedEvent) for event, _ in published) + assert any(isinstance(event, QueueWorkflowSucceededEvent) for event, _ in published) + paused_event = next(event for event, _ in published if isinstance(event, QueueWorkflowPausedEvent)) + assert paused_event.paused_nodes == ["node-1"] + assert emails + + def test_handle_node_events_publishes_queue_events(self): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + + runner._handle_event( + workflow_entry, + NodeRunStartedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=datetime.utcnow(), + ), + ) + runner._handle_event( + workflow_entry, + NodeRunStreamChunkEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + selector=["node", "text"], + chunk="hi", + is_final=False, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunAgentLogEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + message_id="msg", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunIterationSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="Iter", + start_at=datetime.utcnow(), + inputs={}, + outputs={"ok": True}, + metadata={}, + steps=1, + ), + ) + runner._handle_event( + workflow_entry, + NodeRunLoopFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="Loop", + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + metadata={}, + steps=1, + error="boom", + ), + ) + + assert any(isinstance(event, QueueTextChunkEvent) for event in published) + assert any(isinstance(event, QueueAgentLogEvent) for event in published) + assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) + assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index f4efb240c0..1388279221 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -4,8 +4,8 @@ import pytest from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.graph_events.graph import GraphRunPausedEvent +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index f5903d28bd..178e26118e 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -7,9 +7,11 @@ import pytest from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner +from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from models.workflow import Workflow @@ -105,3 +107,57 @@ def test_run_uses_single_node_execution_branch( assert entry_kwargs["invoke_from"] == InvokeFrom.DEBUGGER assert entry_kwargs["variable_pool"] is variable_pool assert entry_kwargs["graph_runtime_state"] is graph_runtime_state + + +def test_single_node_run_validates_target_node_config(monkeypatch) -> None: + runner = WorkflowBasedAppRunner( + queue_manager=MagicMock(spec=AppQueueManager), + variable_loader=MagicMock(), + app_id="app", + ) + + workflow = MagicMock(spec=Workflow) + workflow.id = "workflow" + workflow.tenant_id = "tenant" + workflow.graph_dict = { + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + } + ], + "edges": [], + } + + _, _, graph_runtime_state = _make_graph_state() + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + with ( + patch("core.app.apps.workflow_app_runner.DifyNodeFactory"), + patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()), + patch("core.app.apps.workflow_app_runner.load_into_variable_pool"), + patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index c30b925d88..65c6bd6654 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,12 +10,12 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph_events.graph import GraphRunPausedEvent -from core.workflow.nodes.human_input.entities import FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph_events.graph import GraphRunPausedEvent +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType +from dify_graph.system_variable import SystemVariable from models.account import Account diff --git a/api/tests/unit_tests/core/app/apps/workflow/__init__.py b/api/tests/unit_tests/core/app/apps/workflow/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py new file mode 100644 index 0000000000..f8dd6bf609 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_config_manager.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager +from models.model import AppMode + + +class TestWorkflowAppConfigManager: + def test_get_app_config(self): + app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW.value) + workflow = SimpleNamespace(id="wf-1", features_dict={}) + + with ( + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.convert", + return_value=None, + ), + patch( + "core.app.apps.workflow.app_config_manager.WorkflowVariablesConfigManager.convert", + return_value=[], + ), + ): + app_config = WorkflowAppConfigManager.get_app_config(app_model, workflow) + + assert app_config.workflow_id == "wf-1" + assert app_config.app_mode == AppMode.WORKFLOW + + def test_config_validate_filters_keys(self): + def _add_key(key, value): + def _inner(*args, **kwargs): + # Support both positional and keyword arguments for config + if "config" in kwargs: + config = kwargs["config"] + elif len(args) > 0: + config = args[0] + else: + config = {} + config[key] = value + return config, [key] + + return _inner + + with ( + patch( + "core.app.apps.workflow.app_config_manager.FileUploadConfigManager.validate_and_set_defaults", + side_effect=_add_key("file_upload", 1), + ), + patch( + "core.app.apps.workflow.app_config_manager.TextToSpeechConfigManager.validate_and_set_defaults", + side_effect=_add_key("text_to_speech", 2), + ), + patch( + "core.app.apps.workflow.app_config_manager.SensitiveWordAvoidanceConfigManager.validate_and_set_defaults", + side_effect=_add_key("sensitive_word_avoidance", 3), + ), + ): + filtered = WorkflowAppConfigManager.config_validate(tenant_id="t1", config={}) + + assert filtered["file_upload"] == 1 + assert filtered["text_to_speech"] == 2 + assert filtered["sensitive_word_avoidance"] == 3 diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py new file mode 100644 index 0000000000..09ad078a70 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_generator_extra.py @@ -0,0 +1,188 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.ops.ops_trace_manager import TraceQueueManager +from models.model import AppMode + + +class TestWorkflowAppGeneratorValidation: + def test_should_prepare_user_inputs(self): + generator = WorkflowAppGenerator() + + assert generator._should_prepare_user_inputs({}) is True + assert generator._should_prepare_user_inputs({SKIP_PREPARE_USER_INPUTS_KEY: True}) is False + + def test_single_iteration_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args={"inputs": {}}, + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_iteration_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args={}, + streaming=False, + ) + + def test_single_loop_generate_validates_args(self): + generator = WorkflowAppGenerator() + + with pytest.raises(ValueError, match="node_id is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="", + user=SimpleNamespace(), + args=SimpleNamespace(inputs={}), + streaming=False, + ) + + with pytest.raises(ValueError, match="inputs is required"): + generator.single_loop_generate( + app_model=SimpleNamespace(), + workflow=SimpleNamespace(), + node_id="node", + user=SimpleNamespace(), + args=SimpleNamespace(inputs=None), + streaming=False, + ) + + +class TestWorkflowAppGeneratorHandleResponse: + def test_handle_response_closed_file_raises_stopped(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + extras={}, + trace_manager=None, + workflow_execution_id="run-id", + call_depth=0, + ) + + class _Pipeline: + def __init__(self, **kwargs) -> None: + _ = kwargs + + def process(self): + raise ValueError("I/O operation on closed file.") + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppGenerateTaskPipeline", + _Pipeline, + ) + + with pytest.raises(GenerateTaskStoppedError): + generator._handle_response( + application_generate_entity=application_generate_entity, + workflow=SimpleNamespace(), + queue_manager=SimpleNamespace(), + user=SimpleNamespace(), + draft_var_saver_factory=lambda **kwargs: None, + stream=False, + ) + + +class TestWorkflowAppGeneratorGenerate: + def test_generate_skips_prepare_inputs_when_flag_set(self, monkeypatch): + generator = WorkflowAppGenerator() + + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.WorkflowAppConfigManager.get_app_config", + lambda app_model, workflow: app_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.FileUploadConfigManager.convert", + lambda features_dict, is_vision=False: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.file_factory.build_from_mappings", + lambda **kwargs: [], + ) + DummyTraceQueueManager = type( + "_DummyTraceQueueManager", + (TraceQueueManager,), + { + "__init__": lambda self, app_id=None, user_id=None: ( + setattr(self, "app_id", app_id) or setattr(self, "user_id", user_id) + ) + }, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.TraceQueueManager", + DummyTraceQueueManager, + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + lambda **kwargs: SimpleNamespace(), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.db", + SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.app_generator.sessionmaker", + lambda **kwargs: SimpleNamespace(), + ) + + prepare_inputs = pytest.fail + monkeypatch.setattr(generator, "_prepare_user_inputs", lambda **kwargs: prepare_inputs()) + + monkeypatch.setattr(generator, "_generate", lambda **kwargs: {"ok": True}) + + result = generator.generate( + app_model=SimpleNamespace(id="app", tenant_id="tenant"), + workflow=SimpleNamespace(features_dict={}), + user=SimpleNamespace(id="user", session_id="session"), + args={"inputs": {}, SKIP_PREPARE_USER_INPUTS_KEY: True}, + invoke_from=InvokeFrom.WEB_APP, + streaming=False, + call_depth=0, + ) + + assert result == {"ok": True} diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py new file mode 100644 index 0000000000..6133be9867 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_app_queue_manager.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +import pytest + +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager +from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.queue_entities import QueueMessageEndEvent, QueuePingEvent + + +class TestWorkflowAppQueueManager: + def test_publish_stop_events_trigger_stop(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + manager._is_stopped = lambda: True + + with pytest.raises(GenerateTaskStoppedError): + manager._publish(QueueMessageEndEvent(llm_result=None), PublishFrom.APPLICATION_MANAGER) + + def test_publish_non_stop_event_does_not_raise(self): + manager = WorkflowAppQueueManager( + task_id="task", + user_id="user", + invoke_from=InvokeFrom.DEBUGGER, + app_mode="workflow", + ) + + manager._publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_errors.py b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py new file mode 100644 index 0000000000..7461e06833 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_errors.py @@ -0,0 +1,9 @@ +from core.app.apps.workflow.errors import WorkflowPausedInBlockingModeError + + +class TestWorkflowErrors: + def test_workflow_paused_in_blocking_mode_error_attributes(self): + err = WorkflowPausedInBlockingModeError() + assert err.error_code == "workflow_paused_in_blocking_mode" + assert err.code == 400 + assert "blocking response mode" in err.description diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py new file mode 100644 index 0000000000..62e94a7580 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -0,0 +1,133 @@ +from collections.abc import Generator + +from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.entities.task_entities import ( + ErrorStreamResponse, + NodeFinishStreamResponse, + NodeStartStreamResponse, + PingStreamResponse, + WorkflowAppBlockingResponse, + WorkflowAppStreamResponse, +) +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + + +class TestWorkflowGenerateResponseConverter: + def test_blocking_full_response(self): + blocking = WorkflowAppBlockingResponse( + task_id="t1", + workflow_run_id="r1", + data=WorkflowAppBlockingResponse.Data( + id="exec-1", + workflow_id="wf-1", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.2, + total_tokens=10, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + response = WorkflowAppGenerateResponseConverter.convert_blocking_full_response(blocking) + assert response["workflow_run_id"] == "r1" + + def test_stream_simple_response_node_events(self): + node_start = NodeStartStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeStartStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + created_at=1, + ), + ) + node_finish = NodeFinishStreamResponse( + task_id="t1", + workflow_run_id="r1", + data=NodeFinishStreamResponse.Data( + id="e1", + node_id="n1", + node_type="answer", + title="Answer", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + elapsed_time=0.1, + created_at=1, + finished_at=2, + ), + ) + + def stream() -> Generator[WorkflowAppStreamResponse, None, None]: + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=PingStreamResponse(task_id="t1")) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_start) + yield WorkflowAppStreamResponse(workflow_run_id="r1", stream_response=node_finish) + yield WorkflowAppStreamResponse( + workflow_run_id="r1", stream_response=ErrorStreamResponse(task_id="t1", err=ValueError("boom")) + ) + + converted = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(stream())) + assert converted[0] == "ping" + assert converted[1]["event"] == "node_started" + assert converted[2]["event"] == "node_finished" + assert converted[3]["event"] == "error" + + def test_convert_stream_simple_response_handles_ping_and_nodes(self): + def _gen(): + yield WorkflowAppStreamResponse(stream_response=PingStreamResponse(task_id="task")) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeStartStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeStartStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + created_at=1, + ), + ), + ) + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=NodeFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=NodeFinishStreamResponse.Data( + id="node-exec", + node_id="node", + node_type="start", + title="Start", + index=1, + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={}, + created_at=1, + finished_at=2, + elapsed_time=1.0, + error=None, + ), + ), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_simple_response(_gen())) + + assert chunks[0] == "ping" + assert chunks[1]["event"] == "node_started" + assert chunks[2]["event"] == "node_finished" + + def test_convert_stream_full_response_handles_error(self): + def _gen(): + yield WorkflowAppStreamResponse( + workflow_run_id="run", + stream_response=ErrorStreamResponse(task_id="task", err=ValueError("boom")), + ) + + chunks = list(WorkflowAppGenerateResponseConverter.convert_stream_full_response(_gen())) + + assert chunks[0]["event"] == "error" diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 32cb1ed47c..5b23e71035 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,9 +7,9 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from models.account import Account from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py new file mode 100644 index 0000000000..f35710d207 --- /dev/null +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -0,0 +1,868 @@ +from __future__ import annotations + +from contextlib import contextmanager +from datetime import datetime +from types import SimpleNamespace + +import pytest + +from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueErrorEvent, + QueueHumanInputFormFilledEvent, + QueueHumanInputFormTimeoutEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueuePingEvent, + QueueStopEvent, + QueueTextChunkEvent, + QueueWorkflowFailedEvent, + QueueWorkflowPartialSuccessEvent, + QueueWorkflowPausedEvent, + QueueWorkflowStartedEvent, + QueueWorkflowSucceededEvent, +) +from core.app.entities.task_entities import ( + ErrorStreamResponse, + MessageAudioEndStreamResponse, + MessageAudioStreamResponse, + PingStreamResponse, + WorkflowFinishStreamResponse, + WorkflowPauseStreamResponse, + WorkflowStartStreamResponse, +) +from core.base.tts.app_generator_tts_publisher import AudioTrunk +from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from models.enums import CreatorUserRole +from models.model import AppMode, EndUser + + +def _make_pipeline(): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="user", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + user = SimpleNamespace(id="user", session_id="session") + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None), + user=user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + return pipeline + + +class TestWorkflowGenerateTaskPipeline: + def test_to_blocking_response_handles_pause(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowPauseStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowPauseStreamResponse.Data( + workflow_run_id="run", + status=WorkflowExecutionStatus.PAUSED, + outputs={}, + created_at=1, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.status == WorkflowExecutionStatus.PAUSED + + def test_to_blocking_response_handles_finish(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run", + data=WorkflowFinishStreamResponse.Data( + id="run", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={"ok": True}, + error=None, + elapsed_time=1.0, + total_tokens=5, + total_steps=2, + created_at=1, + finished_at=2, + ), + ) + + response = pipeline._to_blocking_response(_gen()) + + assert response.data.outputs == {"ok": True} + + def test_listen_audio_msg_returns_audio_stream(self): + pipeline = _make_pipeline() + publisher = SimpleNamespace(check_and_get_audio=lambda: AudioTrunk(status="stream", audio="data")) + + response = pipeline._listen_audio_msg(publisher=publisher, task_id="task") + + assert isinstance(response, MessageAudioStreamResponse) + + def test_handle_ping_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.ping_stream_response = lambda: PingStreamResponse(task_id="task") + + responses = list(pipeline._handle_ping_event(QueuePingEvent())) + + assert isinstance(responses[0], PingStreamResponse) + + def test_handle_error_event(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list(pipeline._handle_error_event(QueueErrorEvent(error=ValueError("boom")))) + + assert isinstance(responses[0], ValueError) + + def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" + + @contextmanager + def _fake_session(): + yield SimpleNamespace() + + monkeypatch.setattr(pipeline, "_database_session", _fake_session) + monkeypatch.setattr(pipeline, "_save_workflow_app_log", lambda **kwargs: None) + + responses = list(pipeline._handle_workflow_started_event(QueueWorkflowStartedEvent())) + + assert pipeline._workflow_execution_id == "run-id" + assert responses == ["started"] + + def test_handle_node_succeeded_event_saves_output(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + pipeline._save_output_for_event = lambda event, node_execution_id: None + pipeline._workflow_execution_id = "run-id" + + event = QueueNodeSucceededEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + ) + + responses = list(pipeline._handle_node_succeeded_event(event)) + + assert responses == ["done"] + + def test_handle_workflow_failed_event_yields_error(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + pipeline._base_task_pipeline.handle_error = lambda **kwargs: ValueError("boom") + pipeline._base_task_pipeline.error_to_stream_response = lambda err: err + + responses = list( + pipeline._handle_workflow_failed_and_stop_events(QueueWorkflowFailedEvent(error="fail", exceptions_count=1)) + ) + + assert responses[0] == "finish" + + def test_handle_text_chunk_event_publishes_tts(self): + pipeline = _make_pipeline() + published: list[object] = [] + + class _Publisher: + def publish(self, message): + published.append(message) + + event = QueueTextChunkEvent(text="hi", from_variable_selector=["x"]) + queue_message = SimpleNamespace(event=event) + + responses = list( + pipeline._handle_text_chunk_event(event, tts_publisher=_Publisher(), queue_message=queue_message) + ) + + assert responses[0].data.text == "hi" + assert published == [queue_message] + + def test_dispatch_event_handles_node_failed(self): + pipeline = _make_pipeline() + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "done" + + event = QueueNodeFailedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="err", + ) + + assert list(pipeline._dispatch_event(event)) == ["done"] + + def test_handle_stop_event_yields_finish(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + + responses = list( + pipeline._handle_workflow_failed_and_stop_events( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL) + ) + ) + + assert responses == ["finish"] + + def test_save_workflow_app_log_created_from(self): + pipeline = _make_pipeline() + pipeline._application_generate_entity.invoke_from = InvokeFrom.SERVICE_API + pipeline._user_id = "user" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + + assert added + + def test_iteration_loop_and_human_input_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._workflow_response_converter.workflow_iteration_start_to_stream_response = lambda **kwargs: "iter" + pipeline._workflow_response_converter.workflow_iteration_next_to_stream_response = lambda **kwargs: "next" + pipeline._workflow_response_converter.workflow_iteration_completed_to_stream_response = lambda **kwargs: "done" + pipeline._workflow_response_converter.workflow_loop_start_to_stream_response = lambda **kwargs: "loop" + pipeline._workflow_response_converter.workflow_loop_next_to_stream_response = lambda **kwargs: "loop_next" + pipeline._workflow_response_converter.workflow_loop_completed_to_stream_response = lambda **kwargs: "loop_done" + pipeline._workflow_response_converter.human_input_form_filled_to_stream_response = lambda **kwargs: "filled" + pipeline._workflow_response_converter.human_input_form_timeout_to_stream_response = lambda **kwargs: "timeout" + pipeline._workflow_response_converter.handle_agent_log = lambda **kwargs: "log" + + iter_start = QueueIterationStartEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + iter_next = QueueIterationNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + node_run_index=1, + ) + iter_done = QueueIterationCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_start = QueueLoopStartEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + loop_next = QueueLoopNextEvent( + index=1, + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + node_run_index=1, + ) + loop_done = QueueLoopCompletedEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=datetime.utcnow(), + node_run_index=1, + ) + filled_event = QueueHumanInputFormFilledEvent( + node_execution_id="exec", + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="title", + rendered_content="content", + action_id="action", + action_text="action", + ) + timeout_event = QueueHumanInputFormTimeoutEvent( + node_id="node", + node_type=BuiltinNodeTypes.LLM, + node_title="title", + expiration_time=datetime.utcnow(), + ) + agent_event = QueueAgentLogEvent( + id="log", + label="label", + node_execution_id="exec", + parent_id=None, + error=None, + status="done", + data={}, + metadata={}, + node_id="node", + ) + + assert list(pipeline._handle_iteration_start_event(iter_start)) == ["iter"] + assert list(pipeline._handle_iteration_next_event(iter_next)) == ["next"] + assert list(pipeline._handle_iteration_completed_event(iter_done)) == ["done"] + assert list(pipeline._handle_loop_start_event(loop_start)) == ["loop"] + assert list(pipeline._handle_loop_next_event(loop_next)) == ["loop_next"] + assert list(pipeline._handle_loop_completed_event(loop_done)) == ["loop_done"] + assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] + assert list(pipeline._handle_human_input_form_timeout_event(timeout_event)) == ["timeout"] + assert list(pipeline._handle_agent_log_event(agent_event)) == ["log"] + + def test_wrapper_process_stream_response_emits_audio_end(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return AudioTrunk(status="stream", audio="data") + if self.calls == 2: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + return None + + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert any(isinstance(item, MessageAudioStreamResponse) for item in responses) + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_init_with_end_user_sets_role_and_system_user(self): + app_config = WorkflowUIBasedAppConfig( + tenant_id="tenant", + app_id="app", + app_mode=AppMode.WORKFLOW, + additional_features=AppAdditionalFeatures(), + variables=[], + workflow_id="workflow-id", + ) + application_generate_entity = WorkflowAppGenerateEntity.model_construct( + task_id="task", + app_config=app_config, + inputs={}, + files=[], + user_id="end-user-id", + stream=False, + invoke_from=InvokeFrom.WEB_APP, + trace_manager=None, + workflow_execution_id="run-id", + extras={}, + call_depth=0, + ) + workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + queue_manager = SimpleNamespace(invoke_from=InvokeFrom.WEB_APP, graph_runtime_state=None) + end_user = EndUser(tenant_id="tenant", type="session", name="user", session_id="session-id") + end_user.id = "end-user-id" + + pipeline = WorkflowAppGenerateTaskPipeline( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=end_user, + stream=False, + draft_var_saver_factory=lambda **kwargs: None, + ) + + assert pipeline._created_by_role == CreatorUserRole.END_USER + assert pipeline._workflow_system_variables.user_id == "session-id" + + def test_process_returns_stream_and_blocking_variants(self): + pipeline = _make_pipeline() + pipeline._base_task_pipeline.stream = True + pipeline._wrapper_process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + stream_response = list(pipeline.process()) + assert len(stream_response) == 1 + assert stream_response[0].workflow_run_id is None + + pipeline._base_task_pipeline.stream = False + pipeline._wrapper_process_stream_response = lambda **kwargs: iter( + [ + WorkflowFinishStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowFinishStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + status=WorkflowExecutionStatus.SUCCEEDED, + outputs={}, + error=None, + elapsed_time=0.1, + total_tokens=0, + total_steps=0, + created_at=1, + finished_at=2, + ), + ) + ] + ) + + blocking_response = pipeline.process() + assert blocking_response.workflow_run_id == "run-id" + + def test_to_blocking_response_handles_error_and_unexpected_end(self): + pipeline = _make_pipeline() + + def _error_gen(): + yield ErrorStreamResponse(task_id="task", err=ValueError("boom")) + + with pytest.raises(ValueError, match="boom"): + pipeline._to_blocking_response(_error_gen()) + + def _unexpected_gen(): + yield PingStreamResponse(task_id="task") + + with pytest.raises(ValueError, match="queue listening stopped unexpectedly"): + pipeline._to_blocking_response(_unexpected_gen()) + + def test_to_stream_response_tracks_workflow_run_id(self): + pipeline = _make_pipeline() + + def _gen(): + yield WorkflowStartStreamResponse( + task_id="task", + workflow_run_id="run-id", + data=WorkflowStartStreamResponse.Data( + id="run-id", + workflow_id="workflow-id", + inputs={}, + created_at=1, + ), + ) + yield PingStreamResponse(task_id="task") + + stream_responses = list(pipeline._to_stream_response(_gen())) + assert stream_responses[0].workflow_run_id == "run-id" + assert stream_responses[1].workflow_run_id == "run-id" + + def test_listen_audio_msg_returns_none_without_publisher(self): + pipeline = _make_pipeline() + assert pipeline._listen_audio_msg(publisher=None, task_id="task") is None + + def test_wrapper_process_stream_response_without_tts(self): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = {} + pipeline._process_stream_response = lambda **kwargs: iter([PingStreamResponse(task_id="task")]) + + responses = list(pipeline._wrapper_process_stream_response()) + assert responses == [PingStreamResponse(task_id="task")] + + def test_wrapper_process_stream_response_final_audio_none_then_finish(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + sleep_spy = [] + + class _Publisher: + def __init__(self, *args, **kwargs): + self.calls = 0 + + def check_and_get_audio(self): + self.calls += 1 + if self.calls == 1: + return None + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + time_values = iter([0.0, 0.0, 0.2]) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: next(time_values)) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.time.sleep", lambda _: sleep_spy.append(True) + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert sleep_spy + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_wrapper_process_stream_response_handles_audio_exception(self, monkeypatch): + pipeline = _make_pipeline() + pipeline._workflow_features_dict = { + "text_to_speech": {"enabled": True, "autoPlay": "enabled", "voice": "v", "language": "en"} + } + pipeline._process_stream_response = lambda **kwargs: iter([]) + + class _Publisher: + def __init__(self, *args, **kwargs): + self.called = False + + def check_and_get_audio(self): + if not self.called: + self.called = True + raise RuntimeError("tts failure") + return AudioTrunk(status="finish", audio="") + + def publish(self, message): + _ = message + + logger_exception = [] + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.time.time", lambda: 0.0) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.logger.exception", + lambda *args, **kwargs: logger_exception.append((args, kwargs)), + ) + monkeypatch.setattr( + "core.app.apps.workflow.generate_task_pipeline.AppGeneratorTTSPublisher", + _Publisher, + ) + + responses = list(pipeline._wrapper_process_stream_response()) + + assert logger_exception + assert any(isinstance(item, MessageAudioEndStreamResponse) for item in responses) + + def test_database_session_rolls_back_on_error(self, monkeypatch): + pipeline = _make_pipeline() + calls = {"commit": 0, "rollback": 0} + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def commit(self): + calls["commit"] += 1 + + def rollback(self): + calls["rollback"] += 1 + + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + with pytest.raises(RuntimeError, match="db error"): + with pipeline._database_session(): + raise RuntimeError("db error") + + assert calls["commit"] == 0 + assert calls["rollback"] == 1 + + def test_node_retry_and_started_handlers_cover_none_and_value(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + + retry_event = QueueNodeRetryEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=BuiltinNodeTypes.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + error="error", + retry_index=1, + ) + started_event = QueueNodeStartedEvent( + node_execution_id="exec", + node_id="node", + node_title="title", + node_type=BuiltinNodeTypes.LLM, + node_run_index=1, + start_at=datetime.utcnow(), + provider_type="provider", + provider_id="provider-id", + ) + + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_retry_event(retry_event)) == [] + pipeline._workflow_response_converter.workflow_node_retry_to_stream_response = lambda **kwargs: "retry" + assert list(pipeline._handle_node_retry_event(retry_event)) == ["retry"] + + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: None + assert list(pipeline._handle_node_started_event(started_event)) == [] + pipeline._workflow_response_converter.workflow_node_start_to_stream_response = lambda **kwargs: "started" + assert list(pipeline._handle_node_started_event(started_event)) == ["started"] + + def test_handle_node_exception_event_saves_output(self): + pipeline = _make_pipeline() + saved_ids: list[str] = [] + pipeline._workflow_response_converter.workflow_node_finish_to_stream_response = lambda **kwargs: "failed" + pipeline._save_output_for_event = lambda event, node_execution_id: saved_ids.append(node_execution_id) + + event = QueueNodeExceptionEvent( + node_execution_id="exec-id", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=datetime.utcnow(), + inputs={}, + outputs={}, + process_data={}, + error="boom", + ) + + responses = list(pipeline._handle_node_failed_events(event)) + assert responses == ["failed"] + assert saved_ids == ["exec-id"] + + def test_success_partial_and_pause_handlers(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + + pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" + assert list(pipeline._handle_workflow_succeeded_event(QueueWorkflowSucceededEvent(outputs={}))) == ["finish"] + assert list( + pipeline._handle_workflow_partial_success_event( + QueueWorkflowPartialSuccessEvent(exceptions_count=2, outputs={}) + ) + ) == ["finish"] + + pipeline._workflow_response_converter.workflow_pause_to_stream_response = lambda **kwargs: [ + "pause-a", + "pause-b", + ] + pause_event = QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=["node"]) + assert list(pipeline._handle_workflow_paused_event(pause_event)) == ["pause-a", "pause-b"] + + def test_text_chunk_handler_returns_empty_when_text_missing(self): + pipeline = _make_pipeline() + event = QueueTextChunkEvent.model_construct(text=None, from_variable_selector=None) + assert list(pipeline._handle_text_chunk_event(event)) == [] + + def test_dispatch_event_direct_failed_and_unhandled_paths(self): + pipeline = _make_pipeline() + pipeline._workflow_execution_id = "run-id" + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) + assert list(pipeline._dispatch_event(QueuePingEvent())) == ["ping"] + + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["workflow-failed"]) + assert list(pipeline._dispatch_event(QueueWorkflowFailedEvent(error="failed", exceptions_count=1))) == [ + "workflow-failed" + ] + + assert list(pipeline._dispatch_event(SimpleNamespace())) == [] + + def test_process_stream_response_main_match_paths_and_cleanup(self): + pipeline = _make_pipeline() + pipeline._graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + start_at=0.0, + ) + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [ + SimpleNamespace(event=QueueWorkflowStartedEvent()), + SimpleNamespace(event=QueueTextChunkEvent(text="hello")), + SimpleNamespace(event=QueuePingEvent()), + SimpleNamespace(event=QueueErrorEvent(error="e")), + ] + ) + pipeline._handle_workflow_started_event = lambda event, **kwargs: iter(["started"]) + pipeline._handle_text_chunk_event = lambda event, **kwargs: iter(["text"]) + pipeline._dispatch_event = lambda event, **kwargs: iter(["dispatched"]) + pipeline._handle_error_event = lambda event, **kwargs: iter(["error"]) + publisher_calls: list[object] = [] + + class _Publisher: + def publish(self, message): + publisher_calls.append(message) + + responses = list(pipeline._process_stream_response(tts_publisher=_Publisher())) + assert responses == ["started", "text", "dispatched", "error"] + assert publisher_calls == [None] + + def test_process_stream_response_break_paths(self): + pipeline = _make_pipeline() + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowFailedEvent(error="fail", exceptions_count=1))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["failed"]) + assert list(pipeline._process_stream_response()) == ["failed"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueWorkflowPausedEvent(reasons=[], outputs={}, paused_nodes=[]))] + ) + pipeline._handle_workflow_paused_event = lambda event, **kwargs: iter(["paused"]) + assert list(pipeline._process_stream_response()) == ["paused"] + + pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( + [SimpleNamespace(event=QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL))] + ) + pipeline._handle_workflow_failed_and_stop_events = lambda event, **kwargs: iter(["stopped"]) + assert list(pipeline._process_stream_response()) == ["stopped"] + + def test_save_workflow_app_log_covers_invoke_from_variants(self): + pipeline = _make_pipeline() + pipeline._user_id = "user-id" + added: list[object] = [] + + class _Session: + def add(self, item): + added.append(item) + + pipeline._application_generate_entity.invoke_from = InvokeFrom.EXPLORE + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "installed-app" + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert added[-1].created_from == "web-app" + + count_before = len(added) + pipeline._application_generate_entity.invoke_from = InvokeFrom.DEBUGGER + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id="run-id") + assert len(added) == count_before + + pipeline._application_generate_entity.invoke_from = InvokeFrom.WEB_APP + pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) + assert len(added) == count_before + + def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + pipeline = _make_pipeline() + saver_calls: list[tuple[object, object]] = [] + captured_factory_args: dict[str, object] = {} + + class _Saver: + def save(self, process_data, outputs): + saver_calls.append((process_data, outputs)) + + def _factory(**kwargs): + captured_factory_args.update(kwargs) + return _Saver() + + class _Begin: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb): + return False + + class _Session: + def __init__(self, *args, **kwargs): + _ = args, kwargs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return _Begin() + + pipeline._draft_var_saver_factory = _factory + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) + monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) + + event = QueueNodeSucceededEvent( + node_execution_id="exec-id", + node_id="node-id", + node_type=BuiltinNodeTypes.START, + in_loop_id="loop-id", + start_at=datetime.utcnow(), + process_data={"k": "v"}, + outputs={"out": 1}, + ) + pipeline._save_output_for_event(event=event, node_execution_id="exec-id") + + assert captured_factory_args["node_execution_id"] == "exec-id" + assert captured_factory_args["enclosing_node_id"] == "loop-id" + assert saver_calls == [({"k": "v"}, {"out": 1})] diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index d3ae577d0d..bdc889d941 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,16 +3,16 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.protocols.command_channel import CommandChannel -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import StringVariable -from core.workflow.variables.segments import Segment +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.protocols.command_channel import CommandChannel +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import StringVariable +from dify_graph.variables.segments import Segment class MockReadOnlyVariablePool: @@ -78,7 +78,7 @@ def test_persists_conversation_variables_from_assigner_output(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) @@ -100,7 +100,7 @@ def test_skips_when_outputs_missing(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) layer.on_event(event) updater.update.assert_not_called() @@ -112,7 +112,7 @@ def test_skips_non_assigner_nodes(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.LLM) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) layer.on_event(event) updater.update.assert_not_called() @@ -137,7 +137,7 @@ def test_skips_non_conversation_variables(): layer = ConversationVariablePersistenceLayer(updater) layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=NodeType.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) layer.on_event(event) updater.update.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 539f0cb581..035f0ee05c 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,17 +13,17 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.graph_engine.entities.commands import GraphEngineCommand -from core.workflow.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from core.workflow.graph_events.graph import ( +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.graph_engine.entities.commands import GraphEngineCommand +from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from dify_graph.graph_events.graph import ( GraphRunFailedEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from core.workflow.variables.segments import Segment +from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool +from dify_graph.variables.segments import Segment from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 40f58c9ddf..13fbca6e26 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -25,9 +25,9 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher -from core.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from core.model_runtime.entities.message_entities import TextPromptMessageContent from core.ops.ops_trace_manager import TraceQueueManager +from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py new file mode 100644 index 0000000000..582990c88a --- /dev/null +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -0,0 +1,425 @@ +""" +Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method. + +This test suite ensures that the files array is correctly populated in the message_end +SSE event, which is critical for vision/image chat responses to render correctly. + +Test Coverage: +- Files array populated when MessageFile records exist +- Files array is None when no MessageFile records exist +- Correct signed URL generation for LOCAL_FILE transfer method +- Correct URL handling for REMOTE_URL transfer method +- Correct URL handling for TOOL_FILE transfer method +- Proper file metadata formatting (filename, mime_type, size, extension) +""" + +import uuid +from unittest.mock import MagicMock, Mock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.app.entities.task_entities import MessageEndStreamResponse +from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline +from dify_graph.file.enums import FileTransferMethod +from models.model import MessageFile, UploadFile + + +class TestMessageEndStreamResponseFiles: + """Test suite for files array population in message_end SSE event.""" + + @pytest.fixture + def mock_pipeline(self): + """Create a mock EasyUIBasedGenerateTaskPipeline instance.""" + pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline) + pipeline._message_id = str(uuid.uuid4()) + pipeline._task_state = Mock() + pipeline._task_state.metadata = Mock() + pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"}) + pipeline._task_state.llm_result = Mock() + pipeline._task_state.llm_result.usage = Mock() + pipeline._application_generate_entity = Mock() + pipeline._application_generate_entity.task_id = str(uuid.uuid4()) + return pipeline + + @pytest.fixture + def mock_message_file_local(self): + """Create a mock MessageFile with LOCAL_FILE transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.LOCAL_FILE + message_file.upload_file_id = str(uuid.uuid4()) + message_file.url = None + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_message_file_remote(self): + """Create a mock MessageFile with REMOTE_URL transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.REMOTE_URL + message_file.upload_file_id = None + message_file.url = "https://example.com/image.jpg" + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_message_file_tool(self): + """Create a mock MessageFile with TOOL_FILE transfer method.""" + message_file = Mock(spec=MessageFile) + message_file.id = str(uuid.uuid4()) + message_file.message_id = str(uuid.uuid4()) + message_file.transfer_method = FileTransferMethod.TOOL_FILE + message_file.upload_file_id = None + message_file.url = "tool_file_123.png" + message_file.type = "image" + return message_file + + @pytest.fixture + def mock_upload_file(self, mock_message_file_local): + """Create a mock UploadFile.""" + upload_file = Mock(spec=UploadFile) + upload_file.id = mock_message_file_local.upload_file_id + upload_file.name = "test_image.png" + upload_file.mime_type = "image/png" + upload_file.size = 1024 + upload_file.extension = "png" + return upload_file + + def test_message_end_with_no_files(self, mock_pipeline): + """Test that files array is None when no MessageFile records exist.""" + # Arrange + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is None + assert result.id == mock_pipeline._message_id + assert result.metadata == {"test": "metadata"} + + def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file): + """Test that files array is populated correctly for LOCAL_FILE transfer method.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local] + + # Second query: UploadFile (batch query to avoid N+1) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [mock_upload_file] + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/signed-url?signature=abc123" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["related_id"] == mock_message_file_local.id + assert file_dict["filename"] == "test_image.png" + assert file_dict["mime_type"] == "image/png" + assert file_dict["size"] == 1024 + assert file_dict["extension"] == ".png" + assert file_dict["type"] == "image" + assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value + assert "https://example.com/signed-url" in file_dict["url"] + assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id + assert file_dict["remote_url"] == "" + + # Verify database queries + # Should be called twice: once for MessageFile, once for UploadFile + assert mock_session.scalars.call_count == 2 + mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id)) + + def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote): + """Test that files array is populated correctly for REMOTE_URL transfer method.""" + # Arrange + mock_message_file_remote.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_remote] + mock_session.scalars.return_value = mock_scalars_result + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["related_id"] == mock_message_file_remote.id + assert file_dict["filename"] == "image.jpg" + assert file_dict["url"] == "https://example.com/image.jpg" + assert file_dict["extension"] == ".jpg" + assert file_dict["type"] == "image" + assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value + assert file_dict["remote_url"] == "https://example.com/image.jpg" + assert file_dict["upload_file_id"] == mock_message_file_remote.id + + # Verify only one query for message_files is made + mock_session.scalars.assert_called_once() + + def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool): + """Test that files array is populated correctly for TOOL_FILE with HTTP URL.""" + # Arrange + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "https://example.com/tool_file.png" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert file_dict["url"] == "https://example.com/tool_file.png" + assert file_dict["filename"] == "tool_file.png" + assert file_dict["extension"] == ".png" + assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value + + def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool): + """Test that files array is populated correctly for TOOL_FILE with local path.""" + # Arrange + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "tool_file_123.png" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + + mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert "https://example.com/signed-tool-file.png" in file_dict["url"] + assert file_dict["filename"] == "tool_file_123.png" + assert file_dict["extension"] == ".png" + assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value + + # Verify tool file signing was called + mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png") + + def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool): + """Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin.""" + mock_message_file_tool.message_id = mock_pipeline._message_id + mock_message_file_tool.url = "tool_file_abc.verylongextension" + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + mock_scalars_result = Mock() + mock_scalars_result.all.return_value = [mock_message_file_tool] + mock_session.scalars.return_value = mock_scalars_result + mock_sign_tool.return_value = "https://example.com/signed.bin" + + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + assert result.files is not None + file_dict = result.files[0] + assert file_dict["extension"] == ".bin" + mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin") + + def test_message_end_with_multiple_files( + self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file + ): + """Test that files array contains all MessageFile records when multiple exist.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + mock_message_file_remote.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote] + + # Second query: UploadFile (batch query to avoid N+1) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [mock_upload_file] + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/signed-url?signature=abc123" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 2 + + # Verify both files are present + file_ids = [f["related_id"] for f in result.files] + assert mock_message_file_local.id in file_ids + assert mock_message_file_remote.id in file_ids + + def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local): + """Test fallback when UploadFile is not found for LOCAL_FILE.""" + # Arrange + mock_message_file_local.message_id = mock_pipeline._message_id + + with ( + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db, + patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class, + patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url, + ): + mock_engine = MagicMock() + mock_db.engine = mock_engine + + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + # Mock database queries + # First query: MessageFile + mock_message_files_result = Mock() + mock_message_files_result.all.return_value = [mock_message_file_local] + + # Second query: UploadFile (batch query) - returns empty list (not found) + mock_upload_files_result = Mock() + mock_upload_files_result.all.return_value = [] # UploadFile not found + + # Setup scalars to return different results for different queries + call_count = [0] # Use list to allow modification in nested function + + def scalars_side_effect(query): + call_count[0] += 1 + # First call is for MessageFile, second call is for UploadFile + if call_count[0] == 1: + return mock_message_files_result + else: + return mock_upload_files_result + + mock_session.scalars.side_effect = scalars_side_effect + mock_get_url.return_value = "https://example.com/fallback-url?signature=def456" + + # Act + result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline) + + # Assert + assert isinstance(result, MessageEndStreamResponse) + assert result.files is not None + assert len(result.files) == 1 + + file_dict = result.files[0] + assert "https://example.com/fallback-url" in file_dict["url"] + # Verify fallback URL was generated using upload_file_id from message_file + mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id)) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py new file mode 100644 index 0000000000..0f8a846d11 --- /dev/null +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -0,0 +1,60 @@ +from datetime import UTC, datetime +from unittest.mock import Mock + +import pytest + +from core.app.workflow.layers.persistence import ( + PersistenceWorkflowInfo, + WorkflowPersistenceLayer, + _NodeRuntimeSnapshot, +) +from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType +from dify_graph.node_events import NodeRunResult + + +def _build_layer() -> WorkflowPersistenceLayer: + application_generate_entity = Mock() + application_generate_entity.inputs = {} + + return WorkflowPersistenceLayer( + application_generate_entity=application_generate_entity, + workflow_info=PersistenceWorkflowInfo( + workflow_id="workflow-id", + workflow_type=WorkflowType.WORKFLOW, + version="1", + graph_data={}, + ), + workflow_execution_repository=Mock(), + workflow_node_execution_repository=Mock(), + ) + + +def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-1" + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="node-id", + title="LLM", + predecessor_node_id=None, + iteration_id="iter-1", + loop_id=None, + created_at=node_execution.created_at, + ) + + event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None) + delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None) + monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time) + + layer._update_node_execution( + node_execution, + NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), + WorkflowNodeExecutionStatus.SUCCEEDED, + finished_at=event_finished_at, + ) + + assert node_execution.finished_at == event_finished_at + assert node_execution.elapsed_time == 2.0 diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py new file mode 100644 index 0000000000..3759b6aa37 --- /dev/null +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -0,0 +1,390 @@ +import base64 +import queue +from unittest.mock import MagicMock + +import pytest + +from core.base.tts.app_generator_tts_publisher import ( + AppGeneratorTTSPublisher, + AudioTrunk, + _invoice_tts, + _process_future, +) + +# ========================= +# Fixtures +# ========================= + + +@pytest.fixture +def mock_model_instance(mocker): + model = mocker.MagicMock() + model.invoke_tts.return_value = [b"audio1", b"audio2"] + model.get_tts_voices.return_value = [{"value": "voice1"}, {"value": "voice2"}] + return model + + +@pytest.fixture +def mock_model_manager(mocker, mock_model_instance): + manager = mocker.MagicMock() + manager.get_default_model_instance.return_value = mock_model_instance + mocker.patch( + "core.base.tts.app_generator_tts_publisher.ModelManager", + return_value=manager, + ) + return manager + + +@pytest.fixture(autouse=True) +def patch_threads(mocker): + """Prevent real threads from starting during tests""" + mocker.patch("threading.Thread.start", return_value=None) + + +# ========================= +# AudioTrunk Tests +# ========================= + + +class TestAudioTrunk: + def test_audio_trunk_initialization(self): + trunk = AudioTrunk("responding", b"data") + assert trunk.status == "responding" + assert trunk.audio == b"data" + + +# ========================= +# _invoice_tts Tests +# ========================= + + +class TestInvoiceTTS: + @pytest.mark.parametrize( + "text", + [None, "", " "], + ) + def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): + result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + assert result is None + mock_model_instance.invoke_tts.assert_not_called() + + def test_invoice_tts_valid_text(self, mock_model_instance): + result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + mock_model_instance.invoke_tts.assert_called_once_with( + content_text="hello", + user="responding_tts", + tenant_id="tenant", + voice="voice1", + ) + assert result == [b"audio1", b"audio2"] + + +# ========================= +# _process_future Tests +# ========================= + + +class TestProcessFuture: + def test_process_future_normal_flow(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = [b"abc"] + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + first = audio_queue.get() + assert first.status == "responding" + assert first.audio == base64.b64encode(b"abc") + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_empty_result(self): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.return_value = None + + future_queue.put(future) + future_queue.put(None) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + def test_process_future_exception(self, mocker): + future_queue = queue.Queue() + audio_queue = queue.Queue() + + future = MagicMock() + future.result.side_effect = Exception("error") + + future_queue.put(future) + + _process_future(future_queue, audio_queue) + + finish = audio_queue.get() + assert finish.status == "finish" + + +# ========================= +# AppGeneratorTTSPublisher Tests +# ========================= + + +class TestAppGeneratorTTSPublisher: + def test_initialization_valid_voice(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + assert publisher.voice == "voice1" + assert publisher.max_sentence == 2 + assert publisher.msg_text == "" + + def test_initialization_invalid_voice_fallback(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "invalid_voice") + assert publisher.voice == "voice1" + + def test_publish_puts_message_in_queue(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + message = MagicMock() + publisher.publish(message) + assert publisher._msg_queue.get() == message + + def test_check_and_get_audio_no_audio(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + result = publisher.check_and_get_audio() + assert result is None + + def test_check_and_get_audio_non_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + trunk = AudioTrunk("responding", b"abc") + publisher._audio_queue.put(trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "responding" + assert publisher._last_audio_event == trunk + + def test_check_and_get_audio_finish_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + finish_trunk = AudioTrunk("finish", b"") + publisher._audio_queue.put(finish_trunk) + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + def test_check_and_get_audio_cached_finish(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher._last_audio_event = AudioTrunk("finish", b"") + + result = publisher.check_and_get_audio() + + assert result.status == "finish" + publisher.executor.shutdown.assert_called_once() + + @pytest.mark.parametrize( + ("text", "expected_sentences", "expected_remaining"), + [ + ("Hello world.", ["Hello world."], ""), + ("Hello world! How are you?", ["Hello world!", " How are you?"], ""), + ("No punctuation", [], "No punctuation"), + ("", [], ""), + ], + ) + def test_extract_sentence(self, mock_model_manager, text, expected_sentences, expected_remaining): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + sentences, remaining = publisher._extract_sentence(text) + assert sentences == expected_sentences + assert remaining == expected_remaining + + def test_runtime_handles_none_message_with_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = "Hello." + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_called_once() + + def test_runtime_handles_none_message_without_buffer(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + publisher.msg_text = " " + + publisher._msg_queue.put(None) + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_sentence_threshold_triggers_submit(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + # Force sentence extraction to hit threshold condition + mocker.patch.object( + publisher, + "_extract_sentence", + return_value=(["Hello world.", " Second sentence."], ""), + ) + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Second sentence." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_text_chunk_event(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world." + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_with_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = {"output": "Hello world."} + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.executor.submit.called + + def test_runtime_handles_node_succeeded_event_without_output(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueNodeSucceededEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueNodeSucceededEvent) + event.event.outputs = None + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + publisher.executor.submit.assert_not_called() + + def test_runtime_handles_agent_message_event_list_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + ) + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="Hello "), + ImagePromptMessageContent(format="png", mime_type="image/png", base64_data="a"), + ] + ), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "Hello " + + def test_runtime_handles_agent_message_event_empty_content(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueAgentMessageEvent + from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + + chunk = LLMResultChunk( + model="model", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=""), + ), + ) + event = MagicMock(event=QueueAgentMessageEvent(chunk=chunk)) + + mocker.patch.object(publisher, "_extract_sentence", return_value=([], "")) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_resets_msg_text_when_text_tmp_not_str(self, mock_model_manager, mocker): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher.executor = MagicMock() + + from core.app.entities.queue_entities import QueueTextChunkEvent + + event = MagicMock() + event.event = MagicMock(spec=QueueTextChunkEvent) + event.event.text = "Hello world. Another sentence." + + mocker.patch.object(publisher, "_extract_sentence", return_value=(["A.", "B."], None)) + + publisher._msg_queue.put(event) + publisher._msg_queue.put(None) + + publisher._runtime() + + assert publisher.msg_text == "" + + def test_runtime_exception_path(self, mock_model_manager): + publisher = AppGeneratorTTSPublisher("tenant", "voice1") + publisher._msg_queue = MagicMock() + publisher._msg_queue.get.side_effect = Exception("error") + + publisher._runtime() diff --git a/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py new file mode 100644 index 0000000000..4c1aa33540 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_agent_tool_callback_handler.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock + +import pytest + +import core.callback_handler.agent_tool_callback_handler as module + +# ----------------------------- +# Fixtures +# ----------------------------- + + +@pytest.fixture +def enable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", True) + + +@pytest.fixture +def disable_debug(mocker): + mocker.patch.object(module.dify_config, "DEBUG", False) + + +@pytest.fixture +def mock_print(mocker): + return mocker.patch("builtins.print") + + +@pytest.fixture +def handler(): + return module.DifyAgentCallbackHandler(color="blue") + + +# ----------------------------- +# get_colored_text Tests +# ----------------------------- + + +class TestGetColoredText: + @pytest.mark.parametrize( + ("color", "expected_code"), + [ + ("blue", "36;1"), + ("yellow", "33;1"), + ("pink", "38;5;200"), + ("green", "32;1"), + ("red", "31;1"), + ], + ) + def test_get_colored_text_valid_colors(self, color, expected_code): + text = "hello" + result = module.get_colored_text(text, color) + assert expected_code in result + assert text in result + assert result.endswith("\u001b[0m") + + def test_get_colored_text_invalid_color_raises(self): + with pytest.raises(KeyError): + module.get_colored_text("hello", "invalid") + + def test_get_colored_text_empty_string(self): + result = module.get_colored_text("", "green") + assert "\u001b[" in result + + +# ----------------------------- +# print_text Tests +# ----------------------------- + + +class TestPrintText: + def test_print_text_without_color(self, mock_print): + module.print_text("hello") + mock_print.assert_called_once_with("hello", end="", file=None) + + def test_print_text_with_color(self, mocker, mock_print): + mock_get_color = mocker.patch( + "core.callback_handler.agent_tool_callback_handler.get_colored_text", + return_value="colored_text", + ) + + module.print_text("hello", color="green") + + mock_get_color.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("colored_text", end="", file=None) + + def test_print_text_with_file_flush(self, mocker): + mock_file = MagicMock() + mock_print = mocker.patch("builtins.print") + + module.print_text("hello", file=mock_file) + + mock_print.assert_called_once_with("hello", end="", file=mock_file) + mock_file.flush.assert_called_once() + + def test_print_text_with_end_parameter(self, mock_print): + module.print_text("hello", end="\n") + mock_print.assert_called_once_with("hello", end="\n", file=None) + + +# ----------------------------- +# DifyAgentCallbackHandler Tests +# ----------------------------- + + +class TestDifyAgentCallbackHandler: + def test_init_default_color(self): + handler = module.DifyAgentCallbackHandler() + assert handler.color == "green" + assert handler.current_loop == 1 + + def test_on_tool_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_called() + + def test_on_tool_start_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_start("tool1", {"a": 1}) + + mock_print_text.assert_not_called() + + def test_on_tool_end_debug_enabled_and_trace(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + mock_trace_manager = MagicMock() + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={"a": 1}, + tool_outputs="output", + message_id="msg1", + timer=123, + trace_manager=mock_trace_manager, + ) + + assert mock_print_text.call_count >= 1 + mock_trace_manager.add_trace_task.assert_called_once() + + def test_on_tool_end_without_trace_manager(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_end( + tool_name="tool1", + tool_inputs={}, + tool_outputs="output", + ) + + assert mock_print_text.call_count >= 1 + + def test_on_tool_error_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_called_once() + + def test_on_tool_error_debug_disabled(self, handler, disable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_tool_error(Exception("error")) + + mock_print_text.assert_not_called() + + @pytest.mark.parametrize("thought", ["thinking", ""]) + def test_on_agent_start(self, handler, enable_debug, mocker, thought): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_agent_start(thought) + + mock_print_text.assert_called() + + def test_on_agent_finish_increments_loop(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + current_loop = handler.current_loop + handler.on_agent_finish() + + assert handler.current_loop == current_loop + 1 + mock_print_text.assert_called() + + def test_on_datasource_start_debug_enabled(self, handler, enable_debug, mocker): + mock_print_text = mocker.patch("core.callback_handler.agent_tool_callback_handler.print_text") + + handler.on_datasource_start("ds1", {"x": 1}) + + mock_print_text.assert_called_once() + + def test_ignore_agent_property(self, disable_debug, handler): + assert handler.ignore_agent is True + + def test_ignore_chat_model_property(self, disable_debug, handler): + assert handler.ignore_chat_model is True + + def test_ignore_properties_when_debug_enabled(self, enable_debug, handler): + assert handler.ignore_agent is False + assert handler.ignore_chat_model is False diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py new file mode 100644 index 0000000000..b37c4c57a1 --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -0,0 +1,162 @@ +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.callback_handler.index_tool_callback_handler import ( + DatasetIndexToolCallbackHandler, +) + + +@pytest.fixture +def mock_queue_manager(mocker): + return mocker.Mock() + + +@pytest.fixture +def handler(mock_queue_manager, mocker): + mocker.patch( + "core.callback_handler.index_tool_callback_handler.db", + ) + return DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + +class TestOnQuery: + @pytest.mark.parametrize( + ("invoke_from", "expected_role"), + [ + (InvokeFrom.EXPLORE, "account"), + (InvokeFrom.DEBUGGER, "account"), + (InvokeFrom.WEB_APP, "end_user"), + ], + ) + def test_on_query_success_roles(self, mocker, mock_queue_manager, invoke_from, expected_role): + # Arrange + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id="app-1", + message_id="msg-1", + user_id="user-1", + invoke_from=mocker.Mock(), + ) + + handler._invoke_from = invoke_from + + # Act + handler.on_query("test query", "dataset-1") + + # Assert + mock_db.session.add.assert_called_once() + dataset_query = mock_db.session.add.call_args.args[0] + assert dataset_query.created_by_role == expected_role + mock_db.session.commit.assert_called_once() + + def test_on_query_none_values(self, mocker, mock_queue_manager): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + handler = DatasetIndexToolCallbackHandler( + queue_manager=mock_queue_manager, + app_id=None, + message_id=None, + user_id=None, + invoke_from=None, + ) + + handler.on_query(None, None) + + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +class TestOnToolEnd: + def test_on_tool_end_no_metadata(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + document = mocker.Mock() + document.metadata = None + + handler.on_tool_end([document]) + + mock_db.session.commit.assert_not_called() + + def test_on_tool_end_dataset_document_not_found(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + mock_db.session.scalar.return_value = None + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + handler.on_tool_end([document]) + + mock_db.session.scalar.assert_called_once() + + def test_on_tool_end_parent_child_index_with_child(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + from core.callback_handler.index_tool_callback_handler import IndexStructureType + + mock_dataset_doc.doc_form = IndexStructureType.PARENT_CHILD_INDEX + mock_dataset_doc.dataset_id = "dataset-1" + mock_dataset_doc.id = "doc-1" + + mock_child_chunk = mocker.Mock() + mock_child_chunk.segment_id = "segment-1" + + mock_db.session.scalar.side_effect = [mock_dataset_doc, mock_child_chunk] + + document = mocker.Mock() + document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_non_parent_child_index(self, handler, mocker): + mock_db = mocker.patch("core.callback_handler.index_tool_callback_handler.db") + + mock_dataset_doc = mocker.Mock() + mock_dataset_doc.doc_form = "OTHER" + + mock_db.session.scalar.return_value = mock_dataset_doc + + document = mocker.Mock() + document.metadata = { + "document_id": "doc-1", + "doc_id": "node-1", + "dataset_id": "dataset-1", + } + + mock_query = mocker.Mock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + + handler.on_tool_end([document]) + + mock_query.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_on_tool_end_empty_documents(self, handler): + handler.on_tool_end([]) + + +class TestReturnRetrieverResourceInfo: + def test_publish_called(self, handler, mock_queue_manager, mocker): + mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent") + + resources = [mocker.Mock()] + + handler.return_retriever_resource_info(resources) + + mock_queue_manager.publish.assert_called_once() diff --git a/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py new file mode 100644 index 0000000000..131fb006ed --- /dev/null +++ b/api/tests/unit_tests/core/callback_handler/test_workflow_tool_callback_handler.py @@ -0,0 +1,184 @@ +from unittest.mock import MagicMock, call + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import ( + DifyWorkflowCallbackHandler, +) + + +class DummyToolInvokeMessage: + """Lightweight dummy to simulate ToolInvokeMessage behavior.""" + + def __init__(self, json_value: str): + self._json_value = json_value + + def model_dump_json(self): + return self._json_value + + +@pytest.fixture +def handler(): + """Fixture to create handler instance with deterministic color.""" + instance = DifyWorkflowCallbackHandler() + instance.color = "blue" + return instance + + +@pytest.fixture +def mock_print_text(mocker): + """Mock print_text to avoid real stdout printing.""" + return mocker.patch("core.callback_handler.workflow_tool_callback_handler.print_text") + + +class TestDifyWorkflowCallbackHandler: + def test_on_tool_execution_single_output_success(self, handler, mock_print_text): + # Arrange + tool_name = "test_tool" + tool_inputs = {"a": 1} + message = DummyToolInvokeMessage('{"key": "value"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs=tool_inputs, + tool_outputs=[message], + ) + ) + + # Assert + assert results == [message] + assert mock_print_text.call_count == 4 + mock_print_text.assert_has_calls( + [ + call("\n[on_tool_execution]\n", color="blue"), + call("Tool: test_tool\n", color="blue"), + call( + "Outputs: " + message.model_dump_json()[:1000] + "\n", + color="blue", + ), + call("\n"), + ] + ) + + def test_on_tool_execution_multiple_outputs(self, handler, mock_print_text): + # Arrange + tool_name = "multi_tool" + outputs = [ + DummyToolInvokeMessage('{"id": 1}'), + DummyToolInvokeMessage('{"id": 2}'), + ] + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=outputs, + ) + ) + + # Assert + assert results == outputs + assert mock_print_text.call_count == 4 * len(outputs) + + def test_on_tool_execution_empty_iterable(self, handler, mock_print_text): + # Arrange + tool_name = "empty_tool" + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[], + ) + ) + + # Assert + assert results == [] + mock_print_text.assert_not_called() + + @pytest.mark.parametrize( + ("invalid_outputs", "expected_exception"), + [ + (None, TypeError), + (123, TypeError), + ("not_iterable", AttributeError), + ], + ) + def test_on_tool_execution_invalid_outputs_type(self, handler, invalid_outputs, expected_exception): + # Arrange + tool_name = "invalid_tool" + + # Act & Assert + with pytest.raises(expected_exception): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=invalid_outputs, + ) + ) + + def test_on_tool_execution_long_json_truncation(self, handler, mock_print_text): + # Arrange + tool_name = "long_json_tool" + long_json = "x" * 1500 + message = DummyToolInvokeMessage(long_json) + + # Act + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + ) + ) + + # Assert + expected_truncated = long_json[:1000] + mock_print_text.assert_any_call( + "Outputs: " + expected_truncated + "\n", + color="blue", + ) + + def test_on_tool_execution_model_dump_json_exception(self, handler, mock_print_text): + # Arrange + tool_name = "exception_tool" + bad_message = MagicMock() + bad_message.model_dump_json.side_effect = ValueError("JSON error") + + # Act & Assert + with pytest.raises(ValueError): + list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[bad_message], + ) + ) + + # Ensure first two prints happened before failure + assert mock_print_text.call_count >= 2 + + def test_on_tool_execution_none_message_id_and_trace_manager(self, handler, mock_print_text): + # Arrange + tool_name = "optional_params_tool" + message = DummyToolInvokeMessage('{"data": "ok"}') + + # Act + results = list( + handler.on_tool_execution( + tool_name=tool_name, + tool_inputs={}, + tool_outputs=[message], + message_id=None, + timer=None, + trace_manager=None, + ) + ) + + assert results == [message] + assert mock_print_text.call_count == 4 diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py new file mode 100644 index 0000000000..5482b4db52 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_plugin.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +from configs import dify_config +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType + + +class ConcreteDatasourcePlugin(DatasourcePlugin): + """ + Concrete implementation of DatasourcePlugin for testing purposes. + Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it. + """ + + def datasource_provider_type(self) -> str: + return DatasourceProviderType.LOCAL_FILE + + +class TestDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + # Act + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Act + provider_type = plugin.datasource_provider_type() + # Call the base class method to ensure it's covered + base_provider_type = DatasourcePlugin.datasource_provider_type(plugin) + + # Assert + assert provider_type == DatasourceProviderType.LOCAL_FILE + assert base_provider_type == DatasourceProviderType.LOCAL_FILE + + def test_fork_datasource_runtime(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_entity_copy = MagicMock(spec=DatasourceEntity) + mock_entity.model_copy.return_value = mock_entity_copy + + runtime = MagicMock(spec=DatasourceRuntime) + new_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + + plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon) + + # Act + new_plugin = plugin.fork_datasource_runtime(new_runtime) + + # Assert + assert isinstance(new_plugin, ConcreteDatasourcePlugin) + assert new_plugin.entity == mock_entity_copy + assert new_plugin.runtime == new_runtime + assert new_plugin.icon == icon + mock_entity.model_copy.assert_called_once() + + def test_get_icon_url(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon.png" + tenant_id = "test-tenant-id" + + plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon) + + # Mocking dify_config.CONSOLE_API_URL + with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"): + # Act + icon_url = plugin.get_icon_url(tenant_id) + + # Assert + expected_url = ( + f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}" + ) + assert icon_url == expected_url diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py new file mode 100644 index 0000000000..6a3d21a33d --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_provider.py @@ -0,0 +1,265 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_plugin import DatasourcePlugin +from core.datasource.__base.datasource_provider import DatasourcePluginProviderController +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.entities.provider_entities import ProviderConfig +from core.tools.errors import ToolProviderCredentialValidationError + + +class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController): + """ + Concrete implementation of DatasourcePluginProviderController for testing purposes. + """ + + def get_datasource(self, datasource_name: str) -> DatasourcePlugin: + return MagicMock(spec=DatasourcePlugin) + + +class TestDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + + # Act + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Assert + assert controller.entity == mock_entity + assert controller.tenant_id == tenant_id + + def test_need_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + tenant_id = "test-tenant-id" + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Case 1: credentials_schema is None + mock_entity.credentials_schema = None + assert controller.need_credentials is False + + # Case 2: credentials_schema is empty + mock_entity.credentials_schema = [] + assert controller.need_credentials is False + + # Case 3: credentials_schema has items + mock_entity.credentials_schema = [MagicMock()] + assert controller.need_credentials is True + + @patch("core.datasource.__base.datasource_provider.PluginToolManager") + def test_validate_credentials(self, mock_manager_class): + # Arrange + mock_manager = mock_manager_class.return_value + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + tenant_id = "test-tenant-id" + user_id = "test-user-id" + credentials = {"api_key": "secret"} + + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id) + + # Act: Successful validation + mock_manager.validate_datasource_credentials.return_value = True + controller._validate_credentials(user_id, credentials) + + mock_manager.validate_datasource_credentials.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + provider="test-provider", + credentials=credentials, + ) + + # Act: Failed validation + mock_manager.validate_datasource_credentials.return_value = False + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id, credentials) + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials_format_empty_schema(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {} + + # Act & Assert (Should not raise anything) + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_unknown_credential(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.identity = MagicMock() + mock_entity.identity.name = "test-provider" + mock_entity.credentials_schema = [] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + credentials = {"unknown": "value"} + + # Act & Assert + with pytest.raises( + ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider" + ): + controller.validate_credentials_format(credentials) + + def test_validate_credentials_format_required_missing(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "api_key" + mock_config.required = True + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"): + controller.validate_credentials_format({}) + + def test_validate_credentials_format_not_required_null(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + credentials = {"optional": None} + controller.validate_credentials_format(credentials) + assert credentials["optional"] is None + + def test_validate_credentials_format_type_mismatch_text(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "text_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act & Assert + with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"): + controller.validate_credentials_format({"text_field": 123}) + + def test_validate_credentials_format_select_validation(self): + # Arrange + mock_option = MagicMock() + mock_option.value = "opt1" + + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "select_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.SELECT + mock_config.options = [mock_option] + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Case 1: Value not string + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"): + controller.validate_credentials_format({"select_field": 123}) + + # Case 2: Options not list + mock_config.options = "invalid" + with pytest.raises( + ToolProviderCredentialValidationError, match="credential select_field options should be list" + ): + controller.validate_credentials_format({"select_field": "opt1"}) + + # Case 3: Value not in options + mock_config.options = [mock_option] + with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"): + controller.validate_credentials_format({"select_field": "invalid_opt"}) + + def test_get_datasource_base(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + result = DatasourcePluginProviderController.get_datasource(controller, "test") + + # Assert + assert result is None + + def test_validate_credentials_format_hits_pop(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "valid_field" + mock_config.required = True + mock_config.type = ProviderConfig.Type.TEXT_INPUT + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"valid_field": "valid_value"} + controller.validate_credentials_format(credentials) + + # Assert + assert "valid_field" in credentials + assert credentials["valid_field"] == "valid_value" + + def test_validate_credentials_format_hits_continue(self): + # Arrange + mock_config = MagicMock(spec=ProviderConfig) + mock_config.name = "optional_field" + mock_config.required = False + mock_config.default = None + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {"optional_field": None} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["optional_field"] is None + + def test_validate_credentials_format_default_values(self): + # Arrange + mock_config_text = MagicMock(spec=ProviderConfig) + mock_config_text.name = "text_def" + mock_config_text.required = False + mock_config_text.type = ProviderConfig.Type.TEXT_INPUT + mock_config_text.default = 123 # Int default, should be converted to str + + mock_config_other = MagicMock(spec=ProviderConfig) + mock_config_other.name = "other_def" + mock_config_other.required = False + mock_config_other.type = "OTHER" + mock_config_other.default = "fallback" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.credentials_schema = [mock_config_text, mock_config_other] + controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test") + + # Act + credentials = {} + controller.validate_credentials_format(credentials) + + # Assert + assert credentials["text_def"] == "123" + assert credentials["other_def"] == "fallback" diff --git a/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py new file mode 100644 index 0000000000..2bca9155e9 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/__base/test_datasource_runtime.py @@ -0,0 +1,26 @@ +from core.app.entities.app_invoke_entities import InvokeFrom +from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime +from core.datasource.entities.datasource_entities import DatasourceInvokeFrom + + +class TestDatasourceRuntime: + def test_init(self): + runtime = DatasourceRuntime( + tenant_id="test-tenant", + datasource_id="test-ds", + invoke_from=InvokeFrom.DEBUGGER, + datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE, + credentials={"key": "val"}, + runtime_parameters={"p": "v"}, + ) + assert runtime.tenant_id == "test-tenant" + assert runtime.datasource_id == "test-ds" + assert runtime.credentials["key"] == "val" + + def test_fake_datasource_runtime(self): + # This covers the FakeDatasourceRuntime class and its __init__ + runtime = FakeDatasourceRuntime() + assert runtime.tenant_id == "fake_tenant_id" + assert runtime.datasource_id == "fake_datasource_id" + assert runtime.invoke_from == InvokeFrom.DEBUGGER + assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE diff --git a/api/tests/unit_tests/core/datasource/entities/test_api_entities.py b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py new file mode 100644 index 0000000000..9855b4040a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_api_entities.py @@ -0,0 +1,150 @@ +from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity +from core.datasource.entities.datasource_entities import DatasourceParameter +from core.tools.entities.common_entities import I18nObject + + +def test_datasource_api_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceApiEntity( + author="author", name="name", label=label, description=description, labels=["l1", "l2"] + ) + + assert entity.author == "author" + assert entity.name == "name" + assert entity.label == label + assert entity.description == description + assert entity.labels == ["l1", "l2"] + assert entity.parameters is None + assert entity.output_schema is None + + +def test_datasource_provider_api_entity_defaults(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + entity = DatasourceProviderApiEntity( + id="id", author="author", name="name", description=description, icon="icon", label=label, type="type" + ) + + assert entity.id == "id" + assert entity.datasources == [] + assert entity.is_team_authorization is False + assert entity.allow_delete is True + assert entity.plugin_id == "" + assert entity.plugin_unique_identifier == "" + assert entity.labels == [] + + +def test_datasource_provider_api_entity_convert_none_to_empty_list(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Implicitly testing the field_validator "convert_none_to_empty_list" + entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=None, # type: ignore + ) + + assert entity.datasources == [] + + +def test_datasource_provider_api_entity_to_dict(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + # Create a parameter that should be converted + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + masked_credentials={"key": "masked"}, + datasources=[ds_entity], + labels=["l1"], + ) + + result = provider_entity.to_dict() + + assert result["id"] == "id" + assert result["author"] == "author" + assert result["name"] == "name" + assert result["description"] == description.to_dict() + assert result["icon"] == "icon" + assert result["label"] == label.to_dict() + assert result["type"] == "type" + assert result["team_credentials"] == {"key": "masked"} + assert result["is_team_authorization"] is False + assert result["allow_delete"] is True + assert result["labels"] == ["l1"] + + # Check if parameter type was converted from SYSTEM_FILES to files + assert result["datasources"][0]["parameters"][0]["type"] == "files" + + +def test_datasource_provider_api_entity_to_dict_no_params(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=None + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"] is None + + +def test_datasource_provider_api_entity_to_dict_other_param_type(): + description = I18nObject(en_US="desc", zh_Hans="描述") + label = I18nObject(en_US="label", zh_Hans="标签") + + param = DatasourceParameter.get_simple_instance( + name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True + ) + + ds_entity = DatasourceApiEntity( + author="author", name="ds_name", label=label, description=description, parameters=[param] + ) + + provider_entity = DatasourceProviderApiEntity( + id="id", + author="author", + name="name", + description=description, + icon="icon", + label=label, + type="type", + datasources=[ds_entity], + ) + + result = provider_entity.to_dict() + assert result["datasources"][0]["parameters"][0]["type"] == "string" diff --git a/api/tests/unit_tests/core/datasource/entities/test_common_entities.py b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py new file mode 100644 index 0000000000..0ee4928105 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_common_entities.py @@ -0,0 +1,31 @@ +from core.datasource.entities.common_entities import I18nObject + + +def test_i18n_object_fallback(): + # Only en_US provided + obj = I18nObject(en_US="Hello") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "Hello" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + # Some fields provided + obj = I18nObject(en_US="Hello", zh_Hans="你好") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Hello" + assert obj.ja_JP == "Hello" + + +def test_i18n_object_all_fields(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + assert obj.en_US == "Hello" + assert obj.zh_Hans == "你好" + assert obj.pt_BR == "Olá" + assert obj.ja_JP == "こんにちは" + + +def test_i18n_object_to_dict(): + obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは") + expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"} + assert obj.to_dict() == expected_dict diff --git a/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py new file mode 100644 index 0000000000..a8c8d31537 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/entities/test_datasource_entities.py @@ -0,0 +1,275 @@ +from unittest.mock import patch + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceInvokeMeta, + DatasourceLabel, + DatasourceMessage, + DatasourceParameter, + DatasourceProviderEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, + GetOnlineDocumentPageContentResponse, + GetWebsiteCrawlRequest, + OnlineDocumentInfo, + OnlineDocumentPage, + OnlineDocumentPageContent, + OnlineDocumentPagesMessage, + OnlineDriveBrowseFilesRequest, + OnlineDriveBrowseFilesResponse, + OnlineDriveDownloadFileRequest, + OnlineDriveFile, + OnlineDriveFileBucket, + WebsiteCrawlMessage, + WebSiteInfo, + WebSiteInfoDetail, +) +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolLabelEnum + + +def test_datasource_provider_type(): + assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT + assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE + + with pytest.raises(ValueError, match="invalid mode value invalid"): + DatasourceProviderType.value_of("invalid") + + +def test_datasource_parameter_type(): + param_type = DatasourceParameter.DatasourceParameterType.STRING + assert param_type.as_normal_type() == "string" + assert param_type.cast_value("test") == "test" + + param_type = DatasourceParameter.DatasourceParameterType.NUMBER + assert param_type.cast_value("123") == 123 + + +def test_datasource_parameter(): + param = DatasourceParameter.get_simple_instance( + name="test_param", + typ=DatasourceParameter.DatasourceParameterType.STRING, + required=True, + options=["opt1", "opt2"], + ) + assert param.name == "test_param" + assert param.type == DatasourceParameter.DatasourceParameterType.STRING + assert param.required is True + assert len(param.options) == 2 + assert param.options[0].value == "opt1" + + param_no_options = DatasourceParameter.get_simple_instance( + name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False + ) + assert param_no_options.options == [] + + # Test init_frontend_parameter + # For STRING, it should just return the value as is (or cast to str) + frontend_param = param.init_frontend_parameter("val") + assert frontend_param == "val" + + # Test parameter type methods + assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string" + assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number" + assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string" + + assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5 + assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True + assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"] + + +def test_datasource_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon") + assert identity.author == "author" + assert identity.name == "name" + assert identity.label == label + assert identity.provider == "provider" + assert identity.icon == "icon" + + +def test_datasource_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider") + description = I18nObject(en_US="desc", zh_Hans="描述") + + entity = DatasourceEntity( + identity=identity, + description=description, + parameters=None, # Should be handled by validator + ) + assert entity.parameters == [] + + param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True) + entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param]) + assert entity_with_params.parameters == [param] + + +def test_datasource_provider_identity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH] + ) + + assert identity.author == "author" + assert identity.name == "name" + assert identity.description == description + assert identity.icon == "icon.png" + assert identity.label == label + assert identity.tags == [ToolLabelEnum.SEARCH] + + # Test generate_datasource_icon_url + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = "http://api.example.com" + url = identity.generate_datasource_icon_url("tenant123") + assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url + assert "tenant_id=tenant123" in url + assert "filename=icon.png" in url + + # Test hardcoded icon + identity.icon = "https://assets.dify.ai/images/File%20Upload.svg" + assert identity.generate_datasource_icon_url("tenant123") == identity.icon + + # Test with empty CONSOLE_API_URL + identity.icon = "test.png" + with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config: + mock_config.CONSOLE_API_URL = None + url = identity.generate_datasource_icon_url("tenant123") + assert url.startswith("/console/api/workspaces/current/plugin/icon") + + +def test_datasource_provider_entity(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntity( + identity=identity, + provider_type=DatasourceProviderType.ONLINE_DOCUMENT, + credentials_schema=[], + oauth_schema=None, + ) + assert entity.identity == identity + assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + assert entity.credentials_schema == [] + + +def test_datasource_provider_entity_with_plugin(): + label = I18nObject(en_US="label", zh_Hans="标签") + description = I18nObject(en_US="desc", zh_Hans="描述") + identity = DatasourceProviderIdentity( + author="author", name="name", description=description, icon="icon", label=label + ) + + entity = DatasourceProviderEntityWithPlugin( + identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[] + ) + assert entity.datasources == [] + + +def test_datasource_invoke_meta(): + meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"}) + assert meta.time_cost == 1.5 + assert meta.error == "some error" + assert meta.tool_config == {"k": "v"} + + d = meta.to_dict() + assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}} + + empty_meta = DatasourceInvokeMeta.empty() + assert empty_meta.time_cost == 0.0 + assert empty_meta.error is None + assert empty_meta.tool_config == {} + + error_meta = DatasourceInvokeMeta.error_instance("fatal error") + assert error_meta.time_cost == 0.0 + assert error_meta.error == "fatal error" + assert error_meta.tool_config == {} + + +def test_datasource_label(): + label_obj = I18nObject(en_US="label", zh_Hans="标签") + ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon") + assert ds_label.name == "name" + assert ds_label.label == label_obj + assert ds_label.icon == "icon" + + +def test_online_document_models(): + page = OnlineDocumentPage( + page_id="p1", + page_name="name", + page_icon={"type": "emoji"}, + type="page", + last_edited_time="2023-01-01", + parent_id=None, + ) + assert page.page_id == "p1" + + info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page]) + assert info.total == 1 + + msg = OnlineDocumentPagesMessage(result=[info]) + assert msg.result == [info] + + req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page") + assert req.workspace_id == "w1" + + content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello") + assert content.content == "hello" + + resp = GetOnlineDocumentPageContentResponse(result=content) + assert resp.result == content + + +def test_website_crawl_models(): + req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"}) + assert req.crawl_parameters == {"url": "http://test.com"} + + detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc") + assert detail.title == "title" + + info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1) + assert info.status == "completed" + + msg = WebsiteCrawlMessage(result=info) + assert msg.result == info + + # Test default values + msg_default = WebsiteCrawlMessage() + assert msg_default.result.status == "" + assert msg_default.result.web_info_list == [] + + +def test_online_drive_models(): + file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file") + assert file.name == "file.txt" + + bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None) + assert bucket.bucket == "b1" + + req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None) + assert req.prefix == "folder1" + + resp = OnlineDriveBrowseFilesResponse(result=[bucket]) + assert resp.result == [bucket] + + dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1") + assert dl_req.id == "f1" + + +def test_datasource_message(): + # Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes + msg = DatasourceMessage(type="text", message={"text": "hello"}) + assert msg.message.text == "hello" + + msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}}) + assert msg_json.message.json_object == {"k": "v"} diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py new file mode 100644 index 0000000000..5bf7362a8a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_plugin.py @@ -0,0 +1,57 @@ +from unittest.mock import MagicMock + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin + + +class TestLocalFileDatasourcePlugin: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE + + def test_get_icon_url(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceEntity) + mock_runtime = MagicMock(spec=DatasourceRuntime) + icon = "test-icon" + plugin = LocalFileDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.get_icon_url("any-tenant-id") == icon diff --git a/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py new file mode 100644 index 0000000000..af2369ac4e --- /dev/null +++ b/api/tests/unit_tests/core/datasource/local_file/test_local_file_provider.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin +from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController + + +class TestLocalFileDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.LOCAL_FILE + + def test_validate_credentials(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + # Should not raise any exception + controller._validate_credentials("user_id", {"key": "value"}) + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, LocalFileDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = LocalFileDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py new file mode 100644 index 0000000000..e3a217725a --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_plugin.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + GetOnlineDocumentPageContentRequest, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin + + +class TestOnlineDocumentDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_get_online_document_pages(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = {"param": "value"} + provider_type = "test_type" + + mock_generator = MagicMock() + + # Patch PluginDatasourceManager to isolate plugin behavior from external dependencies + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_pages.return_value = mock_generator + + # Act + result = plugin.get_online_document_pages( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_pages.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_get_online_document_page_content(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test_key"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.get_online_document_page_content.return_value = mock_generator + + # Act + result = plugin.get_online_document_page_content( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert result == mock_generator + mock_manager_instance.get_online_document_page_content.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDocumentDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DOCUMENT diff --git a/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py new file mode 100644 index 0000000000..cfdd05e0b2 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_document/test_online_document_provider.py @@ -0,0 +1,100 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin +from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController + + +class TestOnlineDocumentDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT + + def test_get_datasource_success(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "target_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity = MagicMock() + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_uid" + tenant_id = "test_tenant_id" + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Act + result = controller.get_datasource("target_datasource") + + # Assert + assert isinstance(result, OnlineDocumentDatasourcePlugin) + assert result.entity == mock_datasource_entity + assert result.tenant_id == tenant_id + assert result.icon == "test_icon" + assert result.plugin_unique_identifier == plugin_unique_identifier + assert result.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + from core.datasource.entities.datasource_entities import DatasourceIdentity + + mock_datasource_entity = MagicMock(spec=DatasourceEntity) + mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity) + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDocumentDatasourcePluginProviderController( + entity=mock_entity, + plugin_id="test_plugin_id", + plugin_unique_identifier="test_plugin_uid", + tenant_id="test_tenant_id", + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"): + controller.get_datasource("missing_datasource") diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py new file mode 100644 index 0000000000..6c8b644871 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_plugin.py @@ -0,0 +1,147 @@ +from unittest.mock import MagicMock, patch + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceIdentity, + DatasourceProviderType, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin + + +class TestOnlineDriveDatasourcePlugin: + def test_init(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + # Act + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.entity == entity + assert plugin.runtime == runtime + assert plugin.tenant_id == tenant_id + assert plugin.icon == icon + assert plugin.plugin_unique_identifier == plugin_unique_identifier + + def test_online_drive_browse_files(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveBrowseFilesRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_browse_files.return_value = mock_generator + + # Act + result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_browse_files.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_online_drive_download_file(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + identity = MagicMock(spec=DatasourceIdentity) + entity.identity = identity + identity.provider = "test_provider" + identity.name = "test_name" + + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"token": "test_token"} + + tenant_id = "test_tenant" + icon = "test_icon" + plugin_unique_identifier = "test_plugin_id" + + plugin = OnlineDriveDatasourcePlugin( + entity=entity, + runtime=runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + user_id = "test_user" + request = MagicMock(spec=OnlineDriveDownloadFileRequest) + provider_type = "test_type" + + mock_generator = MagicMock() + + with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager: + mock_manager_instance = MockManager.return_value + mock_manager_instance.online_drive_download_file.return_value = mock_generator + + # Act + result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type) + + # Assert + assert result == mock_generator + mock_manager_instance.online_drive_download_file.assert_called_once_with( + tenant_id=tenant_id, + user_id=user_id, + datasource_provider="test_provider", + datasource_name="test_name", + credentials=runtime.credentials, + request=request, + provider_type=provider_type, + ) + + def test_datasource_provider_type(self): + # Arrange + entity = MagicMock(spec=DatasourceEntity) + runtime = MagicMock(spec=DatasourceRuntime) + plugin = OnlineDriveDatasourcePlugin( + entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act + result = plugin.datasource_provider_type() + + # Assert + assert result == DatasourceProviderType.ONLINE_DRIVE diff --git a/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py new file mode 100644 index 0000000000..2824ddd8ed --- /dev/null +++ b/api/tests/unit_tests/core/datasource/online_drive/test_online_drive_provider.py @@ -0,0 +1,83 @@ +from unittest.mock import MagicMock + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType +from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin +from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController + + +class TestOnlineDriveDatasourcePluginProviderController: + def test_init(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + plugin_id = "test_plugin_id" + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + # Act + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self): + # Arrange + mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE + + def test_get_datasource_success(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "test_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + mock_entity.identity.icon = "test_icon" + + plugin_unique_identifier = "test_plugin_unique_identifier" + tenant_id = "test_tenant_id" + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + datasource = controller.get_datasource("test_datasource") + + # Assert + assert isinstance(datasource, OnlineDriveDatasourcePlugin) + assert datasource.entity == mock_datasource_entity + assert datasource.tenant_id == tenant_id + assert datasource.icon == "test_icon" + assert datasource.plugin_unique_identifier == plugin_unique_identifier + assert datasource.runtime.tenant_id == tenant_id + + def test_get_datasource_not_found(self): + # Arrange + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity.name = "other_datasource" + + mock_entity = MagicMock() + mock_entity.datasources = [mock_datasource_entity] + + controller = OnlineDriveDatasourcePluginProviderController( + entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant" + ) + + # Act & Assert + with pytest.raises(ValueError, match="Datasource with name test_datasource not found"): + controller.get_datasource("test_datasource") diff --git a/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py new file mode 100644 index 0000000000..7cd1fdf06b --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_datasource_file_manager.py @@ -0,0 +1,410 @@ +import base64 +import hashlib +import hmac +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from core.datasource.datasource_file_manager import DatasourceFileManager +from models.model import MessageFile, UploadFile +from models.tools import ToolFile + + +class TestDatasourceFileManager: + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = "test_secret" + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" # 16 bytes + + datasource_file_id = "file_id_123" + extension = ".png" + + # Execute + signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension) + + # Verify + assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?") + assert "timestamp=1700000000" in signed_url + assert f"nonce={mock_urandom.return_value.hex()}" in signed_url + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.os.urandom") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time): + # Setup + mock_config.FILES_URL = "http://localhost:5001" + mock_config.SECRET_KEY = None # Empty secret + mock_time.return_value = 1700000000 + mock_urandom.return_value = b"1234567890abcdef" + + # Execute + signed_url = DatasourceFileManager.sign_file("file_id", ".png") + assert "sign=" in signed_url + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "test_secret" + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" # 200 seconds ago + nonce = "some_nonce" + + # Manually calculate sign + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + secret_key = b"test_secret" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + # Execute & Verify Success + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + # Verify Failure - Wrong Sign + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False + + # Verify Failure - Timeout + mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout) + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False + + @patch("core.datasource.datasource_file_manager.time.time") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_verify_file_empty_secret(self, mock_config, mock_time): + # Setup + mock_config.SECRET_KEY = "" # Empty string secret + mock_config.FILES_ACCESS_TIMEOUT = 300 + mock_time.return_value = 1700000000 + + datasource_file_id = "file_id_123" + timestamp = "1699999800" + nonce = "some_nonce" + + # Calculate with empty secret + data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}" + sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + + assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test.png", + ) + + # Verify + assert upload_file.tenant_id == tenant_id + assert upload_file.name == "test.png" + assert upload_file.size == len(file_binary) + assert upload_file.mime_type == mimetype + assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png" + + mock_storage.save.assert_called_once_with(upload_file.key, file_binary) + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + user_id = "user_123" + tenant_id = "tenant_456" + file_binary = b"fake binary data" + mimetype = "image/png" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + file_binary=file_binary, + mimetype=mimetype, + filename="test", # No extension + ) + + # Verify + assert upload_file.name == "test.png" # Should append extension + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + @patch("core.datasource.datasource_file_manager.guess_extension") + def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_guess_ext.return_value = None # Cannot guess + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user", + tenant_id="tenant", + conversation_id=None, + file_binary=b"data", + mimetype="application/x-unknown", + ) + + # Verify + assert upload_file.extension == ".bin" + assert upload_file.name == "unique_hex.bin" + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + @patch("core.datasource.datasource_file_manager.dify_config") + def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_config.STORAGE_TYPE = "local" + + # Execute + upload_file = DatasourceFileManager.create_file_by_raw( + user_id="user_123", + tenant_id="tenant_456", + conversation_id=None, + file_binary=b"data", + mimetype="application/pdf", + ) + + # Verify + assert upload_file.name == "unique_hex.pdf" + assert upload_file.extension == ".pdf" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} # No content-type in headers + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png" + ) + + # Verify + assert tool_file.mimetype == "image/png" # Guessed from .png in URL + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"bits" + mock_response.headers = {} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", + tenant_id="tenant_456", + file_url="https://example.com/unknown", # No extension, no headers + ) + + # Verify + assert tool_file.mimetype == "application/octet-stream" + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + @patch("core.datasource.datasource_file_manager.uuid4") + def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf): + # Setup + mock_uuid.return_value = MagicMock(hex="unique_hex") + mock_response = MagicMock() + mock_response.content = b"downloaded bits" + mock_response.headers = {"Content-Type": "image/jpeg"} + mock_ssrf.get.return_value = mock_response + + # Execute + tool_file = DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg" + ) + + # Verify + assert tool_file.mimetype == "image/jpeg" + assert tool_file.size == len(b"downloaded bits") + assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg" + mock_storage.save.assert_called_once() + + @patch("core.datasource.datasource_file_manager.ssrf_proxy") + def test_create_file_by_url_timeout(self, mock_ssrf): + # Setup + mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout") + + # Execute & Verify + with pytest.raises(ValueError, match="timeout when downloading file"): + DatasourceFileManager.create_file_by_url( + user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file" + ) + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "some_key" + mock_upload_file.mime_type = "image/png" + + mock_query = mock_db.session.query.return_value + mock_where = mock_query.where.return_value + mock_where.first.return_value = mock_upload_file + + mock_storage.load_once.return_value = b"file content" + + # Execute + result = DatasourceFileManager.get_file_binary("file_id") + + # Verify + assert result == (b"file content", "image/png") + + # Case: Not found + mock_where.first.return_value = None + assert DatasourceFileManager.get_file_binary("unknown") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db): + # Setup + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/tool_id.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.file_key = "tool_key" + mock_tool_file.mimetype = "image/png" + + # Mock query sequence + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + elif model == ToolFile: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"tool content" + + # Execute + result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id") + + # Verify + assert result == (b"tool content", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db): + # Test that it correctly parses tool_id even with extension in URL + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = "http://localhost/files/tools/abcdef.png" + + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "abcdef" + mock_tool_file.file_key = "tk" + mock_tool_file.mimetype = "image/png" + + def mock_query(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = mock_tool_file + return m + + mock_db.session.query.side_effect = mock_query + mock_storage.load_once.return_value = b"bits" + + result = DatasourceFileManager.get_file_binary_by_message_file_id("m") + assert result == (b"bits", "image/png") + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db): + # Setup common mock + mock_query_obj = MagicMock() + mock_db.session.query.return_value = mock_query_obj + mock_query_obj.where.return_value.first.return_value = None + + # Case 1: Message file not found + assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None + + # Case 2: Message file found but tool file not found + mock_message_file = MagicMock(spec=MessageFile) + mock_message_file.url = None + + def mock_query_v2(model): + m = MagicMock() + if model == MessageFile: + m.where.return_value.first.return_value = mock_message_file + else: + m.where.return_value.first.return_value = None + return m + + mock_db.session.query.side_effect = mock_query_v2 + assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None + + @patch("core.datasource.datasource_file_manager.db") + @patch("core.datasource.datasource_file_manager.storage") + def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db): + # Setup + mock_upload_file = MagicMock(spec=UploadFile) + mock_upload_file.key = "upload_key" + mock_upload_file.mime_type = "text/plain" + + mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file + + mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"]) + + # Execute + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id") + + # Verify + assert mimetype == "text/plain" + assert list(stream) == [b"chunk1", b"chunk2"] + + # Case: Not found + mock_db.session.query.return_value.where.return_value.first.return_value = None + stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none") + assert stream is None + assert mimetype is None diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index 9ee1df8bdc..d5eeae912c 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -1,10 +1,16 @@ import types from collections.abc import Generator +import pytest + +from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager -from core.datasource.entities.datasource_entities import DatasourceMessage -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent +from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType +from core.datasource.errors import DatasourceProviderNotFoundError +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -15,6 +21,22 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non ) +def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]: + messages: list[DatasourceMessage] = [] + try: + while True: + messages.append(next(gen)) + except StopIteration as e: + return messages, e.value + + +def _invalidate_recyclable_contextvars() -> None: + """ + Ensure RecyclableContextVar.get() raises LookupError until reset by code under test. + """ + RecyclableContextVar.increment_thread_recycles() + + def test_get_icon_url_calls_runtime(mocker): fake_runtime = mocker.Mock() fake_runtime.get_icon_url.return_value = "https://icon" @@ -30,6 +52,119 @@ def test_get_icon_url_calls_runtime(mocker): DatasourceManager.get_datasource_runtime.assert_called_once() +def test_get_datasource_runtime_delegates_to_provider_controller(mocker): + provider_controller = mocker.Mock() + provider_controller.get_datasource.return_value = object() + mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller) + + runtime = DatasourceManager.get_datasource_runtime( + provider_id="prov/x", + datasource_name="ds", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + assert runtime is provider_controller.get_datasource.return_value + provider_controller.get_datasource.assert_called_once_with("ds") + + +@pytest.mark.parametrize( + ("datasource_type", "controller_path"), + [ + ( + DatasourceProviderType.ONLINE_DOCUMENT, + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.ONLINE_DRIVE, + "core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.WEBSITE_CRAWL, + "core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController", + ), + ( + DatasourceProviderType.LOCAL_FILE, + "core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController", + ), + ], +) +def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path): + _invalidate_recyclable_contextvars() + + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + fetch = mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + ctrl_cls = mocker.patch(controller_path) + + first = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + second = DatasourceManager.get_datasource_plugin_provider( + provider_id=f"prov/{datasource_type.value}", + tenant_id="t1", + datasource_type=datasource_type, + ) + + assert first is second + assert fetch.call_count == 1 + assert ctrl_cls.call_count == 1 + + +def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker): + _invalidate_recyclable_contextvars() + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/notfound", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + +def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime + ) + + +def test_get_datasource_plugin_provider_raises_when_controller_none(mocker): + _invalidate_recyclable_contextvars() + provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq") + mocker.patch( + "core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider", + return_value=provider_entity, + ) + mocker.patch( + "core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController", + return_value=None, + ) + + with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"): + DatasourceManager.get_datasource_plugin_provider( + provider_id="prov/x", + tenant_id="t1", + datasource_type=DatasourceProviderType.ONLINE_DOCUMENT, + ) + + def test_stream_online_results_yields_messages_online_document(mocker): # stub runtime to yield a text message def _doc_messages(**_): @@ -60,6 +195,148 @@ def test_stream_online_results_yields_messages_online_document(mocker): assert msgs[0].message.text == "hello" +def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("hello") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["hello"] + assert final_value == {} + + +def test_stream_online_results_raises_when_missing_params(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def get_online_document_page_content(self, **_kwargs): + yield from _gen_messages_text_only("never") + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("never") + + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + +def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker): + class _Runtime: + def __init__(self) -> None: + self.runtime = types.SimpleNamespace(credentials=None) + + def online_drive_download_file(self, **_kwargs): + yield from _gen_messages_text_only("drive") + + runtime = _Runtime() + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ) + + gen = DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="cred", + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="fid", bucket="b"), + ) + messages, final_value = _drain_generator(gen) + + assert runtime.runtime.credentials == {"token": "t"} + assert [m.message.text for m in messages] == ["drive"] + assert final_value == {} + + +def test_stream_online_results_raises_for_unsupported_stream_type(mocker): + mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock()) + mocker.patch( + "core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials", + return_value={}, + ) + + with pytest.raises(ValueError, match="Unsupported datasource type for streaming"): + list( + DatasourceManager.stream_online_results( + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + datasource_param=None, + online_drive_request=None, + ) + ) + + def test_stream_node_events_emits_events_online_document(mocker): # make manager's low-level stream produce TEXT only mocker.patch.object( @@ -93,6 +370,260 @@ def test_stream_node_events_emits_events_online_document(mocker): assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED +def test_stream_node_events_builds_file_and_variables_from_messages(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, + message=DatasourceMessage.TextMessage(text="hello"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="http://example.com"), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.VARIABLE, + message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False), + meta=None, + ) + yield DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, + message=DatasourceMessage.JsonMessage(json_object={"k": "v"}), + meta=None, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + fake_tool_file = types.SimpleNamespace(mimetype="image/png") + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return fake_tool_file + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + mocker.patch( + "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE + ) + built = File( + tenant_id="t1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tool_file_1", + extension=".png", + mime_type="image/png", + storage_key="k", + ) + build_from_mapping = mocker.patch( + "core.datasource.datasource_manager.file_factory.build_from_mapping", + return_value=built, + ) + + variable_pool = mocker.Mock() + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={"k": "v"}, + datasource_info={"info": "x"}, + variable_pool=variable_pool, + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + build_from_mapping.assert_called_once() + variable_pool.add.assert_not_called() + + assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events) + assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events) + assert isinstance(events[-2], StreamChunkEvent) + assert events[-2].is_final is True + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["v"] == "ab" + assert events[-1].node_run_result.outputs["x"] == 1 + + +def test_stream_node_events_raises_when_toolfile_missing(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + class _Session: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def scalar(self, _stmt): + return None + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) + + with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"): + list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_document", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"), + online_drive_request=None, + ) + ) + + +def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + file_in = File( + tenant_id="t1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="tf", + extension=".pdf", + mime_type="application/pdf", + storage_key="k", + ) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.FileMessage(file_marker="file_marker"), + meta={"file": file_in}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + + variable_pool = mocker.Mock() + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="online_drive", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={"k": "v"}, + variable_pool=variable_pool, + datasource_param=None, + online_drive_request=types.SimpleNamespace(id="id", bucket="b"), + ) + ) + + variable_pool.add.assert_called_once() + assert variable_pool.add.call_args[0][0] == ["nodeA", "file"] + assert variable_pool.add.call_args[0][1] == file_in + + completed = events[-1] + assert isinstance(completed, StreamCompletedEvent) + assert completed.node_run_result.outputs["file"] == file_in + assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE + + +def test_stream_node_events_skips_file_build_for_non_online_types(mocker): + mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) + + def _transformed(**_kwargs): + yield DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE_LINK, + message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"), + meta={}, + ) + + mocker.patch( + "core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages", + side_effect=_transformed, + ) + build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping") + + events = list( + DatasourceManager.stream_node_events( + node_id="nodeA", + user_id="u1", + datasource_name="ds", + datasource_type="website_crawl", + provider_id="p/x", + tenant_id="t1", + provider="prov", + plugin_id="plug", + credential_id="", + parameters_for_log={}, + datasource_info={}, + variable_pool=mocker.Mock(), + datasource_param=None, + online_drive_request=None, + ) + ) + + build_from_mapping.assert_not_called() + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.outputs["file"] is None + + def test_get_upload_file_by_id_builds_file(mocker): # fake UploadFile row fake_row = types.SimpleNamespace( @@ -133,3 +664,27 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + + +def test_get_upload_file_by_id_raises_when_missing(mocker): + class _Q: + def where(self, *_args, **_kwargs): + return self + + def first(self): + return None + + class _S: + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def query(self, *_): + return _Q() + + mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S()) + + with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"): + DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") diff --git a/api/tests/unit_tests/core/datasource/test_errors.py b/api/tests/unit_tests/core/datasource/test_errors.py new file mode 100644 index 0000000000..95986415b1 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/test_errors.py @@ -0,0 +1,64 @@ +from unittest.mock import MagicMock + +from core.datasource.entities.datasource_entities import DatasourceInvokeMeta +from core.datasource.errors import ( + DatasourceApiSchemaError, + DatasourceEngineInvokeError, + DatasourceInvokeError, + DatasourceNotFoundError, + DatasourceNotSupportedError, + DatasourceParameterValidationError, + DatasourceProviderCredentialValidationError, + DatasourceProviderNotFoundError, +) + + +class TestErrors: + def test_datasource_provider_not_found_error(self): + error = DatasourceProviderNotFoundError("Provider not found") + assert str(error) == "Provider not found" + assert isinstance(error, ValueError) + + def test_datasource_not_found_error(self): + error = DatasourceNotFoundError("Datasource not found") + assert str(error) == "Datasource not found" + assert isinstance(error, ValueError) + + def test_datasource_parameter_validation_error(self): + error = DatasourceParameterValidationError("Validation failed") + assert str(error) == "Validation failed" + assert isinstance(error, ValueError) + + def test_datasource_provider_credential_validation_error(self): + error = DatasourceProviderCredentialValidationError("Credential validation failed") + assert str(error) == "Credential validation failed" + assert isinstance(error, ValueError) + + def test_datasource_not_supported_error(self): + error = DatasourceNotSupportedError("Not supported") + assert str(error) == "Not supported" + assert isinstance(error, ValueError) + + def test_datasource_invoke_error(self): + error = DatasourceInvokeError("Invoke error") + assert str(error) == "Invoke error" + assert isinstance(error, ValueError) + + def test_datasource_api_schema_error(self): + error = DatasourceApiSchemaError("API schema error") + assert str(error) == "API schema error" + assert isinstance(error, ValueError) + + def test_datasource_engine_invoke_error(self): + mock_meta = MagicMock(spec=DatasourceInvokeMeta) + error = DatasourceEngineInvokeError(meta=mock_meta) + assert error.meta == mock_meta + assert isinstance(error, Exception) + + def test_datasource_engine_invoke_error_init(self): + # Test initialization with meta + meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed") + error = DatasourceEngineInvokeError(meta=meta) + assert error.meta == meta + assert error.meta.time_cost == 1.5 + assert error.meta.error == "Engine failed" diff --git a/api/tests/unit_tests/core/datasource/test_file_upload.py b/api/tests/unit_tests/core/datasource/test_file_upload.py index ad86190e00..63b86e64fc 100644 --- a/api/tests/unit_tests/core/datasource/test_file_upload.py +++ b/api/tests/unit_tests/core/datasource/test_file_upload.py @@ -35,7 +35,7 @@ TEST COVERAGE OVERVIEW: - Tests hash consistency and determinism 6. Invalid Filename Handling (TestInvalidFilenameHandling) - - Validates rejection of filenames with invalid characters (/, \\, :, *, ?, ", <, >, |) + - Validates rejection of filenames with path separators (/, \\) - Tests filename length truncation (max 200 characters) - Prevents path traversal attacks - Handles edge cases like empty filenames @@ -535,30 +535,23 @@ class TestInvalidFilenameHandling: @pytest.mark.parametrize( "invalid_char", - ["/", "\\", ":", "*", "?", '"', "<", ">", "|"], + ["/", "\\"], ) def test_filename_contains_invalid_characters(self, invalid_char): """Test detection of invalid characters in filename. - Security-critical test that validates rejection of dangerous filename characters. + Security-critical test that validates rejection of path separators. These characters are blocked because they: - / and \\ : Directory separators, could enable path traversal - - : : Drive letter separator on Windows, reserved character - - * and ? : Wildcards, could cause issues in file operations - - " : Quote character, could break command-line operations - - < and > : Redirection operators, command injection risk - - | : Pipe operator, command injection risk Blocking these characters prevents: - Path traversal attacks (../../etc/passwd) - - Command injection - - File system corruption - - Cross-platform compatibility issues + - ZIP entry traversal issues + - Ambiguous path handling """ # Arrange - Create filename with invalid character filename = f"test{invalid_char}file.txt" - # Define complete list of invalid characters - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check if filename contains any invalid character has_invalid_char = any(c in filename for c in invalid_chars) @@ -570,7 +563,7 @@ class TestInvalidFilenameHandling: """Test that valid filenames pass validation.""" # Arrange filename = "valid_file-name_123.txt" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act has_invalid_char = any(c in filename for c in invalid_chars) @@ -578,6 +571,16 @@ class TestInvalidFilenameHandling: # Assert assert has_invalid_char is False + @pytest.mark.parametrize("safe_char", [":", "*", "?", '"', "<", ">", "|"]) + def test_filename_allows_safe_metadata_characters(self, safe_char): + """Test that non-separator punctuation remains allowed in filenames.""" + filename = f"candidate{safe_char}resume.txt" + invalid_chars = ["/", "\\"] + + has_invalid_char = any(c in filename for c in invalid_chars) + + assert has_invalid_char is False + def test_extremely_long_filename_truncation(self): """Test handling of extremely long filenames.""" # Arrange @@ -904,7 +907,7 @@ class TestFilenameValidation: """Test that filenames with spaces are handled correctly.""" # Arrange filename = "my document with spaces.pdf" - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act - Check for invalid characters has_invalid = any(c in filename for c in invalid_chars) @@ -921,7 +924,7 @@ class TestFilenameValidation: "مستند.txt", # Arabic "ファイル.jpg", # Japanese ] - invalid_chars = ["/", "\\", ":", "*", "?", '"', "<", ">", "|"] + invalid_chars = ["/", "\\"] # Act & Assert - Unicode should be allowed for filename in unicode_filenames: diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py new file mode 100644 index 0000000000..43f582feb7 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -0,0 +1,337 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.entities.datasource_entities import DatasourceMessage +from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer +from dify_graph.file import File +from dify_graph.file.enums import FileTransferMethod, FileType +from models.tools import ToolFile + + +class TestDatasourceFileMessageTransformer: + def test_transform_text_and_link_messages(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello") + ), + DatasourceMessage( + type=DatasourceMessage.MessageType.LINK, + message=DatasourceMessage.TextMessage(text="https://example.com"), + ), + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 2 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert result[0].message.text == "hello" + assert result[1].type == DatasourceMessage.MessageType.LINK + assert result[1].message.text == "https://example.com" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "file_id_123" + mock_tool_file.mimetype = "image/png" + mock_manager.create_file_by_url.return_value = mock_tool_file + mock_guess_ext.return_value = ".png" + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + meta={"some": "meta"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/file_id_123.png" + assert result[0].meta == {"some": "meta"} + mock_manager.create_file_by_url.assert_called_once_with( + user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1" + ) + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_failure(self, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_manager.create_file_by_url.side_effect = Exception("Download failed") + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, + message=DatasourceMessage.TextMessage(text="https://example.com/image.png"), + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.TEXT + assert "Failed to download image" in result[0].message.text + assert "Download failed" in result[0].message.text + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_456" + mock_tool_file.mimetype = "image/jpeg" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_ext.return_value = ".jpg" + + blob_data = b"fake-image-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"mime_type": "image/jpeg", "file_name": "test.jpg"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/blob_id_456.jpg" + mock_manager.create_file_by_raw.assert_called_once() + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + @patch("core.datasource.utils.message_transformer.guess_extension") + @patch("core.datasource.utils.message_transformer.guess_type") + def test_transform_blob_message_binary_guess_mimetype( + self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls + ): + # Setup + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_789" + mock_tool_file.mimetype = "application/pdf" + mock_manager.create_file_by_raw.return_value = mock_tool_file + mock_guess_type.return_value = ("application/pdf", None) + mock_guess_ext.return_value = ".pdf" + + blob_data = b"fake-pdf-bits" + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=blob_data), + meta={"file_name": "test.pdf"}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_789.pdf" + + def test_transform_blob_message_invalid_type(self): + # Setup + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob") + ) + ] + + # Execute & Verify + with pytest.raises(ValueError, match="unexpected message type"): + list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + def test_transform_file_tool_file_image(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_123" + mock_file.extension = ".png" + mock_file.type = FileType.IMAGE + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK + assert result[0].message.text == "/files/datasources/related_123.png" + + def test_transform_file_tool_file_binary(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.TOOL_FILE + mock_file.related_id = "related_456" + mock_file.extension = ".txt" + mock_file.type = FileType.DOCUMENT + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="ignored"), + meta={"file": mock_file}, + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.LINK + assert result[0].message.text == "/files/datasources/related_456.txt" + + def test_transform_file_other_transfer_method(self): + # Setup + mock_file = MagicMock(spec=File) + mock_file.transfer_method = FileTransferMethod.REMOTE_URL + + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.FILE, + message=DatasourceMessage.TextMessage(text="remote image"), + meta={"file": mock_file}, + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_transform_other_message_type(self): + # JSON type is yielded by the default 'else' block or the 'yield message' at the end + msg = DatasourceMessage( + type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"}) + ) + messages = [msg] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify + assert len(result) == 1 + assert result[0] == msg + + def test_get_datasource_file_url(self): + # Test with extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg") + assert url == "/files/datasources/file1.jpg" + + # Test without extension + url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None) + assert url == "/files/datasources/file2.bin" + + def test_transform_blob_message_no_meta_filename(self): + # This tests line 70 where filename might be None + with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls: + mock_manager = mock_tool_file_manager_cls.return_value + mock_tool_file = MagicMock(spec=ToolFile) + mock_tool_file.id = "blob_id_no_name" + mock_tool_file.mimetype = "application/octet-stream" + mock_manager.create_file_by_raw.return_value = mock_tool_file + + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.BLOB, + message=DatasourceMessage.BlobMessage(blob=b"data"), + meta={}, # No mime_type, no file_name + ) + ] + + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK + assert result[0].message.text == "/files/datasources/blob_id_no_name.bin" + + @patch("core.datasource.utils.message_transformer.ToolFileManager") + def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls): + # This tests line 24-26 where it checks if message is instance of TextMessage + messages = [ + DatasourceMessage( + type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text") + ) + ] + + # Execute + result = list( + DatasourceFileMessageTransformer.transform_datasource_invoke_messages( + messages=iter(messages), user_id="user1", tenant_id="tenant1" + ) + ) + + # Verify - should yield unchanged if it's not a TextMessage + assert len(result) == 1 + assert result[0].type == DatasourceMessage.MessageType.IMAGE + assert isinstance(result[0].message, DatasourceMessage.BlobMessage) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py new file mode 100644 index 0000000000..2945eb5523 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_plugin.py @@ -0,0 +1,101 @@ +from collections.abc import Generator +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceEntity, + DatasourceProviderType, + WebsiteCrawlMessage, +) +from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin + + +class TestWebsiteCrawlDatasourcePlugin: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceEntity) + entity.identity = MagicMock() + entity.identity.provider = "test-provider" + entity.identity.name = "test-name" + return entity + + @pytest.fixture + def mock_runtime(self): + runtime = MagicMock(spec=DatasourceRuntime) + runtime.credentials = {"api_key": "test-key"} + return runtime + + def test_init(self, mock_entity, mock_runtime): + # Arrange + tenant_id = "test-tenant-id" + icon = "test-icon" + plugin_unique_identifier = "test-plugin-id" + + # Act + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id=tenant_id, + icon=icon, + plugin_unique_identifier=plugin_unique_identifier, + ) + + # Assert + assert plugin.tenant_id == tenant_id + assert plugin.plugin_unique_identifier == plugin_unique_identifier + assert plugin.entity == mock_entity + assert plugin.runtime == mock_runtime + assert plugin.icon == icon + + def test_datasource_provider_type(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test" + ) + + # Act & Assert + assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_website_crawl(self, mock_entity, mock_runtime): + # Arrange + plugin = WebsiteCrawlDatasourcePlugin( + entity=mock_entity, + runtime=mock_runtime, + tenant_id="test-tenant-id", + icon="test-icon", + plugin_unique_identifier="test-plugin-id", + ) + + user_id = "test-user-id" + datasource_parameters = {"url": "https://example.com"} + provider_type = "firecrawl" + + mock_message = MagicMock(spec=WebsiteCrawlMessage) + + # Mock PluginDatasourceManager + with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class: + mock_manager = mock_manager_class.return_value + mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message]) + + # Act + result = plugin.get_website_crawl( + user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type + ) + + # Assert + assert isinstance(result, Generator) + messages = list(result) + assert len(messages) == 1 + assert messages[0] == mock_message + + mock_manager.get_website_crawl.assert_called_once_with( + tenant_id="test-tenant-id", + user_id=user_id, + datasource_provider="test-provider", + datasource_name="test-name", + credentials={"api_key": "test-key"}, + datasource_parameters=datasource_parameters, + provider_type=provider_type, + ) diff --git a/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py new file mode 100644 index 0000000000..b7822ba800 --- /dev/null +++ b/api/tests/unit_tests/core/datasource/website_crawl/test_website_crawl_provider.py @@ -0,0 +1,95 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.datasource.__base.datasource_runtime import DatasourceRuntime +from core.datasource.entities.datasource_entities import ( + DatasourceProviderEntityWithPlugin, + DatasourceProviderType, +) +from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController + + +class TestWebsiteCrawlDatasourcePluginProviderController: + @pytest.fixture + def mock_entity(self): + entity = MagicMock(spec=DatasourceProviderEntityWithPlugin) + entity.datasources = [] + entity.identity = MagicMock() + entity.identity.icon = "test-icon" + return entity + + def test_init(self, mock_entity): + # Arrange + plugin_id = "test-plugin-id" + plugin_unique_identifier = "test-unique-id" + tenant_id = "test-tenant-id" + + # Act + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, + plugin_id=plugin_id, + plugin_unique_identifier=plugin_unique_identifier, + tenant_id=tenant_id, + ) + + # Assert + assert controller.entity == mock_entity + assert controller.plugin_id == plugin_id + assert controller.plugin_unique_identifier == plugin_unique_identifier + assert controller.tenant_id == tenant_id + + def test_provider_type(self, mock_entity): + # Arrange + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL + + def test_get_datasource_success(self, mock_entity): + # Arrange + datasource_name = "test-datasource" + tenant_id = "test-tenant-id" + plugin_unique_identifier = "test-unique-id" + + mock_datasource_entity = MagicMock() + mock_datasource_entity.identity = MagicMock() + mock_datasource_entity.identity.name = datasource_name + mock_entity.datasources = [mock_datasource_entity] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id + ) + + # Act + with patch( + "core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin" + ) as mock_plugin_class: + mock_plugin_instance = mock_plugin_class.return_value + result = controller.get_datasource(datasource_name) + + # Assert + assert result == mock_plugin_instance + mock_plugin_class.assert_called_once() + args, kwargs = mock_plugin_class.call_args + assert kwargs["entity"] == mock_datasource_entity + assert isinstance(kwargs["runtime"], DatasourceRuntime) + assert kwargs["runtime"].tenant_id == tenant_id + assert kwargs["tenant_id"] == tenant_id + assert kwargs["icon"] == "test-icon" + assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier + + def test_get_datasource_not_found(self, mock_entity): + # Arrange + datasource_name = "non-existent" + mock_entity.datasources = [] + + controller = WebsiteCrawlDatasourcePluginProviderController( + entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test" + ) + + # Act & Assert + with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"): + controller.get_datasource(datasource_name) diff --git a/api/tests/unit_tests/core/entities/test_entities_agent_entities.py b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py new file mode 100644 index 0000000000..2437602695 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_agent_entities.py @@ -0,0 +1,9 @@ +from core.entities.agent_entities import PlanningStrategy + + +def test_planning_strategy_values_are_stable() -> None: + # Arrange / Act / Assert + assert PlanningStrategy.ROUTER.value == "router" + assert PlanningStrategy.REACT_ROUTER.value == "react_router" + assert PlanningStrategy.REACT.value == "react" + assert PlanningStrategy.FUNCTION_CALL.value == "function_call" diff --git a/api/tests/unit_tests/core/entities/test_entities_document_task.py b/api/tests/unit_tests/core/entities/test_entities_document_task.py new file mode 100644 index 0000000000..dd550930d7 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_document_task.py @@ -0,0 +1,18 @@ +from core.entities.document_task import DocumentTask + + +def test_document_task_keeps_indexing_identifiers() -> None: + # Arrange + document_ids = ("doc-1", "doc-2") + + # Act + task = DocumentTask( + tenant_id="tenant-1", + dataset_id="dataset-1", + document_ids=document_ids, + ) + + # Assert + assert task.tenant_id == "tenant-1" + assert task.dataset_id == "dataset-1" + assert task.document_ids == document_ids diff --git a/api/tests/unit_tests/core/entities/test_entities_embedding_type.py b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py new file mode 100644 index 0000000000..5a82fc4842 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_embedding_type.py @@ -0,0 +1,7 @@ +from core.entities.embedding_type import EmbeddingInputType + + +def test_embedding_input_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert EmbeddingInputType.DOCUMENT.value == "document" + assert EmbeddingInputType.QUERY.value == "query" diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py new file mode 100644 index 0000000000..2e4f6d34fb --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -0,0 +1,45 @@ +from core.entities.execution_extra_content import ( + ExecutionExtraContentDomainModel, + HumanInputContent, + HumanInputFormDefinition, + HumanInputFormSubmissionData, +) +from dify_graph.nodes.human_input.entities import FormInput, UserAction +from dify_graph.nodes.human_input.enums import FormInputType +from models.execution_extra_content import ExecutionContentType + + +def test_human_input_content_defaults_and_domain_alias() -> None: + # Arrange + form_definition = HumanInputFormDefinition( + form_id="form-1", + node_id="node-1", + node_title="Human Input", + form_content="Please confirm", + inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="answer")], + actions=[UserAction(id="confirm", title="Confirm")], + resolved_default_values={"answer": "yes"}, + expiration_time=1_700_000_000, + ) + submission_data = HumanInputFormSubmissionData( + node_id="node-1", + node_title="Human Input", + rendered_content="Please confirm", + action_id="confirm", + action_text="Confirm", + ) + + # Act + content = HumanInputContent( + workflow_run_id="workflow-run-1", + submitted=True, + form_definition=form_definition, + form_submission_data=submission_data, + ) + + # Assert + assert form_definition.model_config.get("frozen") is True + assert content.type == ExecutionContentType.HUMAN_INPUT + assert content.form_definition is form_definition + assert content.form_submission_data is submission_data + assert ExecutionExtraContentDomainModel is HumanInputContent diff --git a/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py new file mode 100644 index 0000000000..d25f20145f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_knowledge_entities.py @@ -0,0 +1,45 @@ +from core.entities.knowledge_entities import ( + PipelineDataset, + PipelineDocument, + PipelineGenerateResponse, +) + + +def test_pipeline_dataset_normalizes_none_description() -> None: + # Arrange / Act + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description=None, + chunk_structure="parent-child", + ) + + # Assert + assert dataset.description == "" + + +def test_pipeline_generate_response_builds_nested_models() -> None: + # Arrange + dataset = PipelineDataset( + id="dataset-1", + name="Dataset", + description="Knowledge base", + chunk_structure="parent-child", + ) + document = PipelineDocument( + id="doc-1", + position=1, + data_source_type="file", + data_source_info={"name": "spec.pdf"}, + name="spec.pdf", + indexing_status="completed", + enabled=True, + ) + + # Act + response = PipelineGenerateResponse(batch="batch-1", dataset=dataset, documents=[document]) + + # Assert + assert response.batch == "batch-1" + assert response.dataset.id == "dataset-1" + assert response.documents[0].id == "doc-1" diff --git a/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py new file mode 100644 index 0000000000..5449c63b45 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_mcp_provider.py @@ -0,0 +1,450 @@ +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities import mcp_provider as mcp_provider_module +from core.entities.mcp_provider import ( + DEFAULT_EXPIRES_IN, + DEFAULT_TOKEN_TYPE, + MCPProviderEntity, +) +from core.mcp.types import OAuthTokens + + +def _build_mcp_provider_entity() -> MCPProviderEntity: + now = datetime(2025, 1, 1, tzinfo=UTC) + return MCPProviderEntity( + id="provider-1", + provider_id="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[], + icon={"en_US": "icon.png"}, + created_at=now, + updated_at=now, + ) + + +def test_from_db_model_maps_fields() -> None: + # Arrange + now = datetime(2025, 1, 1, tzinfo=UTC) + db_provider = SimpleNamespace( + id="provider-1", + server_identifier="server-1", + name="Example MCP", + tenant_id="tenant-1", + user_id="user-1", + server_url="encrypted-server-url", + headers={"Authorization": "enc"}, + timeout=15, + sse_read_timeout=120, + authed=True, + credentials={"access_token": "enc-token"}, + tool_dict=[{"name": "search"}], + icon=None, + created_at=now, + updated_at=now, + ) + + # Act + entity = MCPProviderEntity.from_db_model(db_provider) + + # Assert + assert entity.provider_id == "server-1" + assert entity.tools == [{"name": "search"}] + assert entity.icon == "" + + +def test_redirect_url_uses_console_api_url(monkeypatch: pytest.MonkeyPatch) -> None: + # Arrange + entity = _build_mcp_provider_entity() + monkeypatch.setattr(mcp_provider_module.dify_config, "CONSOLE_API_URL", "https://console.example.com") + + # Act + redirect_url = entity.redirect_url + + # Assert + assert redirect_url == "https://console.example.com/console/api/mcp/oauth/callback" + + +def test_client_metadata_for_authorization_code_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_client_metadata_for_client_credentials_flow() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"grant_types": ["client_credentials"]}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "client_credentials"] + assert metadata.redirect_uris == [] + assert metadata.response_types == [] + + +def test_client_metadata_prefers_nested_authorization_code_grant_type() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "grant_type": "client_credentials", + "client_information": {"grant_types": ["authorization_code"]}, + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + # Act + metadata = entity.client_metadata + + # Assert + assert metadata.grant_types == ["refresh_token", "authorization_code"] + assert metadata.redirect_uris == [entity.redirect_url] + assert metadata.response_types == ["code"] + + +def test_provider_icon_returns_icon_dict_as_is() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + # Act + icon = entity.provider_icon + + # Assert + assert icon == {"en_US": "icon.png"} + + +def test_provider_icon_uses_signed_url_for_plain_path() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": "icons/mcp.png"}) + + with patch( + "core.entities.mcp_provider.file_helpers.get_signed_file_url", + return_value="https://signed.example.com/icons/mcp.png", + ) as mock_get_signed_url: + # Act + icon = entity.provider_icon + + # Assert + mock_get_signed_url.assert_called_once_with("icons/mcp.png") + assert icon == "https://signed.example.com/icons/mcp.png" + + +def test_to_api_response_without_sensitive_data_skips_auth_related_work() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"icon": {"en_US": "icon.png"}}) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + # Act + response = entity.to_api_response(include_sensitive=False) + + # Assert + assert response["author"] == "Anonymous" + assert response["masked_headers"] == {} + assert response["is_dynamic_registration"] is True + assert "authentication" not in response + + +def test_to_api_response_with_sensitive_data_includes_masked_values() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={ + "credentials": {"client_information": {"is_dynamic_registration": False}}, + "icon": {"en_US": "icon.png"}, + } + ) + + with patch.object(MCPProviderEntity, "masked_server_url", return_value="https://api.example.com/******"): + with patch.object(MCPProviderEntity, "masked_headers", return_value={"Authorization": "Be****"}): + with patch.object(MCPProviderEntity, "masked_credentials", return_value={"client_id": "cl****"}): + # Act + response = entity.to_api_response(user_name="Rajat", include_sensitive=True) + + # Assert + assert response["author"] == "Rajat" + assert response["masked_headers"] == {"Authorization": "Be****"} + assert response["authentication"] == {"client_id": "cl****"} + assert response["is_dynamic_registration"] is False + + +def test_retrieve_client_information_decrypts_nested_secret() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = {"client_information": {"client_id": "client-1", "encrypted_client_secret": "enc-secret"}} + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="plain-secret") as mock_decrypt: + # Act + client_info = entity.retrieve_client_information() + + # Assert + assert client_info is not None + assert client_info.client_id == "client-1" + assert client_info.client_secret == "plain-secret" + mock_decrypt.assert_called_once_with("tenant-1", "enc-secret") + + +def test_retrieve_client_information_returns_none_for_missing_data() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + result_empty = entity.retrieve_client_information() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + result_invalid = entity.retrieve_client_information() + + # Assert + assert result_empty is None + assert result_invalid is None + + +def test_masked_server_url_hides_path_segments() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object( + MCPProviderEntity, + "decrypt_server_url", + return_value="https://api.example.com/v1/mcp?query=1", + ): + # Act + masked_url = entity.masked_server_url() + + # Assert + assert masked_url == "https://api.example.com/******?query=1" + + +def test_mask_value_covers_short_and_long_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + short_masked = entity._mask_value("short") + long_masked = entity._mask_value("abcdefghijkl") + + # Assert + assert short_masked == "*****" + assert long_masked == "ab********kl" + + +def test_masked_headers_masks_all_decrypted_header_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "abcdefgh"}): + # Act + masked = entity.masked_headers() + + # Assert + assert masked == {"Authorization": "ab****gh"} + + +def test_masked_credentials_handles_nested_secret_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + credentials = { + "client_information": { + "client_id": "client-id", + "encrypted_client_secret": "encrypted-value", + "client_secret": "plain-secret", + } + } + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value=credentials): + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="decrypted-secret"): + # Act + masked = entity.masked_credentials() + + # Assert + assert masked["client_id"] == "cl*****id" + assert masked["client_secret"] == "pl********et" + + +def test_masked_credentials_returns_empty_for_missing_client_information() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={}): + # Act + masked_empty = entity.masked_credentials() + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"client_information": "invalid"}): + # Act + masked_invalid = entity.masked_credentials() + + # Assert + assert masked_empty == {} + assert masked_invalid == {} + + +def test_retrieve_tokens_returns_defaults_when_optional_fields_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object( + MCPProviderEntity, + "decrypt_credentials", + return_value={"access_token": "token", "expires_in": "", "refresh_token": "refresh"}, + ): + # Act + tokens = entity.retrieve_tokens() + + # Assert + assert isinstance(tokens, OAuthTokens) + assert tokens.access_token == "token" + assert tokens.token_type == DEFAULT_TOKEN_TYPE + assert tokens.expires_in == DEFAULT_EXPIRES_IN + assert tokens.refresh_token == "refresh" + + +def test_retrieve_tokens_returns_none_when_access_token_missing() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"credentials": {"token": "encrypted"}}) + + with patch.object(MCPProviderEntity, "decrypt_credentials", return_value={"access_token": ""}) as mock_decrypt: + # Act + tokens = entity.retrieve_tokens() + + # Assert + mock_decrypt.assert_called_once() + assert tokens is None + + +def test_decrypt_server_url_delegates_to_encrypter() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch("core.entities.mcp_provider.encrypter.decrypt_token", return_value="https://api.example.com") as mock: + # Act + decrypted = entity.decrypt_server_url() + + # Assert + mock.assert_called_once_with("tenant-1", "encrypted-server-url") + assert decrypted == "https://api.example.com" + + +def test_decrypt_authentication_injects_authorization_for_oauth() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy(update={"authed": True, "headers": {}}) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc123", token_type="bearer"), + ): + # Act + headers = entity.decrypt_authentication() + + # Assert + assert headers["Authorization"] == "Bearer abc123" + + +def test_decrypt_authentication_does_not_overwrite_existing_headers() -> None: + # Arrange + entity = _build_mcp_provider_entity().model_copy( + update={"authed": True, "headers": {"Authorization": "encrypted-header"}} + ) + + with patch.object(MCPProviderEntity, "decrypt_headers", return_value={"Authorization": "existing"}): + with patch.object( + MCPProviderEntity, + "retrieve_tokens", + return_value=OAuthTokens(access_token="abc", token_type="bearer"), + ) as mock_tokens: + # Act + headers = entity.decrypt_authentication() + + # Assert + mock_tokens.assert_not_called() + assert headers == {"Authorization": "existing"} + + +def test_decrypt_dict_returns_empty_for_empty_input() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + # Act + decrypted = entity._decrypt_dict({}) + + # Assert + assert decrypted == {} + + +def test_decrypt_dict_returns_original_data_when_no_encrypted_fields() -> None: + # Arrange + entity = _build_mcp_provider_entity() + input_data = {"nested": {"k": "v"}, "count": 2, "empty": ""} + + # Act + result = entity._decrypt_dict(input_data) + + # Assert + assert result is input_data + + +def test_decrypt_dict_only_decrypts_top_level_string_values() -> None: + # Arrange + entity = _build_mcp_provider_entity() + decryptor = Mock() + decryptor.decrypt.return_value = {"api_key": "plain-key"} + + def _fake_create_provider_encrypter(*, tenant_id: str, config: list, cache): + assert tenant_id == "tenant-1" + assert any(item.name == "api_key" for item in config) + return decryptor, None + + with patch("core.tools.utils.encryption.create_provider_encrypter", side_effect=_fake_create_provider_encrypter): + # Act + result = entity._decrypt_dict( + { + "api_key": "encrypted-key", + "nested": {"client_id": "unchanged"}, + "empty": "", + "count": 2, + } + ) + + # Assert + decryptor.decrypt.assert_called_once_with({"api_key": "encrypted-key"}) + assert result["api_key"] == "plain-key" + assert result["nested"] == {"client_id": "unchanged"} + assert result["count"] == 2 + + +def test_decrypt_headers_and_credentials_delegate_to_decrypt_dict() -> None: + # Arrange + entity = _build_mcp_provider_entity() + + with patch.object(MCPProviderEntity, "_decrypt_dict", side_effect=[{"h": "v"}, {"c": "v"}]) as mock: + # Act + headers = entity.decrypt_headers() + credentials = entity.decrypt_credentials() + + # Assert + assert mock.call_count == 2 + assert headers == {"h": "v"} + assert credentials == {"c": "v"} diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py new file mode 100644 index 0000000000..7a3d5e84ed --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -0,0 +1,92 @@ +"""Unit tests for model entity behavior and invariants. + +Covers DefaultModelEntity, DefaultModelProviderEntity, ModelStatus, +ProviderModelWithStatusEntity, and SimpleModelProviderEntity. Assumes i18n +labels are provided via I18nObject, model metadata aligns with FetchFrom and +ModelType expectations, and ProviderEntity/ConfigurateMethod interactions +drive provider mapping behavior. +""" + +import pytest + +from core.entities.model_entities import ( + DefaultModelEntity, + DefaultModelProviderEntity, + ModelStatus, + ProviderModelWithStatusEntity, + SimpleModelProviderEntity, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: + return ProviderModelWithStatusEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + status=status, + ) + + +def test_simple_model_provider_entity_maps_from_provider_entity() -> None: + # Arrange + provider_entity = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + + # Act + simple_provider = SimpleModelProviderEntity(provider_entity) + + # Assert + assert simple_provider.provider == "openai" + assert simple_provider.label.en_US == "OpenAI" + assert simple_provider.supported_model_types == [ModelType.LLM] + + +def test_provider_model_with_status_raises_for_known_error_statuses() -> None: + # Arrange + expectations = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + for status, message in expectations.items(): + # Act / Assert + with pytest.raises(ValueError, match=message): + _build_model_with_status(status).raise_for_status() + + +def test_provider_model_with_status_allows_active_and_credential_removed() -> None: + # Arrange + active_model = _build_model_with_status(ModelStatus.ACTIVE) + removed_model = _build_model_with_status(ModelStatus.CREDENTIAL_REMOVED) + + # Act / Assert + active_model.raise_for_status() + removed_model.raise_for_status() + + +def test_default_model_entity_accepts_model_field_name() -> None: + # Arrange / Act + default_model = DefaultModelEntity( + model="gpt-4o-mini", + model_type=ModelType.LLM, + provider=DefaultModelProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + ), + ) + + # Assert + assert default_model.model == "gpt-4o-mini" + assert default_model.provider.provider == "openai" diff --git a/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py new file mode 100644 index 0000000000..20b7bf2a9f --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_parameter_entities.py @@ -0,0 +1,22 @@ +from core.entities.parameter_entities import ( + AppSelectorScope, + CommonParameterType, + ModelSelectorScope, + ToolSelectorScope, +) + + +def test_common_parameter_type_values_are_stable() -> None: + # Arrange / Act / Assert + assert CommonParameterType.SECRET_INPUT.value == "secret-input" + assert CommonParameterType.MODEL_SELECTOR.value == "model-selector" + assert CommonParameterType.DYNAMIC_SELECT.value == "dynamic-select" + assert CommonParameterType.ARRAY.value == "array" + assert CommonParameterType.OBJECT.value == "object" + + +def test_selector_scope_values_are_stable() -> None: + # Arrange / Act / Assert + assert AppSelectorScope.WORKFLOW.value == "workflow" + assert ModelSelectorScope.TEXT_EMBEDDING.value == "text-embedding" + assert ToolSelectorScope.BUILTIN.value == "builtin" diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py new file mode 100644 index 0000000000..95d58757f1 --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -0,0 +1,1870 @@ +from __future__ import annotations + +from contextlib import contextmanager +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from constants import HIDDEN_VALUE +from core.entities.model_entities import ModelStatus +from core.entities.provider_configuration import ProviderConfiguration, ProviderConfigurations +from core.entities.provider_entities import ( + CredentialConfiguration, + CustomConfiguration, + CustomModelConfiguration, + CustomProviderConfiguration, + ModelLoadBalancingConfiguration, + ModelSettings, + ProviderQuotaType, + QuotaConfiguration, + QuotaUnit, + RestrictModel, + SystemConfiguration, + SystemConfigurationStatus, +) +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from models.enums import CredentialSourceType +from models.provider import ProviderType +from models.provider_ids import ModelProviderID + +_UNSET = object() + + +def _build_provider_configuration(*, provider_name: str = "openai") -> ProviderConfiguration: + provider_entity = ProviderEntity( + provider=provider_name, + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=1_000, + quota_used=0, + is_valid=True, + restrict_models=[], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + return ProviderConfiguration( + tenant_id="tenant-1", + provider=provider_entity, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + +def _build_ai_model(name: str, *, model_type: ModelType = ModelType.LLM) -> AIModelEntity: + return AIModelEntity( + model=name, + label=I18nObject(en_US=name), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _exec_result( + *, + scalar_one_or_none: Any = _UNSET, + scalar: Any = _UNSET, + scalars_all: Any = _UNSET, + scalars_first: Any = _UNSET, +) -> Mock: + result = Mock() + if scalar_one_or_none is not _UNSET: + result.scalar_one_or_none.return_value = scalar_one_or_none + if scalar is not _UNSET: + result.scalar.return_value = scalar + if scalars_all is not _UNSET or scalars_first is not _UNSET: + scalars = Mock() + if scalars_all is not _UNSET: + scalars.all.return_value = scalars_all + if scalars_first is not _UNSET: + scalars.first.return_value = scalars_first + result.scalars.return_value = scalars + return result + + +@contextmanager +def _patched_session(session: Mock): + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + mock_session_cls.return_value.__enter__.return_value = session + yield mock_session_cls + + +def _build_secret_provider_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + + +def _build_secret_model_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ], + ) + + +def test_extract_secret_variables_returns_only_secret_inputs() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + secret_variables = configuration.extract_secret_variables(credential_form_schemas) + assert secret_variables == ["api_key"] + + +def test_obfuscated_credentials_masks_only_secret_fields() -> None: + configuration = _build_provider_configuration() + credential_form_schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + ), + ] + + with patch( + "core.entities.provider_configuration.encrypter.obfuscated_token", + side_effect=lambda value: f"masked-{value[-2:]}", + ): + obfuscated = configuration.obfuscated_credentials( + credentials={"api_key": "sk-test-1234", "endpoint": "https://api.example.com"}, + credential_form_schemas=credential_form_schemas, + ) + + assert obfuscated["api_key"] == "masked-34" + assert obfuscated["endpoint"] == "https://api.example.com" + + +def test_provider_configurations_behave_like_keyed_container() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + + configurations[provider_key] = configuration + + assert "openai" in configurations + assert configurations["openai"] is configuration + assert configurations.get("openai") is configuration + assert configurations.to_list() == [configuration] + assert list(configurations) == [(provider_key, configuration)] + + +def test_provider_configurations_get_models_forwards_filters() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + expected_model = Mock() + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[expected_model]) as mock_get: + models = configurations.get_models(provider="openai", model_type=ModelType.LLM, only_active=True) + + mock_get.assert_called_once_with(ModelType.LLM, True) + assert models == [expected_model] + + +def test_provider_configurations_get_models_skips_non_matching_provider_filter() -> None: + configuration = _build_provider_configuration() + provider_key = str(ModelProviderID("openai")) + configurations = ProviderConfigurations(tenant_id="tenant-1") + configurations[provider_key] = configuration + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[Mock()]) as mock_get: + models = configurations.get_models(provider="anthropic", model_type=ModelType.LLM, only_active=True) + + assert models == [] + mock_get.assert_not_called() + + +def test_get_current_credentials_custom_provider_checks_current_credential() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + current_credential_id="credential-1", + current_credential_name="Primary", + available_credentials=[], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert mock_check.call_count == 1 + assert mock_check.call_args.kwargs["credential_id"] == "credential-1" + assert mock_check.call_args.kwargs["provider"] == "openai" + + +def test_get_current_credentials_custom_provider_checks_all_available_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[ + CredentialConfiguration(credential_id="cred-1", credential_name="First"), + CredentialConfiguration(credential_id="cred-2", credential_name="Second"), + ], + ) + + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + assert credentials == {"api_key": "provider-key"} + assert [c.kwargs["credential_id"] for c in mock_check.call_args_list] == ["cred-1", "cred-2"] + assert all(c.kwargs["provider"] == "openai" for c in mock_check.call_args_list) + + +def test_get_system_configuration_status_returns_none_when_current_quota_missing() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.FREE + + status = configuration.get_system_configuration_status() + assert status is None + + +def test_get_provider_names_supports_legacy_and_full_plugin_id() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider = "langgenius/openai/openai" + + provider_names = configuration._get_provider_names() + assert provider_names == ["langgenius/openai/openai", "openai"] + + +def test_generate_next_api_key_name_uses_highest_numeric_suffix() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [ + SimpleNamespace(credential_name="API KEY 9"), + SimpleNamespace(credential_name="legacy"), + SimpleNamespace(credential_name=" API KEY 2 "), + ] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 10" + + +def test_generate_next_api_key_name_falls_back_to_default_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + + def _raise_query_error(): + raise RuntimeError("boom") + + name = configuration._generate_next_api_key_name(session=session, query_factory=_raise_query_error) + assert name == "API KEY 1" + + +def test_generate_provider_and_custom_model_names_delegate_to_shared_generator() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_generate_next_api_key_name", return_value="API KEY 7") as mock_generator: + provider_name = configuration._generate_provider_credential_name(session=Mock()) + custom_model_name = configuration._generate_custom_model_credential_name( + model="gpt-4o", + model_type=ModelType.LLM, + session=Mock(), + ) + + assert provider_name == "API KEY 7" + assert custom_model_name == "API KEY 7" + assert mock_generator.call_count == 2 + + +def test_get_provider_credential_uses_specific_lookup_when_id_provided() -> None: + configuration = _build_provider_configuration() + + with patch.object(configuration, "_get_specific_provider_credential", return_value={"api_key": "***"}) as mock_get: + credential = configuration.get_provider_credential("credential-1") + + assert credential == {"api_key": "***"} + mock_get.assert_called_once_with("credential-1") + + +def test_validate_provider_credentials_handles_hidden_secret_value() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="openai_api_key", + label=I18nObject(en_US="API Key"), + type=FormType.SECRET_INPUT, + ) + ] + ) + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="encrypted-old-key") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): + with patch( + "core.entities.provider_configuration.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc::{value}", + ): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE, "region": "us"}, + credential_id="credential-1", + session=session, + ) + + assert validated["openai_api_key"] == "enc::restored-key" + assert validated["region"] == "us" + mock_factory.provider_credentials_validate.assert_called_once_with( + provider="openai", + credentials={"openai_api_key": "restored-key", "region": "us"}, + ) + + +def test_validate_provider_credentials_opens_session_when_not_passed() -> None: + configuration = _build_provider_configuration() + mock_session = Mock() + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"region": "us"} + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + with patch("core.entities.provider_configuration.db") as mock_db: + mock_db.engine = Mock() + mock_session_cls.return_value.__enter__.return_value = mock_session + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_provider_credentials(credentials={"region": "us"}) + + assert validated == {"region": "us"} + mock_session_cls.assert_called_once() + + +def test_switch_preferred_provider_type_returns_early_when_no_change_or_unsupported() -> None: + configuration = _build_provider_configuration() + + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + configuration.preferred_provider_type = ProviderType.CUSTOM + configuration.system_configuration.enabled = False + with patch("core.entities.provider_configuration.Session") as mock_session_cls: + configuration.switch_preferred_provider_type(ProviderType.SYSTEM) + mock_session_cls.assert_not_called() + + +def test_switch_preferred_provider_type_updates_existing_record_with_session() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.CUSTOM + session = Mock() + existing_record = SimpleNamespace(preferred_provider_type="custom") + session.execute.return_value.scalars.return_value.first.return_value = existing_record + + configuration.switch_preferred_provider_type(ProviderType.SYSTEM, session=session) + + assert existing_record.preferred_provider_type == ProviderType.SYSTEM + session.commit.assert_called_once() + + +def test_switch_preferred_provider_type_creates_record_when_missing() -> None: + configuration = _build_provider_configuration() + configuration.preferred_provider_type = ProviderType.SYSTEM + session = Mock() + session.execute.return_value.scalars.return_value.first.return_value = None + + configuration.switch_preferred_provider_type(ProviderType.CUSTOM, session=session) + + assert session.add.call_count == 1 + session.commit.assert_called_once() + + +def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: + configuration = _build_provider_configuration() + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) + mock_factory.get_model_schema.assert_called_once_with( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"api_key": "x"}, + ) + + +def test_get_provider_model_returns_none_when_model_not_found() -> None: + configuration = _build_provider_configuration() + fake_model = SimpleNamespace(model="other-model") + + with patch.object(ProviderConfiguration, "get_provider_models", return_value=[fake_model]): + selected = configuration.get_provider_model(ModelType.LLM, "gpt-4o") + + assert selected is None + + +def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> None: + configuration = _build_provider_configuration() + configuration.provider.position = {"llm": ["b-model", "a-model"]} + configuration.model_settings = [ + ModelSettings(model="a-model", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("a-model"), _build_ai_model("b-model"), _build_ai_model("a-model")], + ) + mock_factory = Mock() + mock_factory.get_provider_schema.return_value = provider_schema + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) + active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) + + assert [model.model for model in all_models] == ["b-model", "a-model"] + assert [model.status for model in all_models] == [ModelStatus.ACTIVE, ModelStatus.DISABLED] + assert [model.model for model in active_models] == ["b-model"] + + +def test_get_custom_provider_models_sets_status_for_removed_credentials_and_invalid_lb_configs() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="custom-model", + model_type=ModelType.LLM, + credentials=None, + available_model_credentials=[CredentialConfiguration(credential_id="c-1", credential_name="first")], + ) + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-model")], + ) + model_setting_map = { + ModelType.LLM: { + "base-model": ModelSettings( + model="base-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-base", + name="LB Base", + credentials={}, + credential_source_type=CredentialSourceType.PROVIDER, + ) + ], + ), + "custom-model": ModelSettings( + model="custom-model", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=True, + load_balancing_configs=[ + ModelLoadBalancingConfiguration( + id="lb-custom", + name="LB Custom", + credentials={}, + credential_source_type=CredentialSourceType.CUSTOM_MODEL, + ) + ], + ), + } + } + + with patch.object(ProviderConfiguration, "get_model_schema", return_value=_build_ai_model("custom-model")): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + ) + + status_map = {model.model: model.status for model in models} + invalid_lb_map = {model.model: model.has_invalid_load_balancing_configs for model in models} + assert status_map["base-model"] == ModelStatus.ACTIVE + assert status_map["custom-model"] == ModelStatus.CREDENTIAL_REMOVED + assert invalid_lb_map["base-model"] is True + assert invalid_lb_map["custom-model"] is True + + +def test_validator_adds_predefined_model_for_customizable_provider_with_restrictions() -> None: + provider = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + ) + system_configuration = SystemConfiguration( + enabled=True, + credentials={"api_key": "test-key"}, + current_quota_type=ProviderQuotaType.TRIAL, + quota_configurations=[ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="restricted", base_model_name="base-model", model_type=ModelType.LLM) + ], + ) + ], + ) + with patch("core.entities.provider_configuration.original_provider_configurate_methods", {}): + configuration = ProviderConfiguration( + tenant_id="tenant-1", + provider=provider, + preferred_provider_type=ProviderType.SYSTEM, + using_provider_type=ProviderType.SYSTEM, + system_configuration=system_configuration, + custom_configuration=CustomConfiguration(provider=None, models=[]), + model_settings=[], + ) + + assert ConfigurateMethod.PREDEFINED_MODEL in configuration.provider.configurate_methods + + +def test_get_current_credentials_system_handles_disable_and_restricted_base_model() -> None: + configuration = _build_provider_configuration() + configuration.model_settings = [ + ModelSettings(model="gpt-4o", model_type=ModelType.LLM, enabled=False, load_balancing_configs=[]) + ] + + with pytest.raises(ValueError, match="Model gpt-4o is disabled"): + configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + + configuration.model_settings = [] + configuration.system_configuration.quota_configurations[0].restrict_models = [ + RestrictModel(model="gpt-4o", base_model_name="base-model", model_type=ModelType.LLM) + ] + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "base-model" + + +def test_get_current_credentials_prefers_model_specific_custom_credentials() -> None: + configuration = _build_provider_configuration() + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"api_key": "model-key"}, + ) + ] + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials == {"api_key": "model-key"} + + +def test_get_system_configuration_status_falsey_quota_returns_unsupported() -> None: + class _FalseyQuota: + quota_type = ProviderQuotaType.TRIAL + is_valid = True + + def __bool__(self) -> bool: + return False + + configuration = _build_provider_configuration() + configuration.system_configuration.quota_configurations = [_FalseyQuota()] # type: ignore[list-item] + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + +def test_get_provider_credential_default_uses_custom_provider_credentials() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + obfuscated = configuration.get_provider_credential() + assert obfuscated == {"api_key": "provider-key"} + + +def test_custom_configuration_availability_and_provider_record_helpers() -> None: + configuration = _build_provider_configuration() + assert not configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = CustomProviderConfiguration( + credentials={"api_key": "provider-key"}, + available_credentials=[CredentialConfiguration(credential_id="cred-1", credential_name="Main")], + ) + assert configuration.is_custom_configuration_available() + + configuration.custom_configuration.provider = None + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="gpt-4o", model_type=ModelType.LLM, credentials={"api_key": "model-key"}) + ] + assert configuration.is_custom_configuration_available() + + session = Mock() + provider_record = SimpleNamespace(id="provider-1") + session.execute.return_value.scalar_one_or_none.return_value = provider_record + assert configuration._get_provider_record(session) is provider_record + + session.execute.return_value.scalar_one_or_none.return_value = None + assert configuration._get_provider_record(session) is None + + +def test_check_provider_credential_name_exists_and_model_setting_lookup() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = "existing-id" + assert configuration._check_provider_credential_name_exists("Main", session) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_provider_credential_name_exists("Main", session, exclude_id="cred-2") + + setting = SimpleNamespace(id="setting-1") + session.execute.return_value.scalars.return_value.first.return_value = setting + assert configuration._get_provider_model_setting(ModelType.LLM, "gpt-4o", session) is setting + + +def test_validate_provider_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-key"} + + +def test_generate_next_api_key_name_returns_default_when_no_records() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + name = configuration._generate_next_api_key_name(session=session, query_factory=lambda: Mock()) + assert name == "API KEY 1" + + +def test_create_provider_credential_creates_provider_record_when_missing() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch.object( + ProviderConfiguration, + "_generate_provider_credential_name", + return_value="API KEY 2", + ): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_provider_credential({"api_key": "raw"}, None) + + assert session.add.call_count == 2 + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.CUSTOM, session=session) + + +def test_create_provider_credential_marks_existing_provider_as_valid() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + assert provider_record.credential_id == "existing-cred" + session.commit.assert_called_once() + + +def test_create_provider_credential_auto_activates_when_no_active_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache"): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + assert provider_record.is_valid is True + assert provider_record.credential_id is not None + session.commit.assert_called_once() + + +def test_create_provider_credential_raises_when_duplicate_name_exists() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_provider_credential({"api_key": "raw"}, "Main") + + +def test_update_provider_credential_success_updates_and_invalidates_cache() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_provider_credential( + credentials={"api_key": "raw"}, + credential_id="cred-1", + credential_name="New Name", + ) + + assert credential_record.credential_name == "New Name" + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + mock_lb.assert_called_once() + + +def test_update_provider_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", None) + + +def test_update_load_balancing_configs_updates_all_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + lb_config = SimpleNamespace(id="lb-1", encrypted_config="old", name="old", updated_at=None) + session.execute.return_value.scalars.return_value.all.return_value = [lb_config] + credential_record = SimpleNamespace(encrypted_config='{"api_key":"enc"}', credential_name="API KEY 3") + + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=credential_record, + credential_source=CredentialSourceType.PROVIDER, + session=session, + ) + + assert lb_config.encrypted_config == '{"api_key":"enc"}' + assert lb_config.name == "API KEY 3" + mock_cache.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +def test_update_load_balancing_configs_returns_when_no_matching_configs() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalars.return_value.all.return_value = [] + + configuration._update_load_balancing_configs_with_credential( + credential_id="cred-1", + credential_record=SimpleNamespace(encrypted_config="{}", credential_name="Main"), + credential_source=CredentialSourceType.PROVIDER, + session=session, + ) + + session.commit.assert_not_called() + + +def test_delete_provider_credential_removes_provider_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert any(call.args and call.args[0] is provider_record for call in session.delete.call_args_list) + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_delete_provider_credential_raises_when_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_provider_credential("cred-1") + + +def test_delete_provider_credential_unsets_active_credential_when_more_available() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_record = SimpleNamespace(id="provider-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_provider_credential("cred-1") + + assert provider_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + mock_switch.assert_called_once_with(provider_type=ProviderType.SYSTEM, session=session) + + +def test_switch_active_provider_credential_success_and_failures() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(ValueError, match="Provider record not found"): + configuration.switch_active_provider_credential("cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch.object(ProviderConfiguration, "switch_preferred_provider_type") as mock_switch: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_active_provider_credential("cred-1") + + assert provider_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + mock_switch.assert_called_once_with(ProviderType.CUSTOM, session=session) + + +def test_get_custom_model_record_supports_plugin_id_alias() -> None: + configuration = _build_provider_configuration(provider_name="langgenius/openai/openai") + session = Mock() + custom_model_record = SimpleNamespace(id="model-1") + session.execute.return_value.scalar_one_or_none.return_value = custom_model_record + + result = configuration._get_custom_model_record(ModelType.LLM, "gpt-4o", session) + assert result is custom_model_record + + +def test_get_specific_custom_model_credential_success_and_not_found() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + record = SimpleNamespace(id="cred-1", credential_name="Main", encrypted_config='{"openai_api_key":"enc"}') + session.execute.return_value.scalar_one_or_none.return_value = record + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + response = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert response["current_credential_id"] == "cred-1" + assert response["credentials"] == {"openai_api_key": "***"} + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential with id cred-1 not found"): + configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config="{invalid-json", + ) + with _patched_session(session): + invalid_json = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert invalid_json["credentials"] == {} + + +def test_check_custom_model_credential_name_exists_respects_exclusion() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + assert configuration._check_custom_model_credential_name_exists( + ModelType.LLM, "gpt-4o", "Main", session, exclude_id="other-id" + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + assert not configuration._check_custom_model_credential_name_exists(ModelType.LLM, "gpt-4o", "Main", session) + + +def test_get_custom_model_credential_uses_specific_id_or_configuration_fallback() -> None: + configuration = _build_provider_configuration() + with patch.object( + ProviderConfiguration, + "_get_specific_custom_model_credential", + return_value={"current_credential_id": "cred-1"}, + ) as mock_specific: + result = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert result == {"current_credential_id": "cred-1"} + mock_specific.assert_called_once() + + configuration.provider.model_credential_schema = _build_secret_model_schema() + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="gpt-4o", + model_type=ModelType.LLM, + credentials={"openai_api_key": "raw"}, + current_credential_id="cred-1", + current_credential_name="Main", + ) + ] + with patch.object(ProviderConfiguration, "obfuscated_credentials", return_value={"openai_api_key": "***"}): + fallback = configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) + assert fallback == { + "current_credential_id": "cred-1", + "current_credential_name": "Main", + "credentials": {"openai_api_key": "***"}, + } + + configuration.custom_configuration.models = [] + assert configuration.get_custom_model_credential(ModelType.LLM, "gpt-4o", None) is None + + +def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless_path() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc"}' + ) + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + assert validated == {"openai_api_key": "enc-new"} + + session = Mock() + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"region": "us"} + with _patched_session(session): + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"region": "us"}, + ) + assert validated == {"region": "us"} + + +def test_create_update_delete_custom_model_credential_flow() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.flush.side_effect = lambda: None + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + credential_record = SimpleNamespace(id="cred-1", encrypted_config="{}", credential_name="Old", updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 1"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + assert session.add.call_count == 2 + assert mock_cache.return_value.delete.call_count == 1 + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc2"}, + ): + with patch.object( + ProviderConfiguration, + "_get_custom_model_record", + return_value=provider_model_record, + ): + with patch.object( + ProviderConfiguration, + "_update_load_balancing_configs_with_credential", + ) as mock_lb: + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="New Name", + credential_id="cred-1", + ) + assert credential_record.credential_name == "New Name" + assert mock_cache.return_value.delete.call_count == 1 + mock_lb.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + lb_config = SimpleNamespace(id="lb-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[lb_config]), + _exec_result(scalar=2), + ] + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id is None + assert mock_cache.return_value.delete.call_count == 2 + + +def test_add_model_credential_to_model_and_switch_custom_model_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + session.add.assert_called_once() + session.commit.assert_called_once() + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with pytest.raises(ValueError, match="Can't add same credential"): + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-2") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.add_model_credential_to_model(ModelType.LLM, "gpt-4o", "cred-2") + assert provider_model_record.credential_id == "cred-2" + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="custom model record not found"): + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id=None, updated_at=None) + session.execute.return_value.scalar_one_or_none.return_value = credential_record + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.switch_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + assert provider_model_record.credential_id == "cred-1" + mock_cache.return_value.delete.assert_called_once() + + +def test_delete_custom_model_and_model_setting_methods() -> None: + configuration = _build_provider_configuration() + session = Mock() + provider_model_record = SimpleNamespace(id="model-1") + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + with patch("core.entities.provider_configuration.ProviderCredentialsCache") as mock_cache: + configuration.delete_custom_model(ModelType.LLM, "gpt-4o") + session.delete.assert_called_once_with(provider_model_record) + session.commit.assert_called_once() + mock_cache.return_value.delete.assert_called_once() + + session = Mock() + existing = SimpleNamespace(enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.enable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is True + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is True + + session = Mock() + existing = SimpleNamespace(enabled=True, load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + assert configuration.disable_model(ModelType.LLM, "gpt-4o") is existing + assert existing.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model(ModelType.LLM, "gpt-4o") + assert created.enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.get_provider_model_setting(ModelType.LLM, "gpt-4o") + assert result is existing + + +def test_model_load_balancing_enable_disable_and_switch_preferred_provider_type_without_session() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar.return_value = 1 + with _patched_session(session): + with pytest.raises(ValueError, match="must be more than 1"): + configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + existing = SimpleNamespace(load_balancing_enabled=False, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is True + + session = Mock() + session.execute.return_value.scalar.return_value = 2 + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.enable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is True + + session = Mock() + existing = SimpleNamespace(load_balancing_enabled=True, updated_at=None) + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=existing): + result = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert result is existing + assert existing.load_balancing_enabled is False + + session = Mock() + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_model_setting", return_value=None): + created = configuration.disable_model_load_balancing(ModelType.LLM, "gpt-4o") + assert created.load_balancing_enabled is False + + configuration.preferred_provider_type = ProviderType.SYSTEM + switch_session = Mock() + with _patched_session(switch_session): + switch_session.execute.return_value.scalars.return_value.first.return_value = None + configuration.switch_preferred_provider_type(ProviderType.CUSTOM) + assert any( + call.args and call.args[0].__class__.__name__ == "TenantPreferredModelProvider" + for call in switch_session.add.call_args_list + ) + switch_session.commit.assert_called() + + +def test_system_and_custom_provider_model_helpers_cover_remaining_skip_paths() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.CUSTOMIZABLE_MODEL], + models=[_build_ai_model("llm-model")], + ) + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="target", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="error-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel(model="none-model", base_model_name="base", model_type=ModelType.LLM), + RestrictModel( + model="embed-model", + base_model_name="base", + model_type=ModelType.TEXT_EMBEDDING, + ), + ], + ), + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + + def _system_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-model": + raise RuntimeError("boom") + if model == "none-model": + return None + if model == "embed-model": + return _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING) + return _build_ai_model("target") + + with patch( + "core.entities.provider_configuration.original_provider_configurate_methods", + {"openai": [ConfigurateMethod.CUSTOMIZABLE_MODEL]}, + ): + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_system_schema): + system_models = configuration._get_system_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={ + ModelType.LLM: { + "target": ModelSettings( + model="target", + model_type=ModelType.LLM, + enabled=False, + load_balancing_configs=[], + ) + } + }, + ) + assert any(model.model == "target" and model.status == ModelStatus.DISABLED for model in system_models) + + configuration.using_provider_type = ProviderType.CUSTOM + configuration.custom_configuration.provider = CustomProviderConfiguration(credentials={"api_key": "provider-key"}) + configuration.custom_configuration.models = [ + CustomModelConfiguration( + model="skip-model-type", + model_type=ModelType.TEXT_EMBEDDING, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="skip-unadded", + model_type=ModelType.LLM, + credentials={"k": "v"}, + unadded_to_model_list=True, + ), + CustomModelConfiguration( + model="skip-filter", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="error-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="none-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + CustomModelConfiguration( + model="disabled-custom", + model_type=ModelType.LLM, + credentials={"k": "v"}, + ), + ] + + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[_build_ai_model("base-disabled")], + ) + model_setting_map = { + ModelType.LLM: { + "base-disabled": ModelSettings( + model="base-disabled", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=True, + load_balancing_configs=[ModelLoadBalancingConfiguration(id="lb-1", name="lb", credentials={})], + ), + "disabled-custom": ModelSettings( + model="disabled-custom", + model_type=ModelType.LLM, + enabled=False, + load_balancing_enabled=False, + load_balancing_configs=[], + ), + } + } + + def _custom_schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_custom_schema): + custom_models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model="disabled-custom", + ) + assert any(model.model == "base-disabled" and model.status == ModelStatus.DISABLED for model in custom_models) + assert any(model.model == "disabled-custom" and model.status == ModelStatus.DISABLED for model in custom_models) + + +def test_get_current_credentials_skips_non_current_quota_restrictions() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.FREE, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="free-base", model_type=ModelType.LLM), + ], + ), + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=0, + is_valid=True, + restrict_models=[ + RestrictModel(model="gpt-4o", base_model_name="trial-base", model_type=ModelType.LLM), + ], + ), + ] + + credentials = configuration.get_current_credentials(ModelType.LLM, "gpt-4o") + assert credentials["base_model_name"] == "trial-base" + + +def test_get_system_configuration_status_covers_disabled_and_quota_exceeded() -> None: + configuration = _build_provider_configuration() + configuration.system_configuration.enabled = False + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.UNSUPPORTED + + configuration.system_configuration.enabled = True + configuration.system_configuration.quota_configurations = [ + QuotaConfiguration( + quota_type=ProviderQuotaType.TRIAL, + quota_unit=QuotaUnit.TOKENS, + quota_limit=100, + quota_used=100, + is_valid=False, + restrict_models=[], + ) + ] + configuration.system_configuration.current_quota_type = ProviderQuotaType.TRIAL + assert configuration.get_system_configuration_status() == SystemConfigurationStatus.QUOTA_EXCEEDED + + +def test_get_specific_provider_credential_decrypts_and_obfuscates_credentials() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret","region":"us"}' + ) + provider_record = SimpleNamespace(provider_name="aliased-openai") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw-secret"): + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "raw-secret", "region": "us"} + + +def test_get_specific_provider_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + encrypted_config='{"openai_api_key":"enc-secret"}' + ) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with patch( + "core.entities.provider_configuration.encrypter.decrypt_token", + side_effect=RuntimeError("boom"), + ): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + credentials = configuration._get_specific_provider_credential("cred-1") + + assert credentials == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_provider_credentials_uses_empty_original_when_record_missing() -> None: + configuration = _build_provider_configuration() + configuration.provider.provider_credential_schema = _build_secret_provider_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + mock_factory = Mock() + mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_provider_credentials( + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_provider_credential_name", return_value="API KEY 9"): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_provider_credential({"api_key": "raw"}, None) + + session.rollback.assert_called_once() + + +def test_update_provider_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + +def test_update_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_provider_credential({"api_key": "raw"}, "cred-1", "Main") + + session.rollback.assert_called_once() + + +def test_delete_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_switch_active_provider_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(id="cred-1") + session.commit.side_effect = RuntimeError("boom") + provider_record = SimpleNamespace(id="provider-1", credential_id=None, updated_at=None) + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record): + with pytest.raises(RuntimeError, match="boom"): + configuration.switch_active_provider_credential("cred-1") + + session.rollback.assert_called_once() + + +def test_get_specific_custom_model_credential_logs_when_decrypt_fails() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + credential_name="Main", + encrypted_config='{"openai_api_key":"enc-secret"}', + ) + + with _patched_session(session): + with patch("core.entities.provider_configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")): + with patch("core.entities.provider_configuration.logger.exception") as mock_logger: + with patch.object( + ProviderConfiguration, + "obfuscated_credentials", + side_effect=lambda credentials, credential_form_schemas: credentials, + ): + result = configuration._get_specific_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert result["credentials"] == {"openai_api_key": "enc-secret"} + mock_logger.assert_called_once() + + +def test_validate_custom_model_credentials_handles_invalid_original_json() -> None: + configuration = _build_provider_configuration() + configuration.provider.model_credential_schema = _build_secret_model_schema() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace(encrypted_config="{invalid-json") + mock_factory = Mock() + mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} + + with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): + validated = configuration.validate_custom_model_credentials( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"openai_api_key": HIDDEN_VALUE}, + credential_id="cred-1", + session=session, + ) + + assert validated == {"openai_api_key": "enc-new"} + + +def test_create_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, "Main") + + +def test_create_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.add.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_generate_custom_model_credential_name", return_value="API KEY 4"): + with patch.object( + ProviderConfiguration, + "validate_custom_model_credentials", + return_value={"openai_api_key": "enc"}, + ): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.create_custom_model_credential(ModelType.LLM, "gpt-4o", {"k": "v"}, None) + + session.rollback.assert_called_once() + + +def test_update_custom_model_credential_raises_on_duplicate_name() -> None: + configuration = _build_provider_configuration() + session = Mock() + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=True): + with pytest.raises(ValueError, match="already exists"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + +def test_update_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = SimpleNamespace( + id="cred-1", + encrypted_config="{}", + credential_name="Main", + updated_at=None, + ) + session.commit.side_effect = RuntimeError("boom") + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_check_custom_model_credential_name_exists", return_value=False): + with patch.object(ProviderConfiguration, "validate_custom_model_credentials", return_value={"k": "v"}): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.update_custom_model_credential( + model_type=ModelType.LLM, + model="gpt-4o", + credentials={"k": "v"}, + credential_name="Main", + credential_id="cred-1", + ) + + session.rollback.assert_called_once() + + +def test_delete_custom_model_credential_raises_when_record_not_found() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.execute.return_value.scalar_one_or_none.return_value = None + + with _patched_session(session): + with pytest.raises(ValueError, match="Credential record not found"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + +def test_delete_custom_model_credential_removes_custom_model_record_when_last_credential() -> None: + configuration = _build_provider_configuration() + session = Mock() + credential_record = SimpleNamespace(id="cred-1") + provider_model_record = SimpleNamespace(id="model-1", credential_id="cred-1", updated_at=None) + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=credential_record), + _exec_result(scalars_all=[]), + _exec_result(scalar=1), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=provider_model_record): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + assert any(call.args and call.args[0] is provider_model_record for call in session.delete.call_args_list) + + +def test_delete_custom_model_credential_rolls_back_on_error() -> None: + configuration = _build_provider_configuration() + session = Mock() + session.delete.side_effect = RuntimeError("boom") + session.execute.side_effect = [ + _exec_result(scalar_one_or_none=SimpleNamespace(id="cred-1")), + _exec_result(scalars_all=[]), + _exec_result(scalar=2), + ] + + with _patched_session(session): + with patch.object(ProviderConfiguration, "_get_custom_model_record", return_value=None): + with pytest.raises(RuntimeError, match="boom"): + configuration.delete_custom_model_credential(ModelType.LLM, "gpt-4o", "cred-1") + + session.rollback.assert_called_once() + + +def test_get_custom_provider_models_skips_schema_models_with_mismatched_type() -> None: + configuration = _build_provider_configuration() + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[ + _build_ai_model("llm-model", model_type=ModelType.LLM), + _build_ai_model("embed-model", model_type=ModelType.TEXT_EMBEDDING), + ], + ) + + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert any(model.model == "llm-model" for model in models) + assert all(model.model != "embed-model" for model in models) + + +def test_get_custom_provider_models_skips_custom_models_on_schema_error_or_none() -> None: + configuration = _build_provider_configuration() + configuration.custom_configuration.models = [ + CustomModelConfiguration(model="error-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="none-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + CustomModelConfiguration(model="ok-custom", model_type=ModelType.LLM, credentials={"k": "v"}), + ] + provider_schema = ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=[], + ) + + def _schema(*, model_type: ModelType, model: str, credentials: dict | None): + if model == "error-custom": + raise RuntimeError("boom") + if model == "none-custom": + return None + return _build_ai_model(model) + + with patch("core.entities.provider_configuration.logger.warning") as mock_warning: + with patch.object(ProviderConfiguration, "get_model_schema", side_effect=_schema): + models = configuration._get_custom_provider_models( + model_types=[ModelType.LLM], + provider_schema=provider_schema, + model_setting_map={}, + ) + + assert mock_warning.call_count == 1 + assert any(model.model == "ok-custom" for model in models) + assert all(model.model != "none-custom" for model in models) diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py new file mode 100644 index 0000000000..c5bfd05a1e --- /dev/null +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -0,0 +1,72 @@ +import pytest + +from core.entities.parameter_entities import AppSelectorScope +from core.entities.provider_entities import ( + BasicProviderConfig, + ModelSettings, + ProviderConfig, + ProviderQuotaType, +) +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType + + +def test_provider_quota_type_value_of_returns_enum_member() -> None: + # Arrange / Act + quota_type = ProviderQuotaType.value_of(ProviderQuotaType.TRIAL.value) + + # Assert + assert quota_type == ProviderQuotaType.TRIAL + + +def test_provider_quota_type_value_of_rejects_unknown_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="No matching enum found"): + ProviderQuotaType.value_of("enterprise") + + +def test_basic_provider_config_type_value_of_handles_known_values() -> None: + # Arrange / Act + parameter_type = BasicProviderConfig.Type.value_of("text-input") + + # Assert + assert parameter_type == BasicProviderConfig.Type.TEXT_INPUT + + +def test_basic_provider_config_type_value_of_rejects_invalid_values() -> None: + # Arrange / Act / Assert + with pytest.raises(ValueError, match="invalid mode value"): + BasicProviderConfig.Type.value_of("unknown") + + +def test_provider_config_to_basic_provider_config_keeps_type_and_name() -> None: + # Arrange + provider_config = ProviderConfig( + type=BasicProviderConfig.Type.SELECT, + name="workspace", + scope=AppSelectorScope.ALL, + options=[ProviderConfig.Option(value="all", label=I18nObject(en_US="All"))], + ) + + # Act + basic_config = provider_config.to_basic_provider_config() + + # Assert + assert isinstance(basic_config, BasicProviderConfig) + assert basic_config.type == BasicProviderConfig.Type.SELECT + assert basic_config.name == "workspace" + + +def test_model_settings_accepts_model_field_name() -> None: + # Arrange / Act + settings = ModelSettings( + model="gpt-4o", + model_type=ModelType.LLM, + enabled=True, + load_balancing_enabled=False, + load_balancing_configs=[], + ) + + # Assert + assert settings.model == "gpt-4o" + assert settings.model_type == ModelType.LLM diff --git a/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py new file mode 100644 index 0000000000..399b531205 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_api_based_extension_requestor.py @@ -0,0 +1,137 @@ +import httpx +import pytest + +from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor +from models.api_based_extension import APIBasedExtensionPoint + + +def test_request_success(mocker): + # Mock httpx.Client and its context manager + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + result = requestor.request(APIBasedExtensionPoint.PING, {"foo": "bar"}) + + assert result == {"result": "success"} + mock_client_instance.request.assert_called_once_with( + method="POST", + url="http://example.com", + json={"point": APIBasedExtensionPoint.PING.value, "params": {"foo": "bar"}}, + headers={"Content-Type": "application/json", "Authorization": "Bearer test_key"}, + ) + + +def test_request_with_ssrf_proxy(mocker): + # Mock dify_config + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", "https://proxy:8081") + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + # Mock HTTPTransport + mock_transport = mocker.patch("httpx.HTTPTransport") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert "mounts" in kwargs + assert "http://" in kwargs["mounts"] + assert "https://" in kwargs["mounts"] + assert mock_transport.call_count == 2 + + +def test_request_with_only_one_proxy_config(mocker): + # Mock dify_config with only one proxy + mocker.patch("configs.dify_config.SSRF_PROXY_HTTP_URL", "http://proxy:8080") + mocker.patch("configs.dify_config.SSRF_PROXY_HTTPS_URL", None) + + # Mock httpx.Client + mock_client = mocker.MagicMock() + mock_client_class = mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance = mock_client.__enter__.return_value + + # Mock response + mock_response = mocker.MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": "success"} + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + requestor.request(APIBasedExtensionPoint.PING, {}) + + # Verify httpx.Client was called with mounts=None (default) + mock_client_class.assert_called_once() + kwargs = mock_client_class.call_args.kwargs + assert kwargs.get("mounts") is None + + +def test_request_timeout(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.TimeoutException("timeout") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request timeout"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_connection_error(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + mock_client_instance.request.side_effect = httpx.RequestError("error") + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request connection error"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 404 + mock_response.text = "Not Found" + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + with pytest.raises(ValueError, match="request error, status_code: 404, content: Not Found"): + requestor.request(APIBasedExtensionPoint.PING, {}) + + +def test_request_error_status_code_long_content(mocker): + mock_client = mocker.MagicMock() + mock_client_instance = mock_client.__enter__.return_value + mocker.patch("httpx.Client", return_value=mock_client) + + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + mock_response.text = "A" * 200 # Testing truncation of content + mock_client_instance.request.return_value = mock_response + + requestor = APIBasedExtensionRequestor(api_endpoint="http://example.com", api_key="test_key") + expected_content = "A" * 100 + with pytest.raises(ValueError, match=f"request error, status_code: 500, content: {expected_content}"): + requestor.request(APIBasedExtensionPoint.PING, {}) diff --git a/api/tests/unit_tests/core/extension/test_extensible.py b/api/tests/unit_tests/core/extension/test_extensible.py new file mode 100644 index 0000000000..9bce0cd7c8 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extensible.py @@ -0,0 +1,281 @@ +import json +import types +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from core.extension.extensible import Extensible + + +class TestExtensible: + def test_init(self): + tenant_id = "tenant_123" + config = {"key": "value"} + ext = Extensible(tenant_id, config) + assert ext.tenant_id == tenant_id + assert ext.config == config + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.sort_to_dict_by_position_map") + def test_scan_extensions_success( + self, + mock_sort, + mock_module_from_spec, + mock_read_text, + mock_exists, + mock_isdir, + mock_listdir, + mock_dirname, + mock_find_spec, + ): + # Setup + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [ + ["ext1"], # package_dir + ["ext1.py", "__builtin__"], # subdir_path + ] + mock_isdir.return_value = True + + mock_exists.return_value = True + mock_read_text.return_value = "10" + + # Use types.ModuleType to avoid MagicMock __dict__ issues + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + mock_sort.side_effect = lambda position_map, data, name_func: data + + # Execute + results = Extensible.scan_extensions() + + # Assert + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].position == 10 + assert results[0].builtin is True + assert results[0].extension_class == MockExtension + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_package_not_found(self, mock_find_spec): + mock_find_spec.return_value = None + with pytest.raises(ImportError, match="Could not find package"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_skip_subdirs(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + mock_find_spec.return_value = package_spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["__pycache__", "not_a_dir", "missing_py_file"], []] + + mock_isdir.side_effect = [False, True] + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_success( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py", "schema.json"]] + mock_isdir.return_value = True + + # exists checks: only schema.json needs to exist + mock_exists.return_value = True + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + schema_content = json.dumps({"label": {"en": "Test"}, "form_schema": [{"name": "field1"}]}) + + with ( + patch("builtins.open", mock_open(read_data=schema_content)), + patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ), + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].name == "ext1" + assert results[0].builtin is False + assert results[0].label == {"en": "Test"} + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_not_builtin_missing_schema( + self, mock_module_from_spec, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # exists: only schema.json checked, and return False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.importlib.util.module_from_spec") + @patch("core.extension.extensible.os.path.exists") + def test_scan_extensions_no_extension_class( + self, mock_exists, mock_module_from_spec, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + # Mock not builtin + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + mock_mod.SomeOtherClass = type("SomeOtherClass", (), {}) + mock_module_from_spec.return_value = mock_mod + + # We need to ensure we don't crash if checking schema (but we won't reach there because class not found) + + with patch("core.extension.extensible.sort_to_dict_by_position_map", return_value=[]): + results = Extensible.scan_extensions() + + assert len(results) == 0 + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + def test_scan_extensions_module_import_error(self, mock_isdir, mock_listdir, mock_dirname, mock_find_spec): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + + mock_find_spec.side_effect = [package_spec, None] # No module spec + mock_dirname.return_value = "/path/to/pkg" + + mock_listdir.side_effect = [["ext1"], ["ext1.py"]] + mock_isdir.return_value = True + + with pytest.raises(ImportError, match="Failed to load module"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + def test_scan_extensions_general_exception(self, mock_find_spec): + mock_find_spec.side_effect = Exception("Unexpected error") + with pytest.raises(Exception, match="Unexpected error"): + Extensible.scan_extensions() + + @patch("core.extension.extensible.importlib.util.find_spec") + @patch("core.extension.extensible.os.path.dirname") + @patch("core.extension.extensible.os.listdir") + @patch("core.extension.extensible.os.path.isdir") + @patch("core.extension.extensible.os.path.exists") + @patch("core.extension.extensible.Path.read_text") + @patch("core.extension.extensible.importlib.util.module_from_spec") + def test_scan_extensions_builtin_without_position_file( + self, mock_module_from_spec, mock_read_text, mock_exists, mock_isdir, mock_listdir, mock_dirname, mock_find_spec + ): + package_spec = MagicMock() + package_spec.origin = "/path/to/pkg/__init__.py" + module_spec = MagicMock() + module_spec.loader = MagicMock() + + mock_find_spec.side_effect = [package_spec, module_spec] + mock_dirname.return_value = "/path/to/pkg" + mock_listdir.side_effect = [["ext1"], ["ext1.py", "__builtin__"]] + mock_isdir.return_value = True + + # builtin exists in listdir, but os.path.exists(builtin_file_path) returns False + mock_exists.return_value = False + + mock_mod = types.ModuleType("ext1") + + class MockExtension(Extensible): + pass + + mock_mod.MockExtension = MockExtension + mock_module_from_spec.return_value = mock_mod + + with patch( + "core.extension.extensible.sort_to_dict_by_position_map", + side_effect=lambda position_map, data, name_func: data, + ): + results = Extensible.scan_extensions() + + assert len(results) == 1 + assert results[0].position == 0 diff --git a/api/tests/unit_tests/core/extension/test_extension.py b/api/tests/unit_tests/core/extension/test_extension.py new file mode 100644 index 0000000000..4ad32d3840 --- /dev/null +++ b/api/tests/unit_tests/core/extension/test_extension.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.extension.extensible import ExtensionModule, ModuleExtension +from core.extension.extension import Extension + + +class TestExtension: + def setup_method(self): + # Reset the private class attribute before each test + Extension._Extension__module_extensions = {} + + def test_init(self): + # Mock scan_extensions for Moderation and ExternalDataTool + mock_mod_extensions = {"mod1": ModuleExtension(name="mod1")} + mock_ext_extensions = {"ext1": ModuleExtension(name="ext1")} + + extension = Extension() + + # We need to mock scan_extensions on the classes defined in Extension.module_classes + with ( + patch("core.extension.extension.Moderation.scan_extensions", return_value=mock_mod_extensions), + patch("core.extension.extension.ExternalDataTool.scan_extensions", return_value=mock_ext_extensions), + ): + extension.init() + + # Check if internal state is updated + internal_state = Extension._Extension__module_extensions + assert internal_state[ExtensionModule.MODERATION.value] == mock_mod_extensions + assert internal_state[ExtensionModule.EXTERNAL_DATA_TOOL.value] == mock_ext_extensions + + def test_module_extensions_success(self): + # Setup data + mock_extensions = {"name1": ModuleExtension(name="name1"), "name2": ModuleExtension(name="name2")} + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: mock_extensions} + + extension = Extension() + result = extension.module_extensions(ExtensionModule.MODERATION.value) + + assert len(result) == 2 + assert any(e.name == "name1" for e in result) + assert any(e.name == "name2" for e in result) + + def test_module_extensions_not_found(self): + extension = Extension() + with pytest.raises(ValueError, match="Extension Module unknown not found"): + extension.module_extensions("unknown") + + def test_module_extension_success(self): + mock_ext = ModuleExtension(name="test_ext") + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.module_extension(ExtensionModule.MODERATION, "test_ext") + assert result == mock_ext + + def test_module_extension_module_not_found(self): + extension = Extension() + # ExtensionModule.MODERATION is "moderation" + with pytest.raises(ValueError, match="Extension Module moderation not found"): + extension.module_extension(ExtensionModule.MODERATION, "any") + + def test_module_extension_extension_not_found(self): + # We need a non-empty dict because 'if not module_extensions' in extension.py + # returns True for an empty dict, which raises the module not found error instead. + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"other": MagicMock()}} + + extension = Extension() + with pytest.raises(ValueError, match="Extension unknown not found"): + extension.module_extension(ExtensionModule.MODERATION, "unknown") + + def test_extension_class_success(self): + class MockClass: + pass + + mock_ext = ModuleExtension(name="test_ext", extension_class=MockClass) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + result = extension.extension_class(ExtensionModule.MODERATION, "test_ext") + assert result == MockClass + + def test_extension_class_none(self): + mock_ext = ModuleExtension(name="test_ext", extension_class=None) + Extension._Extension__module_extensions = {ExtensionModule.MODERATION.value: {"test_ext": mock_ext}} + + extension = Extension() + with pytest.raises(AssertionError): + extension.extension_class(ExtensionModule.MODERATION, "test_ext") diff --git a/api/tests/unit_tests/core/external_data_tool/api/test_api.py b/api/tests/unit_tests/core/external_data_tool/api/test_api.py new file mode 100644 index 0000000000..1653124bd8 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/api/test_api.py @@ -0,0 +1,145 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.external_data_tool.api.api import ApiExternalDataTool +from models.api_based_extension import APIBasedExtensionPoint + + +def test_api_external_data_tool_name(): + assert ApiExternalDataTool.name == "api" + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_success(mock_db): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_db.session.scalar.return_value = mock_extension + + # Should not raise exception + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +def test_validate_config_missing_id(): + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiExternalDataTool.validate_config("tenant_id", {}) + + +@patch("core.external_data_tool.api.api.db") +def test_validate_config_invalid_id(mock_db): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match="api_based_extension_id is invalid"): + ApiExternalDataTool.validate_config("tenant_id", {"api_based_extension_id": "ext_id"}) + + +@pytest.fixture +def api_tool(): + # Use standard kwargs as it inherits from ExternalDataTool which is typically a Pydantic BaseModel + return ApiExternalDataTool( + tenant_id="tenant_id", app_id="app_id", variable="var1", config={"api_based_extension_id": "ext_id"} + ) + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_success(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": "success_result"} + + res = api_tool.query({"input1": "value1"}, "query_str") + + assert res == "success_result" + + mock_requestor_class.assert_called_once_with(api_endpoint="http://api", api_key="decrypted_key") + mock_requestor.request.assert_called_once_with( + point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, + params={"app_id": "app_id", "tool_variable": "var1", "inputs": {"input1": "value1"}, "query": "query_str"}, + ) + + +def test_query_missing_config(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1") + api_tool.config = None # Force None + with pytest.raises(ValueError, match="config is required"): + api_tool.query({}, "") + + +def test_query_missing_extension_id(): + api_tool = ApiExternalDataTool(tenant_id="tenant_id", app_id="app_id", variable="var1", config={"dummy": "value"}) + with pytest.raises(AssertionError, match="api_based_extension_id is required"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +def test_query_invalid_extension(mock_db, api_tool): + mock_db.session.scalar.return_value = None + + with pytest.raises(ValueError, match=".*error: api_based_extension_id is invalid"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_requestor_init_error(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor_class.side_effect = Exception("init error") + + with pytest.raises(ValueError, match=".*error: init error"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_no_result_in_response(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"other": "value"} + + with pytest.raises(ValueError, match=".*error: result not found in response"): + api_tool.query({}, "") + + +@patch("core.external_data_tool.api.api.db") +@patch("core.external_data_tool.api.api.encrypter") +@patch("core.external_data_tool.api.api.APIBasedExtensionRequestor") +def test_query_result_not_string(mock_requestor_class, mock_encrypter, mock_db, api_tool): + mock_extension = MagicMock() + mock_extension.id = "ext_id" + mock_extension.tenant_id = "tenant_id" + mock_extension.api_endpoint = "http://api" + mock_extension.api_key = "encrypted_key" + mock_db.session.scalar.return_value = mock_extension + mock_encrypter.decrypt_token.return_value = "decrypted_key" + + mock_requestor = mock_requestor_class.return_value + mock_requestor.request.return_value = {"result": 123} # Not a string + + with pytest.raises(ValueError, match=".*error: result is not string"): + api_tool.query({}, "") diff --git a/api/tests/unit_tests/core/external_data_tool/test_base.py b/api/tests/unit_tests/core/external_data_tool/test_base.py new file mode 100644 index 0000000000..216cda83c5 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_base.py @@ -0,0 +1,66 @@ +import pytest + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.base import ExternalDataTool + + +class TestExternalDataTool: + def test_module_attribute(self): + assert ExternalDataTool.module == ExtensionModule.EXTERNAL_DATA_TOOL + + def test_init(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1", config={"key": "value"}) + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config == {"key": "value"} + + def test_init_without_config(self): + # Create a concrete subclass to test init + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + assert tool.tenant_id == "tenant_1" + assert tool.app_id == "app_1" + assert tool.variable == "var_1" + assert tool.config is None + + def test_validate_config_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + return super().validate_config(tenant_id, config) + + def query(self, inputs: dict, query: str | None = None) -> str: + return "" + + with pytest.raises(NotImplementedError): + ConcreteTool.validate_config("tenant_1", {}) + + def test_query_raises_not_implemented(self): + class ConcreteTool(ExternalDataTool): + @classmethod + def validate_config(cls, tenant_id: str, config: dict): + pass + + def query(self, inputs: dict, query: str | None = None) -> str: + return super().query(inputs, query) + + tool = ConcreteTool(tenant_id="tenant_1", app_id="app_1", variable="var_1") + with pytest.raises(NotImplementedError): + tool.query({}) diff --git a/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py new file mode 100644 index 0000000000..86b461cf04 --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_external_data_fetch.py @@ -0,0 +1,115 @@ +from unittest.mock import patch + +import pytest +from flask import Flask + +from core.app.app_config.entities import ExternalDataVariableEntity +from core.external_data_tool.external_data_fetch import ExternalDataFetch + + +class TestExternalDataFetch: + @pytest.fixture + def app(self): + app = Flask(__name__) + return app + + def test_fetch_success(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + + # Setup mocks + tool1 = ExternalDataVariableEntity(variable="var1", type="type1", config={"c1": "v1"}) + tool2 = ExternalDataVariableEntity(variable="var2", type="type2", config={"c2": "v2"}) + + external_data_tools = [tool1, tool2] + inputs = {"input_key": "input_value"} + query = "test query" + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + # Create distinct mock instances for each tool to ensure deterministic results + # This approach is robust regardless of thread scheduling order + from unittest.mock import MagicMock + + def factory_side_effect(*args, **kwargs): + variable = kwargs.get("variable") + mock_instance = MagicMock() + if variable == "var1": + mock_instance.query.return_value = "result1" + elif variable == "var2": + mock_instance.query.return_value = "result2" + return mock_instance + + MockFactory.side_effect = factory_side_effect + + result_inputs = fetcher.fetch( + tenant_id="tenant1", + app_id="app1", + external_data_tools=external_data_tools, + inputs=inputs, + query=query, + ) + + # Each tool gets its deterministic result regardless of thread completion order + assert result_inputs["var1"] == "result1" + assert result_inputs["var2"] == "result2" + assert result_inputs["input_key"] == "input_value" + assert len(result_inputs) == 3 + + # Verify factory calls + assert MockFactory.call_count == 2 + MockFactory.assert_any_call( + name="type1", tenant_id="tenant1", app_id="app1", variable="var1", config={"c1": "v1"} + ) + MockFactory.assert_any_call( + name="type2", tenant_id="tenant1", app_id="app1", variable="var2", config={"c2": "v2"} + ) + + def test_fetch_no_tools(self): + # We don't necessarily need app_context if there are no tools, + # but fetch calls current_app._get_current_object() only inside the loop. + # Wait, let's look at the code. + # for tool in external_data_tools: + # executor.submit(..., current_app._get_current_object(), ...) + # So if external_data_tools is empty, it shouldn't access current_app. + fetcher = ExternalDataFetch() + inputs = {"input_key": "input_value"} + result_inputs = fetcher.fetch( + tenant_id="tenant1", app_id="app1", external_data_tools=[], inputs=inputs, query="test query" + ) + assert result_inputs == inputs + assert result_inputs is not inputs # Should be a copy + + def test_fetch_with_none_variable(self, app): + with app.app_context(): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={}) + + # Patch _query_external_data_tool to return None variable + with patch.object(ExternalDataFetch, "_query_external_data_tool") as mock_query: + mock_query.return_value = (None, "some_result") + + result_inputs = fetcher.fetch( + tenant_id="t1", app_id="a1", external_data_tools=[tool], inputs={"in": "val"}, query="q" + ) + + assert "var1" not in result_inputs + assert result_inputs == {"in": "val"} + + def test_query_external_data_tool(self, app): + fetcher = ExternalDataFetch() + tool = ExternalDataVariableEntity(variable="var1", type="type1", config={"k": "v"}) + + with patch("core.external_data_tool.external_data_fetch.ExternalDataToolFactory") as MockFactory: + mock_factory_instance = MockFactory.return_value + mock_factory_instance.query.return_value = "query_result" + + var, res = fetcher._query_external_data_tool( + flask_app=app, tenant_id="t1", app_id="a1", external_data_tool=tool, inputs={"i": "v"}, query="q" + ) + + assert var == "var1" + assert res == "query_result" + MockFactory.assert_called_once_with( + name="type1", tenant_id="t1", app_id="a1", variable="var1", config={"k": "v"} + ) + mock_factory_instance.query.assert_called_once_with(inputs={"i": "v"}, query="q") diff --git a/api/tests/unit_tests/core/external_data_tool/test_factory.py b/api/tests/unit_tests/core/external_data_tool/test_factory.py new file mode 100644 index 0000000000..6bb384b0ac --- /dev/null +++ b/api/tests/unit_tests/core/external_data_tool/test_factory.py @@ -0,0 +1,58 @@ +from unittest.mock import MagicMock, patch + +from core.extension.extensible import ExtensionModule +from core.external_data_tool.factory import ExternalDataToolFactory + + +def test_external_data_tool_factory_init(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + app_id = "app_456" + variable = "var_v" + config = {"key": "value"} + + factory = ExternalDataToolFactory(name, tenant_id, app_id, variable, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.assert_called_once_with( + tenant_id=tenant_id, app_id=app_id, variable=variable, config=config + ) + + +def test_external_data_tool_factory_validate_config(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_code_based_extension.extension_class.return_value = mock_extension_class + + name = "test_tool" + tenant_id = "tenant_123" + config = {"key": "value"} + + ExternalDataToolFactory.validate_config(name, tenant_id, config) + + mock_code_based_extension.extension_class.assert_called_once_with(ExtensionModule.EXTERNAL_DATA_TOOL, name) + mock_extension_class.validate_config.assert_called_once_with(tenant_id, config) + + +def test_external_data_tool_factory_query(): + with patch("core.external_data_tool.factory.code_based_extension") as mock_code_based_extension: + mock_extension_class = MagicMock() + mock_extension_instance = MagicMock() + mock_extension_class.return_value = mock_extension_instance + mock_code_based_extension.extension_class.return_value = mock_extension_class + + mock_extension_instance.query.return_value = "query_result" + + factory = ExternalDataToolFactory("name", "tenant", "app", "var", {}) + + inputs = {"input_key": "input_value"} + query = "search_query" + + result = factory.query(inputs, query) + + assert result == "query_result" + mock_extension_instance.query.assert_called_once_with(inputs, query) diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index 4d4ccc2672..deebf41320 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType def test_file(): diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py new file mode 100644 index 0000000000..b2783bdf99 --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_rule_config_generator.py @@ -0,0 +1,103 @@ +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser +from core.llm_generator.prompts import ( + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, +) + + +class TestRuleConfigGeneratorOutputParser: + def test_get_format_instructions(self): + parser = RuleConfigGeneratorOutputParser() + instructions = parser.get_format_instructions() + assert instructions == ( + RULE_CONFIG_PROMPT_GENERATE_TEMPLATE, + RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE, + RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE, + ) + + def test_parse_success(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + result = parser.parse(text) + assert result["prompt"] == "This is a prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Hello!" + + def test_parse_invalid_json(self): + parser = RuleConfigGeneratorOutputParser() + text = "invalid json" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Parsing text" in str(excinfo.value) + assert "could not find json block in the output" in str(excinfo.value) + + def test_parse_missing_keys(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"] +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "expected key `opening_statement` to be present" in str(excinfo.value) + + def test_parse_wrong_type_prompt(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": 123, + "variables": ["var1", "var2"], + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'prompt' to be a string" in str(excinfo.value) + + def test_parse_wrong_type_variables(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": "not a list", + "opening_statement": "Hello!" +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'variables' to be a list" in str(excinfo.value) + + def test_parse_wrong_type_opening_statement(self): + parser = RuleConfigGeneratorOutputParser() + text = """ +```json +{ + "prompt": "This is a prompt", + "variables": ["var1", "var2"], + "opening_statement": 123 +} +``` +""" + with pytest.raises(OutputParserError) as excinfo: + parser.parse(text) + assert "Expected 'opening_statement' to be a str" in str(excinfo.value) diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py new file mode 100644 index 0000000000..46c9dc6f9c --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -0,0 +1,402 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import ( + ResponseFormat, + _handle_native_json_schema, + _handle_prompt_based_schema, + _parse_structured_output, + _prepare_schema_for_model, + _set_response_format, + convert_boolean_to_string, + invoke_llm_with_structured_output, + remove_additional_properties, +) +from core.model_manager import ModelInstance +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType + + +class TestStructuredOutput: + def test_remove_additional_properties(self): + schema = { + "type": "object", + "properties": {"name": {"type": "string"}, "age": {"type": "integer"}}, + "additionalProperties": False, + "nested": {"type": "object", "additionalProperties": True}, + "items": [{"type": "object", "additionalProperties": False}], + } + remove_additional_properties(schema) + assert "additionalProperties" not in schema + assert "additionalProperties" not in schema["nested"] + assert "additionalProperties" not in schema["items"][0] + + # Test with non-dict input + remove_additional_properties(None) # Should not raise + remove_additional_properties([]) # Should not raise + + def test_convert_boolean_to_string(self): + schema = { + "type": "object", + "properties": { + "is_active": {"type": "boolean"}, + "tags": {"type": "array", "items": {"type": "boolean"}}, + "list_schema": [{"type": "boolean"}], + }, + } + convert_boolean_to_string(schema) + assert schema["properties"]["is_active"]["type"] == "string" + assert schema["properties"]["tags"]["items"]["type"] == "string" + assert schema["properties"]["list_schema"][0]["type"] == "string" + + # Test with non-dict input + convert_boolean_to_string(None) # Should not raise + convert_boolean_to_string([]) # Should not raise + + def test_parse_structured_output_valid(self): + text = '{"key": "value"}' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_non_dict_valid_json(self): + # Even if it's valid JSON, if it's not a dict, it should try repair or fail + text = '["a", "b"]' + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = {"key": "value"} + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_not_dict_fail_via_validate(self): + # Force TypeAdapter to return a non-dict to trigger line 292 + with patch("pydantic.TypeAdapter.validate_json") as mock_validate: + mock_validate.return_value = ["a list"] + with pytest.raises(OutputParserError) as excinfo: + _parse_structured_output('["a list"]') + assert "Failed to parse structured output" in str(excinfo.value) + + def test_parse_structured_output_repair_success(self): + text = "{'key': 'value'}" # Invalid JSON (single quotes) + # json_repair should handle this + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list(self): + # Deepseek-r1 case: result is a list containing a dict + text = '[{"key": "value"}]' + assert _parse_structured_output(text) == {"key": "value"} + + def test_parse_structured_output_repair_list_no_dict(self): + # Deepseek-r1 case: result is a list with NO dict + text = "[1, 2, 3]" + assert _parse_structured_output(text) == {} + + def test_parse_structured_output_repair_fail(self): + text = "not a json at all" + with patch("json_repair.loads") as mock_repair: + mock_repair.return_value = "still not a dict or list" + with pytest.raises(OutputParserError): + _parse_structured_output(text) + + def test_set_response_format(self): + # Test JSON + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON + + # Test JSON_OBJECT + params = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_OBJECT], + ) + ] + _set_response_format(params, rules) + assert params["response_format"] == ResponseFormat.JSON_OBJECT + + def test_handle_native_json_schema(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert json.loads(updated_params["json_schema"]) == {"schema": {"type": "object"}, "name": "llm_response"} + assert updated_params["response_format"] == ResponseFormat.JSON_SCHEMA + + def test_handle_native_json_schema_no_format_rule(self): + provider = "openai" + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + structured_output_schema = {"type": "object"} + model_parameters = {} + rules = [] + + updated_params = _handle_native_json_schema( + provider, model_schema, structured_output_schema, model_parameters, rules + ) + + assert "json_schema" in updated_params + assert "response_format" not in updated_params + + def test_handle_prompt_based_schema_with_system_prompt(self): + prompt_messages = [ + SystemPromptMessage(content="Existing system prompt"), + UserPromptMessage(content="User question"), + ] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert "Existing system prompt" in result[0].content + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_handle_prompt_based_schema_without_system_prompt(self): + prompt_messages = [UserPromptMessage(content="User question")] + schema = {"type": "object"} + + result = _handle_prompt_based_schema(prompt_messages, schema) + + assert len(result) == 2 + assert isinstance(result[0], SystemPromptMessage) + assert json.dumps(schema) in result[0].content + assert isinstance(result[1], UserPromptMessage) + + def test_prepare_schema_for_model_gemini(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gemini-1.5-pro" + schema = {"type": "object", "additionalProperties": False} + + result = _prepare_schema_for_model("google", model_schema, schema) + assert "additionalProperties" not in result + + def test_prepare_schema_for_model_ollama(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "llama3" + schema = {"type": "object"} + + result = _prepare_schema_for_model("ollama", model_schema, schema) + assert result == schema + + def test_prepare_schema_for_model_default(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.model = "gpt-4" + schema = {"type": "object"} + + result = _prepare_schema_for_model("openai", model_schema, schema) + assert result == {"schema": schema, "name": "llm_response"} + + def test_invoke_llm_with_structured_output_no_stream_native(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = True + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON_SCHEMA], + ) + ] + model_schema.model = "gpt-4o" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "gpt-4o" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_native" + mock_result.prompt_messages = [UserPromptMessage(content="hi")] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_native" + + def test_invoke_llm_with_structured_output_no_stream_prompt_based(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [ + ParameterRule( + name="response_format", + label={"en_US": ""}, + type=ParameterType.STRING, + help={"en_US": ""}, + options=[ResponseFormat.JSON], + ) + ] + model_schema.model = "claude-3" + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content='{"result": "success"}') + mock_result.model = "claude-3" + mock_result.usage = LLMUsage.empty_usage() + mock_result.system_fingerprint = "fp_prompt" + mock_result.prompt_messages = [] + + model_instance.invoke_llm.return_value = mock_result + + result = invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={"type": "object"}, + stream=False, + ) + + assert isinstance(result, LLMResultWithStructuredOutput) + assert result.structured_output == {"result": "success"} + assert result.system_fingerprint == "fp_prompt" + + def test_invoke_llm_with_structured_output_no_string_error(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + + model_instance = MagicMock(spec=ModelInstance) + mock_result = MagicMock(spec=LLMResult) + mock_result.message = AssistantPromptMessage(content=[TextPromptMessageContent(data="not a string")]) + + model_instance.invoke_llm.return_value = mock_result + + with pytest.raises(OutputParserError) as excinfo: + invoke_llm_with_structured_output( + provider="anthropic", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=False, + ) + assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value) + + def test_invoke_llm_with_structured_output_stream(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + + # Mock chunks + chunk1 = MagicMock(spec=LLMResultChunk) + chunk1.delta = LLMResultChunkDelta( + index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage() + ) + chunk1.prompt_messages = [UserPromptMessage(content="hi")] + chunk1.system_fingerprint = "fp1" + + chunk2 = MagicMock(spec=LLMResultChunk) + chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}')) + chunk2.prompt_messages = [UserPromptMessage(content="hi")] + chunk2.system_fingerprint = "fp1" + + chunk3 = MagicMock(spec=LLMResultChunk) + chunk3.delta = LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage( + content=[ + TextPromptMessageContent(data=" "), + ] + ), + ) + chunk3.prompt_messages = [UserPromptMessage(content="hi")] + chunk3.system_fingerprint = "fp1" + + event4 = MagicMock() + event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content="")) + + model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[UserPromptMessage(content="hi")], + json_schema={}, + stream=True, + ) + + chunks = list(generator) + assert len(chunks) == 5 + assert chunks[-1].structured_output == {"key": "value"} + assert chunks[-1].system_fingerprint == "fp1" + assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")] + + def test_invoke_llm_with_structured_output_stream_no_id_events(self): + model_schema = MagicMock(spec=AIModelEntity) + model_schema.support_structure_output = False + model_schema.parameter_rules = [] + model_schema.model = "gpt-4" + + model_instance = MagicMock(spec=ModelInstance) + model_instance.invoke_llm.return_value = [] + + generator = invoke_llm_with_structured_output( + provider="openai", + model_schema=model_schema, + model_instance=model_instance, + prompt_messages=[], + json_schema={}, + stream=True, + ) + + with pytest.raises(OutputParserError): + list(generator) + + def test_parse_structured_output_empty_string(self): + with pytest.raises(OutputParserError): + _parse_structured_output("") diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py new file mode 100644 index 0000000000..5b7640696f --- /dev/null +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -0,0 +1,589 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import ModelConfig +from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload +from core.llm_generator.llm_generator import LLMGenerator +from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult +from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError + + +class TestLLMGenerator: + @pytest.fixture + def mock_model_instance(self): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_default_model_instance.return_value = instance + mock_manager.return_value.get_model_instance.return_value = instance + yield instance + + @pytest.fixture + def model_config_entity(self): + return ModelConfig(provider="openai", name="gpt-4", mode=LLMMode.CHAT, completion_params={"temperature": 0.7}) + + def test_generate_conversation_name_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Test Conversation Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager") as mock_trace: + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Test Conversation Name" + mock_trace.assert_called_once() + + def test_generate_conversation_name_truncated(self, mock_model_instance): + long_query = "a" * 2100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": "Short Name"}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", long_query) + assert name == "Short Name" + + def test_generate_conversation_name_empty_answer(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "" + mock_model_instance.invoke_llm.return_value = mock_response + + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "" + + def test_generate_conversation_name_json_repair(self, mock_model_instance): + mock_response = MagicMock() + # Invalid JSON that json_repair can fix + mock_response.message.get_text_content.return_value = "{'Your Output': 'Repaired Name'}" + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "Repaired Name" + + def test_generate_conversation_name_not_dict_result(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["not a dict"]' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_no_output_in_dict(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"something": "else"}' + mock_model_instance.invoke_llm.return_value = mock_response + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert name == "test query" + + def test_generate_conversation_name_long_output(self, mock_model_instance): + long_output = "a" * 100 + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = json.dumps({"Your Output": long_output}) + mock_model_instance.invoke_llm.return_value = mock_response + + with patch("core.llm_generator.llm_generator.TraceQueueManager"): + name = LLMGenerator.generate_conversation_name("tenant_id", "test query") + assert len(name) == 78 # 75 + "..." + assert name.endswith("...") + + def test_generate_suggested_questions_after_answer_success(self, mock_model_instance): + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '["Question 1?", "Question 2?"]' + mock_model_instance.invoke_llm.return_value = mock_response + + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert len(questions) == 2 + assert questions[0] == "Question 1?" + + def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_invoke_error(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_suggested_questions_after_answer_exception(self, mock_model_instance): + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") + assert questions == [] + + def test_generate_rule_config_no_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "Generated Prompt" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Generated Prompt" + assert result["error"] == "" + + def test_generate_rule_config_no_variable_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + + def test_generate_rule_config_no_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=True + ) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate rule config" in result["error"] + assert "Random error" in result["error"] + + def test_generate_rule_config_with_variable_success(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mocking 3 calls for invoke_llm + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1", "var2"' + + mock_res3 = MagicMock() + mock_res3.message.get_text_content.return_value = "Opening Statement" + + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, mock_res3] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert result["prompt"] == "Step 1 Prompt" + assert result["variables"] == ["var1", "var2"] + assert result["opening_statement"] == "Opening Statement" + assert result["error"] == "" + + def test_generate_rule_config_with_variable_step1_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_model_instance.invoke_llm.side_effect = InvokeError("Step 1 Failed") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate prefix prompt" in result["error"] + + def test_generate_rule_config_with_variable_step2_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + # Step 2 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, InvokeError("Step 2 Failed"), MagicMock()] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate variables" in result["error"] + + def test_generate_rule_config_with_variable_step3_error(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + mock_res1 = MagicMock() + mock_res1.message.get_text_content.return_value = "Step 1 Prompt" + + mock_res2 = MagicMock() + mock_res2.message.get_text_content.return_value = '"var1"' + + # Step 3 fails + mock_model_instance.invoke_llm.side_effect = [mock_res1, mock_res2, InvokeError("Step 3 Failed")] + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to generate conversation opener" in result["error"] + + def test_generate_rule_config_with_variable_exception(self, mock_model_instance, model_config_entity): + payload = RuleGeneratePayload( + instruction="test instruction", model_config=model_config_entity, no_variable=False + ) + # Mock any step to throw Exception + mock_model_instance.invoke_llm.side_effect = Exception("Unexpected multi-step error") + + result = LLMGenerator.generate_rule_config("tenant_id", payload) + assert "Failed to handle unexpected exception" in result["error"] + assert "Unexpected multi-step error" in result["error"] + + def test_generate_code_python_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="print hello", code_language="python", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "print('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "print('hello')" + assert result["language"] == "python" + + def test_generate_code_javascript_success(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload( + instruction="console log hello", code_language="javascript", model_config=model_config_entity + ) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "console.log('hello')" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_code("tenant_id", payload) + assert result["code"] == "console.log('hello')" + assert result["language"] == "javascript" + + def test_generate_code_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "Failed to generate code" in result["error"] + + def test_generate_code_exception(self, mock_model_instance, model_config_entity): + payload = RuleCodeGeneratePayload(instruction="error", code_language="python", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_code("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_generate_qa_document_success(self, mock_model_instance): + mock_response = MagicMock(spec=LLMResult) + mock_response.message = MagicMock() + mock_response.message.get_text_content.return_value = "QA Document Content" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_qa_document("tenant_id", "query", "English") + assert result == "QA Document Content" + + def test_generate_qa_document_type_error(self, mock_model_instance): + mock_model_instance.invoke_llm.return_value = "Not an LLMResult" + + with pytest.raises(TypeError, match="Expected LLMResult when stream=False"): + LLMGenerator.generate_qa_document("tenant_id", "query", "English") + + def test_generate_structured_output_success(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"type": "object", "properties": {}}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + assert result["error"] == "" + + def test_generate_structured_output_json_repair(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "{'type': 'object'}" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + parsed_output = json.loads(result["output"]) + assert parsed_output["type"] == "object" + + def test_generate_structured_output_not_dict_or_list(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="generate schema", model_config=model_config_entity) + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "true" # parsed as bool + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + assert "Failed to parse structured output" in result["error"] + + def test_generate_structured_output_invoke_error(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke failed") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "Failed to generate JSON Schema" in result["error"] + + def test_generate_structured_output_exception(self, mock_model_instance, model_config_entity): + payload = RuleStructuredOutputPayload(instruction="error", model_config=model_config_entity) + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.generate_structured_output("tenant_id", payload) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Mock __instruction_modify_common call via invoke_llm + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + last_run = MagicMock() + last_run.query = "q" + last_run.answer = "a" + last_run.error = "e" + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert result == {"modified": "prompt"} + + def test_instruction_modify_workflow_app_not_found(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", MagicMock()) + + def test_instruction_modify_workflow_no_workflow(self): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + with pytest.raises(ValueError, match="Workflow not found for the given app model."): + LLMGenerator.instruction_modify_workflow("t", "f", "n", "c", "i", MagicMock(), "o", workflow_service) + + def test_instruction_modify_workflow_success(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular values, not Mocks + last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]} + last_run.load_full_inputs.return_value = {"in": "val"} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "code"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + # Cause exception in node_type logic + workflow.graph_dict = {"graph": {"nodes": []}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "fallback"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "fallback"} + + def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "llm"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + + last_run = MagicMock() + last_run.node_type = "llm" + last_run.status = "s" + last_run.error = "e" + # Return regular empty list, not a Mock + last_run.execution_metadata_dict = {"agent_log": []} + last_run.load_full_inputs.return_value = {} + + workflow_service.get_node_last_run.return_value = last_run + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"modified": "workflow"}' + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + assert result == {"modified": "workflow"} + + def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): + # Testing placeholders replacement via instruction_modify_legacy for convenience + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + mock_model_instance.invoke_llm.return_value = mock_response + + instruction = "Test {{#last_run#}} and {{#current#}} and {{#error_message#}}" + LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current_val", instruction, model_config_entity, "ideal" + ) + + # Verify the call to invoke_llm contains replaced instruction + args, kwargs = mock_model_instance.invoke_llm.call_args + prompt_messages = kwargs["prompt_messages"] + user_msg = prompt_messages[1].content + user_msg_dict = json.loads(user_msg) + assert "null" in user_msg_dict["instruction"] # because last_run is None and current is current_val etc. + assert "current_val" in user_msg_dict["instruction"] + + def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No braces here" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + assert "Could not find a valid JSON object" in result["error"] + + def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "[1, 2, 3]" + mock_model_instance.invoke_llm.return_value = mock_response + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + # The exception message is "Expected a JSON object, but got list" + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): + with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + instance = MagicMock() + mock_manager.return_value.get_model_instance.return_value = instance + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = '{"ok": true}' + instance.invoke_llm.return_value = mock_response + + with patch("extensions.ext_database.db.session") as mock_session: + mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock() + workflow = MagicMock() + workflow.graph_dict = {"graph": {"nodes": [{"id": "node_id", "data": {"type": "other"}}]}} + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + workflow_service.get_node_last_run.return_value = None + + LLMGenerator.instruction_modify_workflow( + "tenant_id", + "flow_id", + "node_id", + "current", + "instruction", + model_config_entity, + "ideal", + workflow_service, + ) + + def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "Failed to generate code" in result["error"] + + def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + mock_model_instance.invoke_llm.side_effect = Exception("Random error") + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] + + def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): + with patch("extensions.ext_database.db.session.query") as mock_query: + mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + + mock_response = MagicMock() + mock_response.message.get_text_content.return_value = "No JSON here" + mock_model_instance.invoke_llm.return_value = mock_response + + result = LLMGenerator.instruction_modify_legacy( + "tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal" + ) + assert "An unexpected error occurred" in result["error"] diff --git a/api/tests/unit_tests/core/logging/test_filters.py b/api/tests/unit_tests/core/logging/test_filters.py index 7c2767266f..a8b186ac8a 100644 --- a/api/tests/unit_tests/core/logging/test_filters.py +++ b/api/tests/unit_tests/core/logging/test_filters.py @@ -82,6 +82,68 @@ class TestTraceContextFilter: assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" assert log_record.span_id == "051581bf3bb55c45" + def test_otel_context_invalid_trace_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + # Use mocks for base context to ensure we can test the fallback + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_invalid_span_id(self, log_record): + from core.logging.filters import TraceContextFilter + + mock_span = mock.MagicMock() + mock_context = mock.MagicMock() + mock_context.trace_id = 0x5B8AA5A2D2C872E8321CF37308D69DF2 + mock_context.span_id = 0 + mock_context.is_valid = True + mock_span.get_span_context.return_value = mock_context + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=mock_span), + mock.patch("opentelemetry.trace.span.INVALID_TRACE_ID", 0), + mock.patch("opentelemetry.trace.span.INVALID_SPAN_ID", 0), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "5b8aa5a2d2c872e8321cf37308d69df2" + assert log_record.span_id == "" + + def test_otel_context_span_none(self, log_record): + from core.logging.filters import TraceContextFilter + + with ( + mock.patch("opentelemetry.trace.get_current_span", return_value=None), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + + def test_otel_context_exception(self, log_record): + from core.logging.filters import TraceContextFilter + + # Trigger exception in OTEL block + with ( + mock.patch("opentelemetry.trace.get_current_span", side_effect=Exception), + mock.patch("core.logging.filters.get_trace_id", return_value=""), + ): + filter = TraceContextFilter() + filter.filter(log_record) + assert log_record.trace_id == "" + class TestIdentityContextFilter: def test_sets_empty_identity_without_request_context(self, log_record): @@ -114,3 +176,119 @@ class TestIdentityContextFilter: result = filter.filter(log_record) assert result is True assert log_record.tenant_id == "" + + def test_sets_empty_identity_unauthenticated(self, log_record): + from core.logging.filters import IdentityContextFilter + + mock_user = mock.MagicMock() + mock_user.is_authenticated = False + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + assert log_record.user_id == "" + + def test_sets_identity_for_account(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = "tenant_id" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_account_no_tenant(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockAccount: + pass + + mock_user = MockAccount() + mock_user.id = "account_id" + mock_user.current_tenant_id = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.Account", MockAccount), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "" + assert log_record.user_id == "account_id" + assert log_record.user_type == "account" + + def test_sets_identity_for_end_user(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = "custom_type" + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "custom_type" + + def test_sets_identity_for_end_user_default_type(self, log_record): + from core.logging.filters import IdentityContextFilter + + class MockEndUser: + pass + + class AnotherClass: + pass + + mock_user = MockEndUser() + mock_user.id = "end_user_id" + mock_user.tenant_id = "tenant_id" + mock_user.type = None + mock_user.is_authenticated = True + + with ( + mock.patch("flask.has_request_context", return_value=True), + mock.patch("models.model.EndUser", MockEndUser), + mock.patch("models.Account", AnotherClass), + mock.patch("flask_login.current_user", mock_user), + ): + filter = IdentityContextFilter() + filter.filter(log_record) + + assert log_record.tenant_id == "tenant_id" + assert log_record.user_id == "end_user_id" + assert log_record.user_type == "end_user" diff --git a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py index 60f37b6de0..fe533e62af 100644 --- a/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py +++ b/api/tests/unit_tests/core/mcp/auth/test_auth_flow.py @@ -1,27 +1,39 @@ """Unit tests for MCP OAuth authentication flow.""" +import json from unittest.mock import Mock, patch +import httpx import pytest +from pydantic import ValidationError from core.entities.mcp_provider import MCPProviderEntity +from core.helper import ssrf_proxy from core.mcp.auth.auth_flow import ( OAUTH_STATE_EXPIRY_SECONDS, OAUTH_STATE_REDIS_KEY_PREFIX, OAuthCallbackState, _create_secure_redis_state, + _parse_token_response, _retrieve_redis_state, auth, + build_oauth_authorization_server_metadata_discovery_urls, + build_protected_resource_metadata_discovery_urls, check_support_resource_discovery, + client_credentials_flow, + discover_oauth_authorization_server_metadata, discover_oauth_metadata, + discover_protected_resource_metadata, exchange_authorization, generate_pkce_challenge, + get_effective_scope, handle_callback, refresh_authorization, register_client, start_authorization, ) from core.mcp.entities import AuthActionType, AuthResult +from core.mcp.error import MCPRefreshTokenError from core.mcp.types import ( LATEST_PROTOCOL_VERSION, OAuthClientInformation, @@ -764,3 +776,576 @@ class TestAuthOrchestration: auth(mock_provider, authorization_code="auth-code") assert "Existing OAuth client information is required" in str(exc_info.value) + + def test_generate_pkce_challenge(self): + verifier, challenge = generate_pkce_challenge() + assert verifier + assert challenge + assert "=" not in verifier + assert "=" not in challenge + + def test_build_protected_resource_metadata_discovery_urls(self): + # Case 1: WWW-Auth URL provided + urls = build_protected_resource_metadata_discovery_urls( + "https://auth.example.com/prm", "https://api.example.com" + ) + assert "https://auth.example.com/prm" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 2: No WWW-Auth URL, with path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com/v1") + assert "https://api.example.com/.well-known/oauth-protected-resource/v1" in urls + assert "https://api.example.com/.well-known/oauth-protected-resource" in urls + + # Case 3: No path + urls = build_protected_resource_metadata_discovery_urls(None, "https://api.example.com") + assert urls == ["https://api.example.com/.well-known/oauth-protected-resource"] + + def test_build_protected_resource_metadata_discovery_urls_with_relative_hint(self): + urls = build_protected_resource_metadata_discovery_urls( + "/.well-known/oauth-protected-resource/tenant/mcp", + "https://api.example.com/tenant/mcp", + ) + assert urls == [ + "https://api.example.com/.well-known/oauth-protected-resource/tenant/mcp", + "https://api.example.com/.well-known/oauth-protected-resource", + ] + + def test_build_protected_resource_metadata_discovery_urls_ignores_scheme_less_hint(self): + urls = build_protected_resource_metadata_discovery_urls( + "/openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp", + "https://openapi-mcp.cn-hangzhou.aliyuncs.com/tenant/mcp", + ) + + assert urls == [ + "https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource/tenant/mcp", + "https://openapi-mcp.cn-hangzhou.aliyuncs.com/.well-known/oauth-protected-resource", + ] + + def test_build_oauth_authorization_server_metadata_discovery_urls(self): + # Case 1: with auth_server_url + urls = build_oauth_authorization_server_metadata_discovery_urls( + "https://auth.example.com", "https://api.example.com" + ) + assert "https://auth.example.com/.well-known/oauth-authorization-server" in urls + assert "https://auth.example.com/.well-known/openid-configuration" in urls + + # Case 2: with path + urls = build_oauth_authorization_server_metadata_discovery_urls(None, "https://api.example.com/tenant") + assert "https://api.example.com/.well-known/oauth-authorization-server/tenant" in urls + assert "https://api.example.com/tenant/.well-known/openid-configuration" in urls + + @patch("core.helper.ssrf_proxy.get") + def test_discover_protected_resource_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "resource": "https://api.example.com", + "authorization_servers": ["https://auth"], + } + mock_get.return_value = mock_response + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is not None + assert result.resource == "https://api.example.com" + + # 404 then Success + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_protected_resource_metadata(None, "https://api.example.com/path") + assert result is not None + assert result.resource == "https://api.example.com" + + # Error handling + mock_get.side_effect = httpx.RequestError("Error") + result = discover_protected_resource_metadata(None, "https://api.example.com") + assert result is None + + @patch("core.helper.ssrf_proxy.get") + def test_discover_oauth_authorization_server_metadata(self, mock_get): + # Success + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "authorization_endpoint": "https://auth.example.com/auth", + "token_endpoint": "https://auth.example.com/token", + "response_types_supported": ["code"], + } + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # 404 + res404 = Mock() + res404.status_code = 404 + mock_get.side_effect = [res404, mock_response] + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com/tenant") + assert result is not None + assert result.authorization_endpoint == "https://auth.example.com/auth" + + # ValidationError + mock_response.json.return_value = {"invalid": "data"} + mock_get.side_effect = None + mock_get.return_value = mock_response + result = discover_oauth_authorization_server_metadata(None, "https://api.example.com") + assert result is None + + def test_get_effective_scope(self): + prm = ProtectedResourceMetadata( + resource="https://api.example.com", + authorization_servers=["https://auth"], + scopes_supported=["read", "write"], + ) + asm = OAuthMetadata( + authorization_endpoint="https://auth.example.com/auth", + token_endpoint="https://auth.example.com/token", + response_types_supported=["code"], + scopes_supported=["openid", "profile"], + ) + + # 1. WWW-Auth priority + assert get_effective_scope("scope1", prm, asm, "client") == "scope1" + # 2. PRM priority + assert get_effective_scope(None, prm, asm, "client") == "read write" + # 3. ASM priority + assert get_effective_scope(None, None, asm, "client") == "openid profile" + # 4. Client configured + assert get_effective_scope(None, None, None, "client") == "client" + + @patch("core.mcp.auth.auth_flow.redis_client") + def test_redis_state_management(self, mock_redis): + state_data = OAuthCallbackState( + provider_id="p1", + tenant_id="t1", + server_url="https://api", + metadata=None, + client_information=OAuthClientInformation(client_id="c1"), + code_verifier="cv", + redirect_uri="https://re", + ) + + # Create + state_key = _create_secure_redis_state(state_data) + assert state_key + mock_redis.setex.assert_called_once() + + # Retrieve Success + mock_redis.get.return_value = state_data.model_dump_json() + retrieved = _retrieve_redis_state(state_key) + assert retrieved.provider_id == "p1" + mock_redis.delete.assert_called_once() + + # Retrieve Failure - Not found + mock_redis.get.return_value = None + with pytest.raises(ValueError, match="expired or does not exist"): + _retrieve_redis_state("absent") + + # Retrieve Failure - Invalid JSON + mock_redis.get.return_value = "invalid" + with pytest.raises(ValueError, match="Invalid state parameter"): + _retrieve_redis_state("invalid") + + @patch("core.mcp.auth.auth_flow._retrieve_redis_state") + @patch("core.mcp.auth.auth_flow.exchange_authorization") + def test_handle_callback(self, mock_exchange, mock_retrieve): + state = Mock(spec=OAuthCallbackState) + state.server_url = "https://api" + state.metadata = None + state.client_information = Mock() + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + tokens = Mock(spec=OAuthTokens) + mock_exchange.return_value = tokens + + s, t = handle_callback("key", "code") + assert s == state + assert t == tokens + + @patch("core.helper.ssrf_proxy.get") + def test_check_support_resource_discovery(self, mock_get): + # Case 1: authorization_servers (plural) + res = Mock() + res.status_code = 200 + res.json.return_value = {"authorization_servers": ["https://auth1"]} + mock_get.return_value = res + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth1" + + # Case 2: authorization_server_url (singular alias) + res.json.return_value = {"authorization_server_url": ["https://auth2"]} + supported, url = check_support_resource_discovery("https://api") + assert supported is True + assert url == "https://auth2" + + # Case 3: Missing fields + res.json.return_value = {"nothing": []} + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 4: 404 + res.status_code = 404 + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + # Case 5: RequestError + mock_get.side_effect = httpx.RequestError("Error") + supported, url = check_support_resource_discovery("https://api") + assert supported is False + + def test_discover_oauth_metadata(self): + with patch("core.mcp.auth.auth_flow.discover_protected_resource_metadata") as mock_prm: + with patch("core.mcp.auth.auth_flow.discover_oauth_authorization_server_metadata") as mock_asm: + mock_prm.return_value = ProtectedResourceMetadata( + resource="https://api", authorization_servers=["https://auth"] + ) + mock_asm.return_value = Mock(spec=OAuthMetadata) + + asm, prm, hint = discover_oauth_metadata("https://api") + assert asm == mock_asm.return_value + assert prm == mock_prm.return_value + mock_asm.assert_called_with("https://auth", "https://api", None) + + def test_start_authorization(self): + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + client_info = OAuthClientInformation(client_id="c1") + + with patch("core.mcp.auth.auth_flow._create_secure_redis_state") as mock_create: + mock_create.return_value = "state-key" + + # Success with scope + url, verifier = start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1", "read") + assert "scope=read" in url + assert "state=state-key" in url + + # Success without metadata + url, verifier = start_authorization("https://api", None, client_info, "https://re", "p1", "t1") + assert "https://api/authorize" in url + + # Failure: incompatible auth server + metadata.response_types_supported = ["implicit"] + with pytest.raises(ValueError, match="Incompatible auth server"): + start_authorization("https://api", metadata, client_info, "https://re", "p1", "t1") + + def test_parse_token_response(self): + # Case 1: JSON + res = Mock() + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at" + + # Case 2: Form-urlencoded + res.headers = {"content-type": "application/x-www-form-urlencoded"} + res.text = "access_token=at2&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at2" + + # Case 3: No content-type, but JSON + res.headers = {} + res.json.return_value = {"access_token": "at3", "token_type": "Bearer"} + tokens = _parse_token_response(res) + assert tokens.access_token == "at3" + + # Case 4: No content-type, not JSON, but Form + res.json.side_effect = json.JSONDecodeError("msg", "doc", 0) + res.text = "access_token=at4&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at4" + + # Case 5: Validation Error fallback + res.json.side_effect = ValidationError.from_exception_data("error", []) + res.text = "access_token=at5&token_type=Bearer" + tokens = _parse_token_response(res) + assert tokens.access_token == "at5" + + @patch("core.helper.ssrf_proxy.post") + def test_exchange_authorization(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + metadata = OAuthMetadata( + authorization_endpoint="https://auth/authorize", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + assert tokens.access_token == "at" + + # Failure: Unsupported grant type + metadata.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="Incompatible auth server"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + # Failure: HTTP error + metadata.grant_types_supported = ["authorization_code"] + res.is_success = False + res.status_code = 400 + with pytest.raises(ValueError, match="Token exchange failed"): + exchange_authorization("https://api", metadata, client_info, "code", "verifier", "https://re") + + @patch("core.helper.ssrf_proxy.post") + def test_refresh_authorization(self, mock_post): + # Case 1: with client_secret + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_new", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = refresh_authorization("https://api", None, client_info, "rt") + assert tokens.access_token == "at_new" + assert mock_post.call_args[1]["data"]["client_secret"] == "s1" + + # Failure: MaxRetriesExceededError + mock_post.side_effect = ssrf_proxy.MaxRetriesExceededError("Too many retries") + with pytest.raises(MCPRefreshTokenError): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: HTTP error + mock_post.side_effect = None + res.is_success = False + res.text = "error_msg" + with pytest.raises(MCPRefreshTokenError, match="error_msg"): + refresh_authorization("https://api", None, client_info, "rt") + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + refresh_authorization("https://api", metadata, client_info, "rt") + + @patch("core.helper.ssrf_proxy.post") + def test_client_credentials_flow(self, mock_post): + client_info = OAuthClientInformation(client_id="c1", client_secret="s1") + + # Success with secret + res = Mock() + res.is_success = True + res.headers = {"content-type": "application/json"} + res.json.return_value = {"access_token": "at_cc", "token_type": "Bearer"} + mock_post.return_value = res + + tokens = client_credentials_flow("https://api", None, client_info, "read") + assert tokens.access_token == "at_cc" + args, kwargs = mock_post.call_args + assert "Authorization" in kwargs["headers"] + + # Success without secret + client_info_no_secret = OAuthClientInformation(client_id="c2") + tokens = client_credentials_flow("https://api", None, client_info_no_secret) + args, kwargs = mock_post.call_args + assert kwargs["data"]["client_id"] == "c2" + + # Failure: Incompatible metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + with pytest.raises(ValueError, match="Incompatible auth server"): + client_credentials_flow("https://api", metadata, client_info) + + # Failure: HTTP error + res.is_success = False + res.status_code = 401 + res.text = "Unauthorized" + with pytest.raises(ValueError, match="Client credentials token request failed"): + client_credentials_flow("https://api", None, client_info) + + @patch("core.helper.ssrf_proxy.post") + def test_register_client(self, mock_post): + # Case 1: Success with metadata + metadata = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + registration_endpoint="https://auth/register", + response_types_supported=["code"], + ) + client_metadata = OAuthClientMetadata(client_name="Dify", redirect_uris=["https://re"]) + + res = Mock() + res.is_success = True + res.json.return_value = { + "client_id": "c_new", + "client_secret": "s_new", + "client_name": "Dify", + "redirect_uris": ["https://re"], + } + mock_post.return_value = res + + info = register_client("https://api", metadata, client_metadata) + assert info.client_id == "c_new" + + # Case 2: Success without metadata + info = register_client("https://api", None, client_metadata) + assert mock_post.call_args[0][0] == "https://api/register" + + # Case 3: Metadata provided but no endpoint + metadata.registration_endpoint = None + with pytest.raises(ValueError, match="does not support dynamic client registration"): + register_client("https://api", metadata, client_metadata) + + # Failure: HTTP + res.is_success = False + res.raise_for_status = Mock() + res.status_code = 400 + # If is_success is false, it should call raise_for_status + register_client("https://api", None, client_metadata) + res.raise_for_status.assert_called_once() + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_failures(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + + # Case 1: No server metadata + mock_discover.return_value = (None, None, None) + with pytest.raises(ValueError, match="Failed to discover OAuth metadata"): + auth(provider) + + # Case 2: No client info, exchange code provided + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + ) + mock_discover.return_value = (asm, None, None) + provider.retrieve_client_information.return_value = None + with pytest.raises(ValueError, match="Existing OAuth client information is required"): + auth(provider, authorization_code="code") + + # Case 3: CLIENT_CREDENTIALS but client must provide info + asm.grant_types_supported = ["client_credentials"] + with pytest.raises(ValueError, match="requires client_id and client_secret"): + auth(provider) + + # Case 4: Client registration fails + asm.grant_types_supported = ["authorization_code"] + with patch("core.mcp.auth.auth_flow.register_client") as mock_reg: + mock_reg.side_effect = httpx.RequestError("Reg failed") + with pytest.raises(ValueError, match="Could not register OAuth client"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_client_credentials(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1", client_secret="s1") + provider.decrypt_credentials.return_value = {"scope": "read"} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["client_credentials"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.client_credentials_flow") as mock_cc: + mock_cc.return_value = OAuthTokens(access_token="at_cc", token_type="Bearer") + + result = auth(provider) + assert result.response == {"result": "success"} + assert result.actions[0].action_type == AuthActionType.SAVE_TOKENS + assert result.actions[0].data["grant_type"] == "client_credentials" + + # Failure in CC flow + mock_cc.side_effect = ValueError("CC Failed") + with pytest.raises(ValueError, match="Client credentials flow failed"): + auth(provider) + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_authorization_code(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + # Case 1: Exchange code + with patch("core.mcp.auth.auth_flow._retrieve_redis_state") as mock_retrieve: + state = Mock(spec=OAuthCallbackState) + state.code_verifier = "cv" + state.redirect_uri = "https://re" + mock_retrieve.return_value = state + + with patch("core.mcp.auth.auth_flow.exchange_authorization") as mock_exchange: + mock_exchange.return_value = OAuthTokens(access_token="at_code", token_type="Bearer") + + # Success + result = auth(provider, authorization_code="code", state_param="sp") + assert result.response == {"result": "success"} + + # Missing state_param + with pytest.raises(ValueError, match="State parameter is required"): + auth(provider, authorization_code="code") + + # Missing verifier in state + state.code_verifier = None + with pytest.raises(ValueError, match="Missing code_verifier"): + auth(provider, authorization_code="code", state_param="sp") + + # Invalid state + mock_retrieve.side_effect = ValueError("Invalid") + with pytest.raises(ValueError, match="Invalid state parameter"): + auth(provider, authorization_code="code", state_param="sp") + + @patch("core.mcp.auth.auth_flow.discover_oauth_metadata") + def test_auth_orchestration_refresh_failure(self, mock_discover): + provider = Mock(spec=MCPProviderEntity) + provider.decrypt_server_url.return_value = "https://api" + provider.id = "p1" + provider.tenant_id = "t1" + provider.retrieve_client_information.return_value = OAuthClientInformation(client_id="c1") + provider.decrypt_credentials.return_value = {} + provider.retrieve_tokens.return_value = OAuthTokens(access_token="at", token_type="Bearer", refresh_token="rt") + + asm = OAuthMetadata( + authorization_endpoint="https://auth/auth", + token_endpoint="https://auth/token", + response_types_supported=["code"], + grant_types_supported=["authorization_code"], + ) + mock_discover.return_value = (asm, None, None) + + with patch("core.mcp.auth.auth_flow.refresh_authorization") as mock_refresh: + mock_refresh.side_effect = ValueError("Refresh Failed") + with pytest.raises(ValueError, match="Could not refresh OAuth tokens"): + auth(provider) diff --git a/api/tests/unit_tests/core/mcp/client/test_sse.py b/api/tests/unit_tests/core/mcp/client/test_sse.py index 490a647025..e6eeb6cd59 100644 --- a/api/tests/unit_tests/core/mcp/client/test_sse.py +++ b/api/tests/unit_tests/core/mcp/client/test_sse.py @@ -322,3 +322,475 @@ def test_sse_client_concurrent_access(): assert len(received_messages) == 10 for i in range(10): assert f"message_{i}" in received_messages + + +class TestStatusClasses: + """Tests for _StatusReady and _StatusError data containers.""" + + def test_status_ready_stores_endpoint(self): + from core.mcp.client.sse_client import _StatusReady + + status = _StatusReady("http://example.com/messages/") + assert status.endpoint_url == "http://example.com/messages/" + + def test_status_error_stores_exception(self): + from core.mcp.client.sse_client import _StatusError + + exc = ValueError("bad endpoint") + status = _StatusError(exc) + assert status.exc is exc + + +class TestSSETransportInit: + """Tests for SSETransport default and explicit init values.""" + + def test_defaults(self): + from core.mcp.client.sse_client import SSETransport + + t = SSETransport("http://example.com/sse") + assert t.url == "http://example.com/sse" + assert t.headers == {} + assert t.timeout == 5.0 + assert t.sse_read_timeout == 60.0 + assert t.endpoint_url is None + assert t.event_source is None + + def test_explicit_headers_not_mutated(self): + from core.mcp.client.sse_client import SSETransport + + hdrs = {"X-Foo": "bar"} + t = SSETransport("http://example.com/sse", headers=hdrs) + assert t.headers is hdrs + + +class TestHandleEndpointEvent: + """Tests for SSETransport._handle_endpoint_event covering the invalid-origin branch.""" + + def test_invalid_origin_puts_status_error(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Provide a full URL with a different origin so urljoin keeps it as-is + transport._handle_endpoint_event("http://evil.com/messages/", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusError) + assert "does not match" in str(result.exc) + + def test_valid_origin_puts_status_ready(self): + from core.mcp.client.sse_client import SSETransport, _StatusReady + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + transport._handle_endpoint_event("/messages/?session_id=abc", status_queue) + + result = status_queue.get_nowait() + assert isinstance(result, _StatusReady) + assert "example.com" in result.endpoint_url + + +class TestHandleSSEEvent: + """Tests for SSETransport._handle_sse_event covering all match branches.""" + + def _make_sse(self, event_type: str, data: str): + sse = Mock() + sse.event = event_type + sse.data = data + return sse + + def test_message_event_dispatched(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + valid_msg = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + transport._handle_sse_event(self._make_sse("message", valid_msg), read_queue, status_queue) + + item = read_queue.get_nowait() + assert hasattr(item, "message") + + def test_unknown_event_logs_warning_and_does_nothing(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + transport._handle_sse_event(self._make_sse("ping", "{}"), read_queue, status_queue) + + assert read_queue.empty() + assert status_queue.empty() + + +class TestSSEReader: + """Tests for SSETransport.sse_reader exception branches.""" + + def test_read_error_closes_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + event_source = Mock() + event_source.iter_sse.side_effect = httpx.ReadError("connection reset") + + transport.sse_reader(event_source, read_queue, status_queue) + + # Finally block always puts None as sentinel + sentinel = read_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_then_none(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + read_queue: queue.Queue = queue.Queue() + status_queue: queue.Queue = queue.Queue() + + boom = RuntimeError("unexpected!") + event_source = Mock() + event_source.iter_sse.side_effect = boom + + transport.sse_reader(event_source, read_queue, status_queue) + + exc_item = read_queue.get_nowait() + assert exc_item is boom + + sentinel = read_queue.get_nowait() + assert sentinel is None + + +class TestSendMessage: + """Tests for SSETransport._send_message.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_sends_post_and_raises_for_status(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.return_value = mock_response + + session_msg = self._make_session_message() + transport._send_message(mock_client, "http://example.com/messages/", session_msg) + + mock_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + +class TestPostWriter: + """Tests for SSETransport.post_writer exception branches.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_none_message_exits_loop(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + write_queue.put(None) # Signal shutdown immediately + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # Should put final None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_exception_in_message_put_back_to_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + exc = ValueError("some error") + write_queue.put(exc) # Exception goes in first + write_queue.put(None) # Then shutdown signal + + mock_client = Mock() + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # The exception should be re-queued, then None from loop exit, then None from finally + item1 = write_queue.get_nowait() + assert isinstance(item1, Exception) + + def test_read_error_shuts_down_cleanly(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_response = Mock() + mock_response.status_code = 200 + mock_client = Mock() + mock_client.post.side_effect = httpx.ReadError("connection dropped") + + # post_writer calls _send_message which calls client.post → ReadError propagates + # The ReadError is raised inside _send_message → propagates out of the while loop + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_generic_exception_puts_exc_in_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + session_msg = self._make_session_message() + write_queue.put(session_msg) + + mock_client = Mock() + boom = RuntimeError("boom") + mock_client.post.side_effect = boom + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + exc_item = write_queue.get_nowait() + assert isinstance(exc_item, Exception) + + sentinel = write_queue.get_nowait() + assert sentinel is None + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch (line 188) in post_writer.""" + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + write_queue: queue.Queue = queue.Queue() + + mock_client = Mock() + + # Patch queue.Queue.get so it raises Empty first, then returns None (shutdown) + call_count = {"n": 0} + original_get = write_queue.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + write_queue.get = patched_get # type: ignore[method-assign] + + transport.post_writer(mock_client, "http://example.com/messages/", write_queue) + + # finally always puts None sentinel + sentinel = write_queue.get_nowait() + assert sentinel is None + assert call_count["n"] >= 2 # Empty on first, None on second (and possibly more retries) + + +class TestWaitForEndpoint: + """Tests for SSETransport._wait_for_endpoint edge cases.""" + + def test_raises_on_empty_queue(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() # empty + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + def test_raises_status_error_exception(self): + from core.mcp.client.sse_client import SSETransport, _StatusError + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + exc = ValueError("malicious endpoint") + status_queue.put(_StatusError(exc)) + + with pytest.raises(ValueError, match="malicious endpoint"): + transport._wait_for_endpoint(status_queue) + + def test_raises_on_unknown_status_type(self): + from core.mcp.client.sse_client import SSETransport + + transport = SSETransport("http://example.com/sse") + status_queue: queue.Queue = queue.Queue() + + # Put an object that is neither _StatusReady nor _StatusError + status_queue.put("unexpected_value") + + with pytest.raises(ValueError, match="failed to get endpoint URL"): + transport._wait_for_endpoint(status_queue) + + +class TestSSEClientRuntimeError: + """Test sse_client context manager handles RuntimeError on close().""" + + def test_runtime_error_on_close_is_suppressed(self): + """Ensure RuntimeError raised by event_source.response.close() is caught.""" + test_url = "http://test.example/sse" + + class MockSSEEvent: + def __init__(self, event_type: str, data: str): + self.event = event_type + self.data = data + + endpoint_event = MockSSEEvent("endpoint", "/messages/?session_id=test-123") + + with patch("core.mcp.client.sse_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + with patch("core.mcp.client.sse_client.ssrf_proxy_sse_connect") as mock_sc: + mock_client = Mock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_es = Mock() + mock_es.response.raise_for_status.return_value = None + mock_es.iter_sse.return_value = [endpoint_event] + # Make close() raise RuntimeError to exercise line 307-308 + mock_es.response.close.side_effect = RuntimeError("already closed") + mock_sc.return_value.__enter__.return_value = mock_es + + # Should NOT raise even though close() raises RuntimeError + with contextlib.suppress(Exception): + with sse_client(test_url) as (rq, wq): + pass + + +class TestStandaloneSendMessage: + """Tests for the module-level send_message() function.""" + + def _make_session_message(self): + msg_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + msg = types.JSONRPCMessage.model_validate_json(msg_json) + return types.SessionMessage(msg) + + def test_send_message_success(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.status_code = 200 + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + mock_http_client.post.assert_called_once() + mock_response.raise_for_status.assert_called_once() + + def test_send_message_raises_on_http_error(self): + from core.mcp.client.sse_client import send_message + + mock_http_client = Mock() + mock_http_client.post.side_effect = httpx.ConnectError("refused") + + session_msg = self._make_session_message() + + with pytest.raises(httpx.ConnectError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + def test_send_message_raises_for_status_failure(self): + from core.mcp.client.sse_client import send_message + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Not Found", request=Mock(), response=Mock(status_code=404) + ) + mock_http_client = Mock() + mock_http_client.post.return_value = mock_response + + session_msg = self._make_session_message() + + with pytest.raises(httpx.HTTPStatusError): + send_message(mock_http_client, "http://example.com/messages/", session_msg) + + +class TestReadMessages: + """Tests for the module-level read_messages() generator.""" + + def _make_mock_sse_event(self, event_type: str, data: str): + ev = Mock() + ev.event = event_type + ev.data = data + return ev + + def test_valid_message_event_yields_session_message(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 1, "method": "ping"}' + mock_sse_event = self._make_mock_sse_event("message", valid_json) + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert hasattr(results[0], "message") + + def test_invalid_json_yields_exception(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("message", "{not valid json}") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert isinstance(results[0], Exception) + + def test_non_message_event_is_skipped(self): + from core.mcp.client.sse_client import read_messages + + mock_sse_event = self._make_mock_sse_event("endpoint", "/messages/") + + mock_client = Mock() + mock_client.events.return_value = [mock_sse_event] + + results = list(read_messages(mock_client)) + # Non-message events produce no output + assert results == [] + + def test_outer_exception_yields_exc(self): + from core.mcp.client.sse_client import read_messages + + boom = RuntimeError("stream broken") + mock_client = Mock() + mock_client.events.side_effect = boom + + results = list(read_messages(mock_client)) + assert len(results) == 1 + assert results[0] is boom + + def test_multiple_events_mixed(self): + from core.mcp.client.sse_client import read_messages + + valid_json = '{"jsonrpc": "2.0", "id": 2, "result": {}}' + events = [ + self._make_mock_sse_event("endpoint", "/messages/"), + self._make_mock_sse_event("message", valid_json), + self._make_mock_sse_event("message", "{bad json}"), + ] + + mock_client = Mock() + mock_client.events.return_value = events + + results = list(read_messages(mock_client)) + # endpoint is skipped; 1 valid SessionMessage + 1 Exception + assert len(results) == 2 + assert hasattr(results[0], "message") + assert isinstance(results[1], Exception) diff --git a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py index 9a30a35a49..81f8da9a62 100644 --- a/api/tests/unit_tests/core/mcp/client/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/client/test_streamable_http.py @@ -4,14 +4,39 @@ Tests for the StreamableHTTP client transport. Contains tests for only the client side of the StreamableHTTP transport. """ +import json import queue import threading import time +from contextlib import contextmanager +from datetime import timedelta from typing import Any -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest +from httpx_sse import ServerSentEvent from core.mcp import types -from core.mcp.client.streamable_client import streamablehttp_client +from core.mcp.client.streamable_client import ( + LAST_EVENT_ID, + MCP_SESSION_ID, + RequestContext, + ResumptionError, + StreamableHTTPError, + StreamableHTTPTransport, + streamablehttp_client, +) +from core.mcp.types import ( + ClientMessageMetadata, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, + JSONRPCResponse, + SessionMessage, +) # Test constants SERVER_NAME = "test_streamable_http_server" @@ -448,3 +473,1169 @@ def test_streamablehttp_client_resumption_token_handling(): assert write_queue is not None except Exception: pass # Expected due to mocking + + +# ── helpers ─────────────────────────────────────────────────────────────────── + + +def _make_request_msg(method: str = "ping", req_id: int = 1) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=req_id, method=method)) + + +def _make_response_msg(req_id: int = 1, result: dict | None = None) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=req_id, result=result or {})) + + +def _make_error_msg(req_id: int = 1, code: int = -32600) -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=code, message="err"))) + + +def _make_notification_msg(method: str = "notifications/initialized") -> JSONRPCMessage: + return JSONRPCMessage(root=JSONRPCNotification(jsonrpc="2.0", method=method)) + + +def _make_sse_mock(event: str = "message", data: str = "", sse_id: str = "") -> ServerSentEvent: + # Use real ServerSentEvent since StreamableHTTPTransport requires its structure + return ServerSentEvent(event=event, data=data, id=sse_id, retry=None) + + +def _new_transport(url: str = "http://example.com/mcp", **kwargs) -> StreamableHTTPTransport: + return StreamableHTTPTransport(url, **kwargs) + + +# ── StreamableHTTPTransport.__init__ ───────────────────────────────────────── + + +class TestStreamableHTTPTransportInit: + def test_defaults(self): + t = _new_transport() + assert t.url == "http://example.com/mcp" + assert t.headers == {} + assert t.timeout == 30 + assert t.sse_read_timeout == 300 + assert t.session_id is None + assert t.stop_event is not None + assert t._active_responses == [] + + def test_timedelta_timeout_and_sse_read_timeout(self): + t = _new_transport(timeout=timedelta(seconds=10), sse_read_timeout=timedelta(seconds=120)) + assert t.timeout == 10.0 + assert t.sse_read_timeout == 120.0 + + def test_custom_headers_merged_into_request_headers(self): + t = _new_transport(headers={"Authorization": "Bearer tok"}) + assert t.request_headers["Authorization"] == "Bearer tok" + assert "Accept" in t.request_headers + assert "content-type" in t.request_headers + + +# ── _update_headers_with_session ───────────────────────────────────────────── + + +class TestUpdateHeadersWithSession: + def test_no_session_id_returns_copy_without_session_header(self): + t = _new_transport() + t.session_id = None + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result == {"X-Foo": "bar"} + assert MCP_SESSION_ID not in result + + def test_with_session_id_adds_header(self): + t = _new_transport() + t.session_id = "sess-abc" + result = t._update_headers_with_session({"X-Foo": "bar"}) + assert result[MCP_SESSION_ID] == "sess-abc" + assert result["X-Foo"] == "bar" + + +# ── _register_response / _unregister_response / close_active_responses ──────── + + +class TestResponseRegistry: + def test_register_and_unregister(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._register_response(resp) + assert resp in t._active_responses + t._unregister_response(resp) + assert resp not in t._active_responses + + def test_unregister_not_registered_does_not_raise(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + t._unregister_response(resp) # Should swallow ValueError silently + + def test_close_active_responses_calls_close(self): + t = _new_transport() + resp1 = MagicMock(spec=httpx.Response) + resp2 = MagicMock(spec=httpx.Response) + t._register_response(resp1) + t._register_response(resp2) + t.close_active_responses() + resp1.close.assert_called_once() + resp2.close.assert_called_once() + assert t._active_responses == [] + + def test_close_active_responses_swallows_runtime_error(self): + t = _new_transport() + resp = MagicMock(spec=httpx.Response) + resp.close.side_effect = RuntimeError("already closed") + t._register_response(resp) + t.close_active_responses() # Should not raise + + +# ── _is_initialization_request / _is_initialized_notification ──────────────── + + +class TestMessageClassifiers: + def test_is_initialization_request_true(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("initialize")) is True + + def test_is_initialization_request_false_other_method(self): + t = _new_transport() + assert t._is_initialization_request(_make_request_msg("tools/list")) is False + + def test_is_initialization_request_false_not_request(self): + t = _new_transport() + assert t._is_initialization_request(_make_response_msg()) is False + + def test_is_initialized_notification_true(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/initialized")) is True + + def test_is_initialized_notification_false_other_method(self): + t = _new_transport() + assert t._is_initialized_notification(_make_notification_msg("notifications/cancelled")) is False + + def test_is_initialized_notification_false_not_notification(self): + t = _new_transport() + assert t._is_initialized_notification(_make_request_msg("notifications/initialized")) is False + + +# ── _maybe_extract_session_id_from_response ─────────────────────────────────── + + +class TestMaybeExtractSessionIdNew: + def test_extracts_session_id_when_present(self): + t = _new_transport() + resp = MagicMock() + resp.headers = {MCP_SESSION_ID: "new-session-99"} + t._maybe_extract_session_id_from_response(resp) + assert t.session_id == "new-session-99" + + def test_no_session_id_header_leaves_none(self): + t = _new_transport() + resp = MagicMock() + resp.headers = MagicMock() + resp.headers.get = MagicMock(return_value=None) + t._maybe_extract_session_id_from_response(resp) + assert t.session_id is None + + +# ── _handle_sse_event ───────────────────────────────────────────────────────── + + +class TestHandleSseEventNew: + def test_message_event_response_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}})) + assert t._handle_sse_event(sse, q) is True + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_error_returns_true(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "error": {"code": -32600, "message": "bad"}}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is True + + def test_message_event_notification_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "method": "notifications/something"}) + sse = _make_sse_mock("message", data) + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), SessionMessage) + + def test_message_event_empty_data_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", " ") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_message_event_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("message", "{bad json}") + assert t._handle_sse_event(sse, q) is False + assert isinstance(q.get_nowait(), Exception) + + def test_message_event_replaces_original_request_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + t._handle_sse_event(sse, q, original_request_id=999) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert item.message.root.id == 999 + + def test_message_event_calls_resumption_callback_when_sse_id_present(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="token-abc") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_called_once_with("token-abc") + + def test_message_event_no_callback_when_no_sse_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="") + callback = MagicMock() + t._handle_sse_event(sse, q, resumption_callback=callback) + callback.assert_not_called() + + def test_ping_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("ping", "") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + def test_unknown_event_returns_false(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + sse = _make_sse_mock("custom_event", "{}") + assert t._handle_sse_event(sse, q) is False + assert q.empty() + + +# ── handle_get_stream ───────────────────────────────────────────────────────── + + +class TestHandleGetStreamNew: + def test_skips_when_no_session_id(self): + t = _new_transport() + t.session_id = None + q: queue.Queue = queue.Queue() + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + t.handle_get_stream(MagicMock(), q) + mock_connect.assert_not_called() + + def test_handles_messages_via_sse(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t.handle_get_stream(MagicMock(), q) + + assert q.empty() + + def test_exception_when_not_stopped_is_logged(self): + t = _new_transport() + t.session_id = "sess-1" + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise + + def test_exception_when_stopped_is_suppressed(self): + t = _new_transport() + t.session_id = "sess-1" + t.stop_event.set() + q: queue.Queue = queue.Queue() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.side_effect = Exception("connection error") + t.handle_get_stream(MagicMock(), q) # Should not raise or log + + +# ── _handle_resumption_request ──────────────────────────────────────────────── + + +class TestHandleResumptionRequestNew: + def _make_ctx(self, transport, q, resumption_token="token-123", message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", req_id=42) + session_msg = SessionMessage(message) + metadata = None + if resumption_token: + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = resumption_token + metadata.on_resumption_token_update = MagicMock() + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=session_msg, + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_raises_resumption_error_without_token(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = None + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_raises_resumption_error_without_metadata(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + with pytest.raises(ResumptionError): + t._handle_resumption_request(ctx) + + def test_sets_last_event_id_header(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, resumption_token="resume-999") + + captured_headers: dict = {} + data = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + def fake_connect(url, headers, **kwargs): + captured_headers.update(headers) + + @contextmanager + def _ctx(): + yield mock_event_source + + return _ctx() + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect", side_effect=fake_connect): + t._handle_resumption_request(ctx) + + assert captured_headers.get(LAST_EVENT_ID) == "resume-999" + + def test_stops_when_response_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q, message=_make_request_msg("tools/list", 42)) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 42, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 43, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [sse1, sse2] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + # Only the first event was processed (loop breaks on completion) + assert q.qsize() == 1 + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + mock_response.raise_for_status.return_value = None + mock_event_source = MagicMock() + mock_event_source.response = mock_response + mock_event_source.iter_sse.return_value = [mock_sse_event] + + with patch("core.mcp.client.streamable_client.ssrf_proxy_sse_connect") as mock_connect: + mock_connect.return_value.__enter__.return_value = mock_event_source + t._handle_resumption_request(ctx) + + assert q.empty() + + +# ── _handle_post_request ────────────────────────────────────────────────────── + + +class TestHandlePostRequestNew: + def _make_ctx(self, transport, q, message=None) -> RequestContext: + if message is None: + message = _make_request_msg("tools/list", 1) + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=transport.session_id, + session_message=SessionMessage(message), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def _stream_ctx(self, mock_response): + @contextmanager + def _stream(*args, **kwargs): + yield mock_response + + return _stream + + def test_202_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 202 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_204_returns_immediately_no_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + mock_resp = MagicMock() + mock_resp.status_code = 204 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_404_sends_session_terminated_error_for_request(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("tools/list", 77) + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 77 + + def test_404_for_notification_no_error_sent(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("some/notification") + ctx = self._make_ctx(t, q, message=msg) + mock_resp = MagicMock() + mock_resp.status_code = 404 + ctx.client.stream = self._stream_ctx(mock_resp) + t._handle_post_request(ctx) + assert q.empty() + + def test_json_response_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {"ok": True}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_json_response_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = b"{bad json!" + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert isinstance(q.get_nowait(), Exception) + + def test_unexpected_content_type_puts_value_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/plain"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "Unexpected content type" in str(item) + + def test_initialization_request_extracts_session_id(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_request_msg("initialize", 1) + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = MagicMock() + headers_dict = {"content-type": "application/json", MCP_SESSION_ID: "new-sid"} + mock_resp.headers.__getitem__ = lambda self, k: headers_dict[k] + mock_resp.headers.get = lambda k, default=None: headers_dict.get(k, default) + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert t.session_id == "new-sid" + + def test_notification_skips_response_processing(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + msg = _make_notification_msg("notifications/something") + ctx = self._make_ctx(t, q, message=msg) + + response_data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "application/json"} + mock_resp.raise_for_status.return_value = None + mock_resp.read.return_value = response_data + ctx.client.stream = self._stream_ctx(mock_resp) + + t._handle_post_request(ctx) + assert q.empty() + + def test_sse_response_handles_stream(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._make_ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.headers = {"content-type": "text/event-stream"} + mock_resp.raise_for_status.return_value = None + ctx.client.stream = self._stream_ctx(mock_resp) + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_post_request(ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + +# ── _handle_json_response ───────────────────────────────────────────────────── + + +class TestHandleJsonResponseNew: + def test_valid_json_puts_session_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}).encode() + mock_response = MagicMock() + mock_response.read.return_value = data + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), SessionMessage) + + def test_invalid_json_puts_exception(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + mock_response = MagicMock() + mock_response.read.return_value = b"{ invalid }" + t._handle_json_response(mock_response, q) + assert isinstance(q.get_nowait(), Exception) + + +# ── _handle_sse_response ────────────────────────────────────────────────────── + + +class TestHandleSseResponseNew: + def _ctx(self, transport, q) -> RequestContext: + return RequestContext( + client=MagicMock(), + headers=transport.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + def test_processes_sse_events(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), SessionMessage) + + def test_stops_when_stop_event_set(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + mock_sse_event = _make_sse_mock("message", data) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [mock_sse_event] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_stops_when_complete(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + + data1 = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + data2 = json.dumps({"jsonrpc": "2.0", "id": 2, "result": {}}) + sse1 = _make_sse_mock("message", data1) + sse2 = _make_sse_mock("message", data2) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse1, sse2] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + assert q.qsize() == 1 # Only the first completion item + + def test_exception_outside_stop_puts_to_queue(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert isinstance(q.get_nowait(), Exception) + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + t.stop_event.set() + q: queue.Queue = queue.Queue() + ctx = self._ctx(t, q) + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + MockEventSource.side_effect = RuntimeError("EventSource error") + t._handle_sse_response(mock_response, ctx) + + assert q.empty() + + def test_with_metadata_resumption_callback(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + metadata = MagicMock(spec=ClientMessageMetadata) + callback = MagicMock() + metadata.on_resumption_token_update = callback + + ctx = RequestContext( + client=MagicMock(), + headers=t.request_headers, + session_id=None, + session_message=SessionMessage(_make_request_msg()), + metadata=metadata, + server_to_client_queue=q, + sse_read_timeout=60, + ) + + data = json.dumps({"jsonrpc": "2.0", "id": 1, "result": {}}) + sse = _make_sse_mock("message", data, sse_id="resume-token") + mock_response = MagicMock() + + with patch("core.mcp.client.streamable_client.EventSource") as MockEventSource: + mock_es_instance = MagicMock() + mock_es_instance.iter_sse.return_value = [sse] + MockEventSource.return_value = mock_es_instance + t._handle_sse_response(mock_response, ctx) + + callback.assert_called_once_with("resume-token") + + +# ── _handle_unexpected_content_type ────────────────────────────────────────── + + +class TestHandleUnexpectedContentTypeNew: + def test_puts_value_error_with_message(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._handle_unexpected_content_type("text/html", q) + item = q.get_nowait() + assert isinstance(item, ValueError) + assert "text/html" in str(item) + + +# ── _send_session_terminated_error ──────────────────────────────────────────── + + +class TestSendSessionTerminatedErrorNew: + def test_puts_jsonrpc_error(self): + t = _new_transport() + q: queue.Queue = queue.Queue() + t._send_session_terminated_error(q, 42) + item = q.get_nowait() + assert isinstance(item, SessionMessage) + assert isinstance(item.message.root, JSONRPCError) + assert item.message.root.id == 42 + assert item.message.root.error.code == 32600 + assert "terminated" in item.message.root.error.message.lower() + + +# ── post_writer ─────────────────────────────────────────────────────────────── + + +class TestPostWriterNew: + def test_none_message_exits_loop(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + c2s.put(None) + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_stop_event_exits_loop(self): + t = _new_transport() + t.stop_event.set() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + def test_initialized_notification_calls_start_get_stream(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + notif_msg = _make_notification_msg("notifications/initialized") + c2s.put(SessionMessage(notif_msg)) + c2s.put(None) + + with patch.object(t, "_handle_post_request"): + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + start_get_stream.assert_called_once() + + def test_resumption_message_calls_handle_resumption_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + start_get_stream = MagicMock() + + msg = SessionMessage(_make_request_msg("tools/list", 10)) + metadata = MagicMock(spec=ClientMessageMetadata) + metadata.resumption_token = "resume-abc" + msg.metadata = metadata + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_resumption_request") as mock_resumption: + t.post_writer(MagicMock(), c2s, s2c, start_get_stream) + + mock_resumption.assert_called_once() + + def test_regular_message_calls_handle_post_request(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + mock_post.assert_called_once() + + def test_exception_in_handler_put_to_s2c_when_not_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + item = s2c.get_nowait() + assert item is boom + + def test_exception_suppressed_when_stopped(self): + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + c2s.put(msg) + c2s.put(None) + t.stop_event.set() + + boom = RuntimeError("oops") + with patch.object(t, "_handle_post_request", side_effect=boom): + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + assert s2c.empty() + + def test_queue_empty_timeout_continues_loop(self): + """Cover the 'except queue.Empty: continue' branch in post_writer.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + call_count = {"n": 0} + + original_get = c2s.get + + def patched_get(*args, **kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + raise queue.Empty + + c2s.get = patched_get # type: ignore[method-assign] + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + assert call_count["n"] >= 2 + + def test_non_client_metadata_treated_as_none(self): + """session_message.metadata that's not ClientMessageMetadata → metadata is None.""" + t = _new_transport() + c2s: queue.Queue = queue.Queue() + s2c: queue.Queue = queue.Queue() + + msg = SessionMessage(_make_request_msg("tools/list", 5)) + msg.metadata = "not-a-client-metadata" + c2s.put(msg) + c2s.put(None) + + with patch.object(t, "_handle_post_request") as mock_post: + t.post_writer(MagicMock(), c2s, s2c, MagicMock()) + + ctx = mock_post.call_args[0][0] + assert ctx.metadata is None + + +# ── terminate_session ───────────────────────────────────────────────────────── + + +class TestTerminateSessionNew: + def test_no_session_id_skips(self): + t = _new_transport() + t.session_id = None + mock_client = MagicMock() + t.terminate_session(mock_client) + mock_client.delete.assert_not_called() + + def test_200_response_is_success(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) + mock_client.delete.assert_called_once() + + def test_405_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 405 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_non_200_logs_warning_does_not_raise(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 500 + mock_client.delete.return_value = mock_response + t.terminate_session(mock_client) # Should not raise + + def test_exception_is_swallowed(self): + t = _new_transport() + t.session_id = "sess-1" + mock_client = MagicMock() + mock_client.delete.side_effect = httpx.ConnectError("refused") + t.terminate_session(mock_client) # Should not raise + + +# ── get_session_id ──────────────────────────────────────────────────────────── + + +class TestGetSessionIdNew: + def test_returns_none_when_no_session(self): + t = _new_transport() + assert t.get_session_id() is None + + def test_returns_session_id_when_set(self): + t = _new_transport() + t.session_id = "my-session" + assert t.get_session_id() == "my-session" + + +# ── streamablehttp_client context manager ───────────────────────────────────── + + +class TestStreamablehttpClientContextManagerNew: + def test_yields_queues_and_callback(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + assert s2c is not None + assert c2s is not None + assert callable(get_sid) + + def test_terminate_on_close_false_does_not_delete(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=False) as (s2c, c2s, get_sid): + pass + mock_client.delete.assert_not_called() + + def test_queue_cleanup_on_outer_exception(self): + """Verify cleanup in finally block runs even when create_ssrf raises.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_cf.side_effect = RuntimeError("connection failed") + + with pytest.raises(RuntimeError): + with streamablehttp_client("http://example.com/mcp"): + pass # pragma: no cover + + def test_timedelta_args_accepted(self): + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client( + "http://example.com/mcp", + timeout=timedelta(seconds=15), + sse_read_timeout=timedelta(seconds=60), + ) as (s2c, c2s, get_sid): + assert callable(get_sid) + + def test_start_get_stream_submits_to_executor(self): + """When context starts, post_writer is submitted to executor.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + submitted_calls = [] + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + + def capture_submit(fn, *args, **kwargs): + submitted_calls.append((fn, args)) + + mock_executor.submit.side_effect = capture_submit + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # post_writer was submitted + assert len(submitted_calls) >= 1 + + def test_cleanup_puts_none_sentinels_to_queues(self): + """After context exit, None sentinels are put into both queues.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with streamablehttp_client("http://example.com/mcp") as (s2c, c2s, get_sid): + pass + + # After context exit, None sentinel should be in c2s queue from cleanup + val = c2s.get_nowait() + assert val is None + + def test_terminate_called_when_session_id_set(self): + """When session_id is set and terminate_on_close=True, terminate_session is called.""" + from core.mcp.client.streamable_client import streamablehttp_client + + with patch("core.mcp.client.streamable_client.create_ssrf_proxy_mcp_http_client") as mock_cf: + mock_client = MagicMock() + mock_cf.return_value.__enter__.return_value = mock_client + + mock_delete_resp = MagicMock() + mock_delete_resp.status_code = 200 + mock_client.delete.return_value = mock_delete_resp + + with patch("core.mcp.client.streamable_client.ThreadPoolExecutor") as mock_exec: + mock_executor = MagicMock() + mock_exec.return_value = mock_executor + + with patch("core.mcp.client.streamable_client.StreamableHTTPTransport") as MockTransport: + mock_transport = MockTransport.return_value + mock_transport.request_headers = { + "Accept": "application/json, text/event-stream", + "content-type": "application/json", + } + mock_transport.timeout = 30 + mock_transport.sse_read_timeout = 300 + mock_transport.session_id = "active-session" + mock_transport.stop_event = MagicMock() + mock_transport.get_session_id = MagicMock(return_value="active-session") + + with streamablehttp_client("http://example.com/mcp", terminate_on_close=True) as ( + s2c, + c2s, + get_sid, + ): + pass + + mock_transport.terminate_session.assert_called_once_with(mock_client) + + +# ── Exception hierarchy ─────────────────────────────────────────────────────── + + +class TestExceptionHierarchyNew: + def test_streamable_http_error_is_exception(self): + err = StreamableHTTPError("test") + assert isinstance(err, Exception) + + def test_resumption_error_is_streamable_http_error(self): + err = ResumptionError("test") + assert isinstance(err, StreamableHTTPError) + assert isinstance(err, Exception) + + +# ── RequestContext dataclass ────────────────────────────────────────────────── + + +class TestRequestContextNew: + def test_creation(self): + import queue + + q: queue.Queue = queue.Queue() + ctx = RequestContext( + client=MagicMock(), + headers={"X-Test": "val"}, + session_id="sid", + session_message=SessionMessage(_make_request_msg()), + metadata=None, + server_to_client_queue=q, + sse_read_timeout=30.0, + ) + assert ctx.session_id == "sid" + assert ctx.sse_read_timeout == 30.0 + assert ctx.metadata is None diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index 40a7700394..f982765b1a 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -18,7 +18,7 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/mcp/session/test_base_session.py b/api/tests/unit_tests/core/mcp/session/test_base_session.py new file mode 100644 index 0000000000..1dd916bcf1 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_base_session.py @@ -0,0 +1,617 @@ +import queue +import time +from concurrent.futures import Future, ThreadPoolExecutor +from datetime import timedelta +from typing import Union +from unittest.mock import MagicMock, patch + +import pytest +from httpx import HTTPStatusError, Request, Response +from pydantic import BaseModel, ConfigDict, RootModel + +from core.mcp.error import MCPAuthError, MCPConnectionError +from core.mcp.session.base_session import BaseSession, RequestResponder +from core.mcp.types import ( + CancelledNotification, + ClientNotification, + ClientRequest, + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCResponse, + Notification, + RequestParams, + SessionMessage, +) +from core.mcp.types import ( + Request as MCPRequest, +) + + +class MockRequestParams(RequestParams): + name: str = "default" + model_config = ConfigDict(extra="allow") + + +class MockRequest(MCPRequest[MockRequestParams, str]): + method: str = "test/request" + params: MockRequestParams = MockRequestParams() + + +class MockResult(BaseModel): + result: str + + +class MockNotificationParams(BaseModel): + message: str + + +class MockNotification(Notification[MockNotificationParams, str]): + method: str = "test/notification" + params: MockNotificationParams + + +class ReceiveRequest(RootModel[Union[MockRequest, ClientRequest]]): + pass + + +class ReceiveNotification(RootModel[Union[CancelledNotification, MockNotification, JSONRPCNotification]]): + pass + + +class MockSession(BaseSession[MockRequest, MockNotification, MockResult, ReceiveRequest, ReceiveNotification]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.received_requests = [] + self.received_notifications = [] + self.handled_incoming = [] + + def _received_request(self, responder): + self.received_requests.append(responder) + + def _received_notification(self, notification): + self.received_notifications.append(notification) + + def _handle_incoming(self, item): + self.handled_incoming.append(item) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +@pytest.mark.timeout(5) +def test_request_responder_respond(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.respond(MockResult(result="ok")) + + with responder as r: + r.respond(MockResult(result="ok")) + with pytest.raises(AssertionError, match="Request already responded to"): + r.respond(MockResult(result="error")) + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCResponse) + assert msg.message.root.result == {"result": "ok"} + + +@pytest.mark.timeout(5) +def test_request_responder_cancel(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + on_complete = MagicMock() + request = ReceiveRequest(MockRequest(method="test", params=MockRequestParams(name="test"))) + + responder = RequestResponder( + request_id=1, request_meta=None, request=request, session=session, on_complete=on_complete + ) + + with pytest.raises(RuntimeError, match="RequestResponder must be used as a context manager"): + responder.cancel() + + with responder as r: + r.cancel() + + assert responder.completed is True + on_complete.assert_called_once_with(responder) + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.error.message == "Request cancelled" + + +@pytest.mark.timeout(10) +def test_base_session_lifecycle(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session as s: + assert isinstance(s, MockSession) + assert s._executor is not None + assert s._receiver_future is not None + + session._receiver_future.result(timeout=5.0) + assert session._receiver_future.done() + + +@pytest.mark.timeout(5) +def test_send_request_success(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "hello world"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except Exception: + pass + + import threading + + t = threading.Thread(target=mock_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult) + assert result.result == "hello world" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_retry_loop_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_delayed_response(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + time.sleep(0.2) + response = JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"result": "slow"}) + read_stream.put(SessionMessage(message=JSONRPCMessage(response))) + except: + pass + + import threading + + t = threading.Thread(target=mock_delayed_response, daemon=True) + t.start() + + with session: + result = session.send_request(request, MockResult, request_read_timeout_seconds=timedelta(seconds=0.1)) + assert result.result == "slow" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_jsonrpc_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=-32000, message="Error")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "Error" + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + error = JSONRPCError(jsonrpc="2.0", id=req_id, error=ErrorData(code=401, message="Unauthorized")) + read_stream.put(SessionMessage(message=JSONRPCMessage(error))) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_error_coverage(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_direct_http_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + # To cover line 263 in base_session.py, we MUST put non-401 HTTPStatusError + # DIRECTLY into response_streams, as _receive_loop would convert it to JSONRPCError. + response = Response(status_code=403, request=Request("GET", "http://test")) + error = HTTPStatusError("Forbidden", request=response.request, response=response) + session._response_streams[req_id].put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_direct_http_error, daemon=True) + t.start() + + # We still need the session for request ID generation and queue setup + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].code == 403 + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_request_http_status_auth_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_error(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + except: + pass + + import threading + + t = threading.Thread(target=mock_error, daemon=True) + t.start() + + with session: + with pytest.raises(MCPAuthError): + session.send_request(request, MockResult) + t.join(timeout=1) + + +@pytest.mark.timeout(5) +def test_send_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + notification = MockNotification(method="notify", params=MockNotificationParams(message="hi")) + + session.send_notification(notification, related_request_id="rel-1") + + msg = write_stream.get_nowait() + assert isinstance(msg.message.root, JSONRPCNotification) + assert msg.message.root.method == "notify" + assert msg.message.root.params == {"message": "hi"} + assert msg.metadata.related_request_id == "rel-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_request(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": 1, "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if session.received_requests: + break + time.sleep(0.1) + + assert len(session.received_requests) == 1 + responder = session.received_requests[0] + assert responder.request_id == 1 + assert responder.request.root.method == "test/request" + + +@pytest.mark.timeout(10) +def test_receive_loop_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "test/notification", "params": {"message": "hello"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + + for _ in range(30): + if session.received_notifications: + break + time.sleep(0.1) + + assert len(session.received_notifications) == 1 + assert isinstance(session.received_notifications[0].root, MockNotification) + assert session.received_notifications[0].root.method == "test/notification" + + +@pytest.mark.timeout(15) +def test_receive_loop_cancel_notification(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ClientNotification) + + with session: + req_payload = {"jsonrpc": "2.0", "id": "req-1", "method": "test/request", "params": {"name": "test"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(req_payload))) + + for _ in range(30): + if "req-1" in session._in_flight: + break + time.sleep(0.1) + + assert "req-1" in session._in_flight + responder = session._in_flight["req-1"] + + with responder: + cancel_payload = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": "req-1"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(cancel_payload))) + + for _ in range(30): + if responder.completed: + break + time.sleep(0.1) + + assert responder.completed is True + msg = write_stream.get(timeout=2) + assert isinstance(msg.message.root, JSONRPCError) + assert msg.message.root.id == "req-1" + + +@pytest.mark.timeout(10) +def test_receive_loop_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + read_stream.put(Exception("Unexpected error")) + for _ in range(30): + if any(isinstance(x, Exception) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any(isinstance(x, Exception) and str(x) == "Unexpected error" for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=401, request=Request("GET", "http://test")) + # Using 401 specifically as _receive_loop preserves it + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, HTTPStatusError) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_status_error_non_401(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + session._request_id = 1 + resp_queue = queue.Queue() + session._response_streams[0] = resp_queue + + response = Response(status_code=500, request=Request("GET", "http://test")) + error = HTTPStatusError("Server Error", request=response.request, response=response) + read_stream.put(error) + + got = resp_queue.get(timeout=2) + assert isinstance(got, JSONRPCError) + assert got.error.code == 500 + + +@pytest.mark.timeout(5) +def test_check_receiver_status_fail(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + executor = ThreadPoolExecutor(max_workers=1) + + def raise_err(): + raise RuntimeError("Receiver failed") + + future = executor.submit(raise_err) + session._receiver_future = future + + try: + future.result() + except: + pass + + with pytest.raises(RuntimeError, match="Receiver failed"): + session.check_receiver_status() + executor.shutdown() + + +@pytest.mark.timeout(10) +def test_receive_loop_unknown_request_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + resp = JSONRPCResponse(jsonrpc="2.0", id=999, result={"ok": True}) + read_stream.put(SessionMessage(message=JSONRPCMessage(resp))) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "Server Error" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("Server Error" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_http_error_unknown_id(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with session: + response = Response(status_code=401, request=Request("GET", "http://test")) + error = HTTPStatusError("Unauthorized", request=response.request, response=response) + read_stream.put(error) + + for _ in range(30): + if any(isinstance(x, RuntimeError) and "unknown request ID" in str(x) for x in session.handled_incoming): + break + time.sleep(0.1) + + assert any("unknown request ID" in str(x) for x in session.handled_incoming) + + +@pytest.mark.timeout(10) +def test_receive_loop_validation_error_notification(streams): + from core.mcp.session.base_session import logger + + with patch.object(logger, "warning") as mock_warning: + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, RootModel[MockNotification]) + + with session: + notif_payload = {"jsonrpc": "2.0", "method": "bad", "params": {"some": "data"}} + read_stream.put(SessionMessage(message=JSONRPCMessage.model_validate(notif_payload))) + time.sleep(1.0) + + assert mock_warning.called + + +@pytest.mark.timeout(5) +def test_send_request_none_response(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + def mock_none(): + try: + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + session._response_streams[req_id].put(None) + except: + pass + + import threading + + t = threading.Thread(target=mock_none, daemon=True) + t.start() + + with session: + with pytest.raises(MCPConnectionError) as exc: + session.send_request(request, MockResult) + assert exc.value.args[0].message == "No response received" + t.join(timeout=1) + + +@pytest.mark.timeout(15) +def test_session_exit_timeout(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + mock_future = MagicMock(spec=Future) + mock_future.result.side_effect = TimeoutError() + mock_future.done.return_value = False + + session._receiver_future = mock_future + session._executor = MagicMock(spec=ThreadPoolExecutor) + + session.__exit__(None, None, None) + + mock_future.cancel.assert_called_once() + session._executor.shutdown.assert_called_once_with(wait=False) + + +@pytest.mark.timeout(10) +def test_receive_loop_fatal_exception(streams): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + with patch.object(read_stream, "get", side_effect=RuntimeError("Fatal loop error")): + with patch("core.mcp.session.base_session.logger") as mock_logger: + with pytest.raises(RuntimeError, match="Fatal loop error"): + with session: + pass + mock_logger.exception.assert_called_with("Error in message processing loop") + + +@pytest.mark.timeout(5) +def test_receive_loop_empty_coverage(streams): + with patch("core.mcp.session.base_session.DEFAULT_RESPONSE_READ_TIMEOUT", 0.1): + read_stream, write_stream = streams + session = MockSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + with session: + time.sleep(0.3) + + +@pytest.mark.timeout(2) +def test_base_methods_noop(streams): + read_stream, write_stream = streams + session = BaseSession(read_stream, write_stream, ReceiveRequest, ReceiveNotification) + + session._received_request(MagicMock()) + session._received_notification(MagicMock()) + session.send_progress_notification("token", 0.5) + session._handle_incoming(MagicMock()) + + +@pytest.mark.timeout(5) +def test_send_request_session_timeout_retry_6(streams): + read_stream, write_stream = streams + session = MockSession( + read_stream, write_stream, ReceiveRequest, ReceiveNotification, read_timeout_seconds=timedelta(seconds=0.1) + ) + + request = MockRequest(method="test", params=MockRequestParams(name="world")) + + with patch.object(session, "check_receiver_status", side_effect=[None, RuntimeError("timeout_broken")]): + with pytest.raises(RuntimeError, match="timeout_broken"): + session.send_request(request, MockResult) diff --git a/api/tests/unit_tests/core/mcp/session/test_client_session.py b/api/tests/unit_tests/core/mcp/session/test_client_session.py new file mode 100644 index 0000000000..c7b9d3cfa9 --- /dev/null +++ b/api/tests/unit_tests/core/mcp/session/test_client_session.py @@ -0,0 +1,576 @@ +import queue +from unittest.mock import MagicMock + +import pytest +from pydantic import AnyUrl + +from core.mcp import types +from core.mcp.session.base_session import RequestResponder, SessionMessage +from core.mcp.session.client_session import ( + ClientSession, + _default_list_roots_callback, + _default_logging_callback, + _default_message_handler, + _default_sampling_callback, +) + + +@pytest.fixture +def streams(): + return queue.Queue(), queue.Queue() + + +def test_client_session_init(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + assert session._client_info.name == "Dify" + assert session._sampling_callback == _default_sampling_callback + assert session._list_roots_callback == _default_list_roots_callback + assert session._logging_callback == _default_logging_callback + assert session._message_handler == _default_message_handler + + +def test_client_session_init_custom(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock() + list_roots_cb = MagicMock() + logging_cb = MagicMock() + msg_handler = MagicMock() + client_info = types.Implementation(name="Custom", version="1.0") + + session = ClientSession( + read_stream, + write_stream, + sampling_callback=sampling_cb, + list_roots_callback=list_roots_cb, + logging_callback=logging_cb, + message_handler=msg_handler, + client_info=client_info, + ) + + assert session._client_info == client_info + assert session._sampling_callback == sampling_cb + assert session._list_roots_callback == list_roots_cb + assert session._logging_callback == logging_cb + assert session._message_handler == msg_handler + + +def test_initialize_success(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + expected_result = types.InitializeResult( + protocolVersion=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ServerCapabilities(), + serverInfo=types.Implementation(name="test-server", version="1.0"), + ) + + def mock_server(): + # Handle initialize request + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result=expected_result.model_dump()) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + # Expect initialized notification + notif = write_stream.get(timeout=2) + assert notif.message.root.method == "notifications/initialized" + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.initialize() + assert result.protocolVersion == types.LATEST_PROTOCOL_VERSION + assert result.serverInfo.name == "test-server" + + t.join(timeout=1) + + +def test_initialize_custom_capabilities(streams): + read_stream, write_stream = streams + session = ClientSession( + read_stream, write_stream, sampling_callback=lambda c, p: None, list_roots_callback=lambda c: None + ) + + def mock_server(): + msg = write_stream.get(timeout=2) + params = msg.message.root.params + # Check that capabilities are set because we provided custom callbacks + assert params["capabilities"]["sampling"] is not None + assert params["capabilities"]["roots"]["listChanged"] is True + + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": types.LATEST_PROTOCOL_VERSION, + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + write_stream.get(timeout=2) # initialized notif + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.initialize() + t.join(timeout=1) + + +def test_initialize_unsupported_version(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + req_id = msg.message.root.id + resp = types.JSONRPCResponse( + jsonrpc="2.0", + id=req_id, + result={ + "protocolVersion": "0.0.1", # Unsupported + "capabilities": {}, + "serverInfo": {"name": "test", "version": "1.0"}, + }, + ) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + with pytest.raises(RuntimeError, match="Unsupported protocol version"): + session.initialize() + t.join(timeout=1) + + +def test_send_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "ping" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.send_ping() + t.join(timeout=1) + + +def test_send_progress_notification(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_progress_notification(progress_token="token", progress=50.0, total=100.0) + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/progress" + assert msg.message.root.params["progressToken"] == "token" + assert msg.message.root.params["progress"] == 50.0 + assert msg.message.root.params["total"] == 100.0 + + +def test_set_logging_level(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "logging/setLevel" + assert msg.message.root.params["level"] == "debug" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.set_logging_level("debug") + t.join(timeout=1) + + +def test_list_resources(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resources": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resources() + assert result.resources == [] + t.join(timeout=1) + + +def test_list_resource_templates(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/templates/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"resourceTemplates": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_resource_templates() + assert result.resourceTemplates == [] + t.join(timeout=1) + + +def test_read_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/read" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"contents": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.read_resource(uri) + assert result.contents == [] + t.join(timeout=1) + + +def test_subscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/subscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.subscribe_resource(uri) + t.join(timeout=1) + + +def test_unsubscribe_resource(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + uri = AnyUrl("file:///test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "resources/unsubscribe" + assert msg.message.root.params["uri"] == str(uri) + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + session.unsubscribe_resource(uri) + t.join(timeout=1) + + +def test_call_tool(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/call" + assert msg.message.root.params["name"] == "test-tool" + assert msg.message.root.params["arguments"] == {"arg": 1} + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"content": [], "isError": False}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.call_tool("test-tool", arguments={"arg": 1}) + assert result.isError is False + t.join(timeout=1) + + +def test_list_prompts(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"prompts": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_prompts() + assert result.prompts == [] + t.join(timeout=1) + + +def test_get_prompt(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "prompts/get" + assert msg.message.root.params["name"] == "test-prompt" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"messages": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.get_prompt("test-prompt") + assert result.messages == [] + t.join(timeout=1) + + +def test_complete(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + ref = types.PromptReference(type="ref/prompt", name="test") + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "completion/complete" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"completion": {"values": [], "hasMore": False}}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.complete(ref, argument={"name": "val", "value": "x"}) + assert result.completion.hasMore is False + t.join(timeout=1) + + +def test_list_tools(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + def mock_server(): + msg = write_stream.get(timeout=2) + assert msg.message.root.method == "tools/list" + req_id = msg.message.root.id + resp = types.JSONRPCResponse(jsonrpc="2.0", id=req_id, result={"tools": []}) + read_stream.put(SessionMessage(message=types.JSONRPCMessage(resp))) + + import threading + + t = threading.Thread(target=mock_server, daemon=True) + t.start() + + with session: + result = session.list_tools() + assert result.tools == [] + t.join(timeout=1) + + +def test_send_roots_list_changed(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + session.send_roots_list_changed() + + msg = write_stream.get_nowait() + assert msg.message.root.method == "notifications/roots/list_changed" + + +def test_received_request_sampling(streams): + read_stream, write_stream = streams + sampling_cb = MagicMock( + return_value=types.CreateMessageResult( + role="assistant", content=types.TextContent(type="text", text="hello"), model="gpt-4" + ) + ) + session = ClientSession(read_stream, write_stream, sampling_callback=sampling_cb) + + req = types.ServerRequest( + root=types.CreateMessageRequest( + method="sampling/createMessage", params=types.CreateMessageRequestParams(messages=[], maxTokens=100) + ) + ) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["model"] == "gpt-4" + sampling_cb.assert_called_once() + + +def test_received_request_list_roots(streams): + read_stream, write_stream = streams + list_roots_cb = MagicMock(return_value=types.ListRootsResult(roots=[])) + session = ClientSession(read_stream, write_stream, list_roots_callback=list_roots_cb) + + req = types.ServerRequest(root=types.ListRootsRequest(method="roots/list")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result["roots"] == [] + list_roots_cb.assert_called_once() + + +def test_received_request_ping(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + req = types.ServerRequest(root=types.PingRequest(method="ping")) + + responder = RequestResponder(request_id=1, request_meta=None, request=req, session=session, on_complete=MagicMock()) + + session._received_request(responder) + + msg = write_stream.get_nowait() + assert msg.message.root.result == {} + + +def test_handle_incoming(streams): + read_stream, write_stream = streams + msg_handler = MagicMock() + session = ClientSession(read_stream, write_stream, message_handler=msg_handler) + + item = MagicMock() + session._handle_incoming(item) + msg_handler.assert_called_once_with(item) + + +def test_received_notification_logging(streams): + read_stream, write_stream = streams + logging_cb = MagicMock() + session = ClientSession(read_stream, write_stream, logging_callback=logging_cb) + + notif = types.ServerNotification( + root=types.LoggingMessageNotification( + method="notifications/message", + params=types.LoggingMessageNotificationParams(level="info", data={"msg": "test"}), + ) + ) + + session._received_notification(notif) + logging_cb.assert_called_once() + assert logging_cb.call_args[0][0].level == "info" + + +def test_default_message_handler(): + # Exception case + with pytest.raises(ValueError, match="test error"): + _default_message_handler(Exception("test error")) + + # Notification case - should do nothing + _default_message_handler(MagicMock(spec=types.ServerNotification)) + + # RequestResponder case - should do nothing + _default_message_handler(MagicMock(spec=RequestResponder)) + + +def test_default_sampling_callback(): + ctx = MagicMock() + params = MagicMock() + res = _default_sampling_callback(ctx, params) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_list_roots_callback(): + ctx = MagicMock() + res = _default_list_roots_callback(ctx) + assert res.code == types.INVALID_REQUEST + assert "not supported" in res.message + + +def test_default_logging_callback(): + params = MagicMock() + _default_logging_callback(params) # Should do nothing + + +def test_received_notification_unknown(streams): + read_stream, write_stream = streams + session = ClientSession(read_stream, write_stream) + + # Use a notification type that is NOT LoggingMessageNotification + notif = types.ServerNotification( + root=types.ResourceListChangedNotification(method="notifications/resources/list_changed") + ) + + session._received_notification(notif) + # Should just pass (case _:) diff --git a/api/tests/unit_tests/core/mcp/test_mcp_client.py b/api/tests/unit_tests/core/mcp/test_mcp_client.py index c0420d3371..c245b4a77e 100644 --- a/api/tests/unit_tests/core/mcp/test_mcp_client.py +++ b/api/tests/unit_tests/core/mcp/test_mcp_client.py @@ -2,13 +2,16 @@ from contextlib import ExitStack from types import TracebackType -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy.orm import Session -from core.mcp.error import MCPConnectionError +from core.entities.mcp_provider import MCPProviderEntity +from core.mcp.auth_client import MCPClientWithAuthRetry +from core.mcp.error import MCPAuthError, MCPConnectionError from core.mcp.mcp_client import MCPClient -from core.mcp.types import CallToolResult, ListToolsResult, TextContent, Tool, ToolAnnotations +from core.mcp.types import CallToolResult, ListToolsResult, OAuthTokens, TextContent, Tool, ToolAnnotations class TestMCPClient: @@ -380,3 +383,256 @@ class TestMCPClient: timeout=30.0, sse_read_timeout=60.0, ) + + +class TestMCPClientWithAuthRetry: + """Test suite for MCPClientWithAuthRetry.""" + + @pytest.fixture + def mock_provider(self): + provider = MagicMock(spec=MCPProviderEntity) + provider.id = "test-provider-id" + provider.tenant_id = "test-tenant-id" + provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-token", + token_type="Bearer", + expires_in=3600, + refresh_token="refresh-token", + ) + return provider + + @pytest.fixture + def auth_client(self, mock_provider): + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer old-token"}, + provider_entity=mock_provider, + authorization_code="test-code", + by_server_id=True, + ) + return client + + def test_init(self, mock_provider): + """Test initialization.""" + client = MCPClientWithAuthRetry( + server_url="http://test.example.com", + headers={"Authorization": "Bearer test"}, + timeout=30.0, + provider_entity=mock_provider, + authorization_code="initial-code", + by_server_id=True, + ) + + assert client.server_url == "http://test.example.com" + assert client.headers == {"Authorization": "Bearer test"} + assert client.timeout == 30.0 + assert client.provider_entity == mock_provider + assert client.authorization_code == "initial-code" + assert client.by_server_id is True + assert client._has_retried is False + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_success( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + mock_session = MagicMock(spec=Session) + mock_session_class.return_value.__enter__.return_value = mock_session + + mock_service = mock_service_class.return_value + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = OAuthTokens( + access_token="new-access-token", + token_type="Bearer", + expires_in=3600, + refresh_token="new-refresh-token", + ) + mock_service.get_provider_entity.return_value = new_provider + + # MCPAuthError parses resource_metadata and scope from www_authenticate_header + www_auth = 'Bearer resource_metadata="http://meta", scope="read"' + error = MCPAuthError("Auth failed", www_authenticate_header=www_auth) + + auth_client._handle_auth_error(error) + + # Verify service calls - error.resource_metadata_url and error.scope_hint are parsed from header + mock_service.auth_with_actions.assert_called_once_with( + mock_provider, + "test-code", + resource_metadata_url="http://meta", + scope_hint="read", + ) + mock_service.get_provider_entity.assert_called_once_with( + mock_provider.id, mock_provider.tenant_id, by_server_id=True + ) + + # Verify client updates + assert auth_client.headers["Authorization"] == "Bearer new-access-token" + assert auth_client.authorization_code is None + assert auth_client._has_retried is True + assert auth_client.provider_entity == new_provider + + def test_handle_auth_error_no_provider(self, auth_client): + """Test auth error handling when no provider entity is set.""" + auth_client.provider_entity = None + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + def test_handle_auth_error_already_retried(self, auth_client): + """Test auth error handling when already retried.""" + auth_client._has_retried = True + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert exc_info.value == error + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_no_token( + self, mock_service_class, mock_session_class, mock_db, auth_client, mock_provider + ): + """Test auth error handling when no token is received.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + + new_provider = MagicMock(spec=MCPProviderEntity) + new_provider.retrieve_tokens.return_value = None + mock_service.get_provider_entity.return_value = new_provider + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication failed - no token received" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_generic_exception(self, mock_service_class, mock_session_class, mock_db, auth_client): + """Test auth error handling when a generic exception occurs.""" + mock_session_class.side_effect = Exception("DB error") + + error = MCPAuthError("Auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Authentication retry failed: DB error" in str(exc_info.value) + + @patch("core.mcp.auth_client.db") + @patch("core.mcp.auth_client.Session") + @patch("services.tools.mcp_tools_manage_service.MCPToolManageService") + def test_handle_auth_error_mcp_auth_error_propagation( + self, mock_service_class, mock_session_class, mock_db, auth_client + ): + """Test that MCPAuthError during refresh is propagated as is.""" + mock_session_class.return_value.__enter__.return_value = MagicMock() + mock_service = mock_service_class.return_value + mock_service.auth_with_actions.side_effect = MCPAuthError("Refresh failed") + + error = MCPAuthError("Initial auth failed") + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._handle_auth_error(error) + + assert "Refresh failed" in str(exc_info.value) + + def test_execute_with_retry_success_first_try(self, auth_client): + """Test execution success on first try.""" + mock_func = MagicMock(return_value="success") + + result = auth_client._execute_with_retry(mock_func, "arg1", kwarg1="val1") + + assert result == "success" + mock_func.assert_called_once_with("arg1", kwarg1="val1") + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test execution success on retry after auth error when client was already initialized.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "success"] + + auth_client._initialized = True + auth_client._exit_stack = MagicMock() + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "success" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_called_once() + auth_client._exit_stack.close.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + @patch.object(MCPClientWithAuthRetry, "_initialize") + def test_execute_with_retry_success_on_retry_not_initialized(self, mock_initialize, mock_handle_auth, auth_client): + """Test retry when client was NOT initialized (skips cleanup/re-init).""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("Auth failed"), "result"] + + auth_client._initialized = False + + result = auth_client._execute_with_retry(mock_func, "arg") + + assert result == "result" + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + mock_initialize.assert_not_called() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_handle_auth_error") + def test_execute_with_retry_failure_on_retry(self, mock_handle_auth, auth_client): + """Test execution failure even after retry.""" + mock_func = MagicMock() + mock_func.side_effect = [MCPAuthError("First fail"), MCPAuthError("Second fail")] + + with pytest.raises(MCPAuthError) as exc_info: + auth_client._execute_with_retry(mock_func, "arg") + + assert "Second fail" in str(exc_info.value) + assert mock_func.call_count == 2 + mock_handle_auth.assert_called_once() + assert auth_client._has_retried is False + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_context_manager_enter(self, mock_execute_retry, auth_client): + """Test context manager __enter__.""" + auth_client.__enter__() + + mock_execute_retry.assert_called_once() + func = mock_execute_retry.call_args[0][0] + + with patch("core.mcp.mcp_client.MCPClient.__enter__") as mock_base_enter: + result = func() + assert result == auth_client + mock_base_enter.assert_called_once() + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_list_tools(self, mock_execute_retry, auth_client): + """Test list_tools with retry.""" + auth_client.list_tools() + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "list_tools" + + @patch.object(MCPClientWithAuthRetry, "_execute_with_retry") + def test_auth_client_invoke_tool(self, mock_execute_retry, auth_client): + """Test invoke_tool with retry.""" + auth_client.invoke_tool("test-tool", {"arg": "val"}) + + mock_execute_retry.assert_called_once() + assert mock_execute_retry.call_args[0][0].__name__ == "invoke_tool" + assert mock_execute_retry.call_args[0][1] == "test-tool" + assert mock_execute_retry.call_args[0][2] == {"arg": "val"} diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py new file mode 100644 index 0000000000..5ecfe01808 --- /dev/null +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -0,0 +1,969 @@ +"""Comprehensive unit tests for core/memory/token_buffer_memory.py""" + +from unittest.mock import MagicMock, patch +from uuid import uuid4 + +import pytest + +from core.memory.token_buffer_memory import TokenBufferMemory +from dify_graph.model_runtime.entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, + UserPromptMessage, +) +from models.model import AppMode + +# --------------------------------------------------------------------------- +# Helpers / shared fixtures +# --------------------------------------------------------------------------- + + +def _make_conversation(mode: AppMode = AppMode.CHAT) -> MagicMock: + """Return a minimal Conversation mock.""" + conv = MagicMock() + conv.id = str(uuid4()) + conv.mode = mode + conv.model_config = {} + return conv + + +def _make_model_instance() -> MagicMock: + """Return a ModelInstance mock whose token counter returns a constant.""" + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 100 + return mi + + +def _make_message(answer: str = "hello", answer_tokens: int = 5) -> MagicMock: + msg = MagicMock() + msg.id = str(uuid4()) + msg.query = "user query" + msg.answer = answer + msg.answer_tokens = answer_tokens + msg.workflow_run_id = str(uuid4()) + msg.created_at = MagicMock() + return msg + + +# =========================================================================== +# Tests for __init__ and workflow_run_repo property +# =========================================================================== + + +class TestInit: + def test_init_stores_conversation_and_model_instance(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + assert mem.conversation is conv + assert mem.model_instance is mi + assert mem._workflow_run_repo is None + + def test_workflow_run_repo_is_created_lazily(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + mock_repo = MagicMock() + with ( + patch("core.memory.token_buffer_memory.sessionmaker") as mock_sm, + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=mock_repo, + ), + ): + mock_db.engine = MagicMock() + repo = mem.workflow_run_repo + assert repo is mock_repo + assert mem._workflow_run_repo is mock_repo + + def test_workflow_run_repo_cached_after_first_access(self): + conv = _make_conversation() + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + existing_repo = MagicMock() + mem._workflow_run_repo = existing_repo + + with patch( + "core.memory.token_buffer_memory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_factory: + repo = mem.workflow_run_repo + mock_factory.assert_not_called() + assert repo is existing_repo + + +# =========================================================================== +# Tests for _build_prompt_message_with_files +# =========================================================================== + + +class TestBuildPromptMessageWithFiles: + """Tests for the private _build_prompt_message_with_files method.""" + + # ------------------------------------------------------------------ + # Mode: CHAT / AGENT_CHAT / COMPLETION (simple branch) + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_user_message(self, mode): + """When file_extra_config is falsy or app_record is None → plain UserPromptMessage.""" + conv = _make_conversation(mode) + mi = _make_model_instance() + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, # falsy → file_objs = [] + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="hello", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_no_files_assistant_message(self, mode): + """Plain AssistantPromptMessage when no files and is_user_message=False.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ): + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="ai reply", + message=_make_message(), + app_record=None, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert result.content == "ai reply" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_user_message(self, mode): + """When files are present, returns UserPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None # no detail override + + mock_file_obj = MagicMock() + # Must be a real entity so Pydantic's tagged union discriminator can validate it + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + mock_message_file = MagicMock() + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[mock_message_file], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert isinstance(result.content, list) + # Last element should be TextPromptMessageContent + assert isinstance(result.content[-1], TextPromptMessageContent) + assert result.content[-1].data == "user text" + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_assistant_message(self, mode): + """When files are present, returns AssistantPromptMessage with list content.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = None + + mock_file_obj = MagicMock() + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=mock_file_obj, + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ), + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="ai text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=False, + ) + + assert isinstance(result, AssistantPromptMessage) + assert isinstance(result.content, list) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_with_files_image_detail_overridden(self, mode): + """When image_config.detail is set, detail is taken from config.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_image_config = MagicMock() + mock_image_config.detail = ImagePromptMessageContent.DETAIL.LOW + + mock_file_extra_config = MagicMock() + mock_file_extra_config.image_config = mock_image_config + + mock_app_record = MagicMock() + mock_app_record.tenant_id = "tenant-1" + + real_image_content = ImagePromptMessageContent( + url="http://example.com/img.png", format="png", mime_type="image/png" + ) + + with ( + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ), + patch( + "core.memory.token_buffer_memory.file_factory.build_from_message_file", + return_value=MagicMock(), + ), + patch( + "core.memory.token_buffer_memory.file_manager.to_prompt_message_content", + return_value=real_image_content, + ) as mock_to_prompt, + ): + mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="user text", + message=_make_message(), + app_record=mock_app_record, + is_user_message=True, + ) + # Ensure the LOW detail was passed through + mock_to_prompt.assert_called_once_with( + mock_to_prompt.call_args[0][0], image_detail_config=ImagePromptMessageContent.DETAIL.LOW + ) + + @pytest.mark.parametrize("mode", [AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.COMPLETION]) + def test_chat_mode_app_record_none_returns_empty_file_objs(self, mode): + """app_record=None path → file_objs stays empty → plain messages.""" + conv = _make_conversation(mode) + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + mock_file_extra_config = MagicMock() + + with patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=mock_file_extra_config, + ): + result = mem._build_prompt_message_with_files( + message_files=[MagicMock()], + text_content="hello", + message=_make_message(), + app_record=None, # <-- forces the else branch → file_objs = [] + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "hello" + + # ------------------------------------------------------------------ + # Mode: ADVANCED_CHAT / WORKFLOW + # ------------------------------------------------------------------ + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_app_raises(self, mode): + """Raises ValueError when conversation.app is falsy.""" + conv = _make_conversation(mode) + conv.app = None + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="App not found for conversation"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_no_workflow_run_id_raises(self, mode): + """Raises ValueError when message.workflow_run_id is falsy.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + message = _make_message() + message.workflow_run_id = None # force missing + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(ValueError, match="Workflow run ID not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=message, + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_run_not_found_raises(self, mode): + """Raises ValueError when workflow_run_repo returns None.""" + conv = _make_conversation(mode) + mock_app = MagicMock() + conv.app = mock_app + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = None + + with pytest.raises(ValueError, match="Workflow run not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_workflow_not_found_raises(self, mode): + """Raises ValueError when Workflow lookup returns None.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + ): + mock_db.session.scalar.return_value = None # workflow not found + + with pytest.raises(ValueError, match="Workflow not found"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + @pytest.mark.parametrize("mode", [AppMode.ADVANCED_CHAT, AppMode.WORKFLOW]) + def test_workflow_mode_success_no_files_user(self, mode): + """Happy path: workflow mode, no message files → plain UserPromptMessage.""" + conv = _make_conversation(mode) + conv.app = MagicMock() + + mock_workflow_run = MagicMock() + mock_workflow_run.workflow_id = str(uuid4()) + + mock_workflow = MagicMock() + mock_workflow.features_dict = {} + + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + mem._workflow_run_repo = MagicMock() + mem._workflow_run_repo.get_workflow_run_by_id.return_value = mock_workflow_run + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalar.return_value = mock_workflow + + result = mem._build_prompt_message_with_files( + message_files=[], + text_content="wf text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + assert isinstance(result, UserPromptMessage) + assert result.content == "wf text" + + # ------------------------------------------------------------------ + # Invalid mode + # ------------------------------------------------------------------ + + def test_invalid_mode_raises_assertion(self): + """Any unknown AppMode raises AssertionError.""" + conv = _make_conversation() + conv.mode = "unknown_mode" # not in any set + mem = TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + with pytest.raises(AssertionError, match="Invalid app mode"): + mem._build_prompt_message_with_files( + message_files=[], + text_content="text", + message=_make_message(), + app_record=MagicMock(), + is_user_message=True, + ) + + +# =========================================================================== +# Tests for get_history_prompt_messages +# =========================================================================== + + +class TestGetHistoryPromptMessages: + """Tests for get_history_prompt_messages.""" + + def _make_memory(self, mode: AppMode = AppMode.CHAT) -> TokenBufferMemory: + conv = _make_conversation(mode) + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_returns_empty_when_no_messages(self): + mem = self._make_memory() + with patch("core.memory.token_buffer_memory.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + assert result == [] + + def test_skips_first_message_without_answer(self): + """The newest message (index 0 after extraction) without answer and tokens==0 is skipped.""" + mem = self._make_memory() + + msg_no_answer = _make_message(answer="", answer_tokens=0) + msg_no_answer.parent_message_id = None # ensures extract_thread_messages returns it + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg_no_answer], + ), + ): + mock_db.session.scalars.return_value.all.side_effect = [ + [msg_no_answer], # first call: messages query + [], # second call: user files query (never hit, but safe) + ] + result = mem.get_history_prompt_messages() + + assert result == [] + + def test_message_with_answer_not_skipped(self): + """A message with a non-empty answer is NOT popped.""" + mem = self._make_memory() + + msg = _make_message(answer="some answer", answer_tokens=10) + msg.parent_message_id = None + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + # user files query → empty; assistant files query → empty + mock_db.session.scalars.return_value.all.return_value = [] + result = mem.get_history_prompt_messages() + + assert len(result) == 2 # one user + one assistant + + def test_message_limit_default_is_500(self): + """When message_limit is None the stmt is limited to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=None) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_clipped_to_500(self): + """A message_limit > 500 is clamped to 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=9999) + mock_stmt.limit.assert_called_with(500) + + def test_message_limit_positive_used(self): + """A positive message_limit < 500 is used as-is.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=10) + mock_stmt.limit.assert_called_with(10) + + def test_message_limit_zero_uses_default(self): + """message_limit=0 triggers the else branch → default 500.""" + mem = self._make_memory() + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch("core.memory.token_buffer_memory.select") as mock_select, + patch("core.memory.token_buffer_memory.extract_thread_messages", return_value=[]), + ): + mock_stmt = MagicMock() + mock_select.return_value.where.return_value.order_by.return_value = mock_stmt + mock_stmt.limit.return_value = mock_stmt + mock_db.session.scalars.return_value.all.return_value = [] + + mem.get_history_prompt_messages(message_limit=0) + mock_stmt.limit.assert_called_with(500) + + def test_user_files_cause_build_with_files_call(self): + """When user_files is non-empty _build_prompt_message_with_files is invoked.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_user_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="from build") + mock_assistant_prompt = AssistantPromptMessage(content="answer") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + # messages query + r.all.return_value = [msg] + elif call_count["n"] == 1: + # user files + r.all.return_value = [mock_user_file] + else: + # assistant files + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + side_effect=[mock_user_prompt, mock_assistant_prompt], + ) as mock_build, + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert mock_build.call_count >= 1 + # First call should be user message + first_call_kwargs = mock_build.call_args_list[0][1] + assert first_call_kwargs["is_user_message"] is True + + def test_assistant_files_cause_build_with_files_call(self): + """When assistant_files is non-empty, build is called with is_user_message=False.""" + mem = self._make_memory() + msg = _make_message() + msg.parent_message_id = None + + mock_assistant_file = MagicMock() + mock_user_prompt = UserPromptMessage(content="query") + mock_assistant_prompt = AssistantPromptMessage(content="built") + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + elif call_count["n"] == 1: + r.all.return_value = [] # no user files + else: + r.all.return_value = [mock_assistant_file] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch.object( + mem, + "_build_prompt_message_with_files", + return_value=mock_assistant_prompt, + ) as mock_build, + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + mock_build.assert_called_once() + call_kwargs = mock_build.call_args[1] + assert call_kwargs["is_user_message"] is False + + def test_token_pruning_removes_oldest_messages(self): + """If tokens exceed limit, oldest messages are removed until within limit.""" + conv = _make_conversation() + conv.app = MagicMock() + + # Model returns tokens that decrease only after removing pairs + token_values = [3000, 1500] # first call over limit, second within + mi = MagicMock() + mi.get_llm_num_tokens.side_effect = token_values + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + # After pruning, we should have fewer than the 2 initial messages + assert len(result) <= 1 + + def test_token_pruning_stops_at_single_message(self): + """Pruning stops when only 1 message remains (to prevent empty list).""" + conv = _make_conversation() + conv.app = MagicMock() + + # Always over limit + mi = MagicMock() + mi.get_llm_num_tokens.return_value = 99999 + + mem = TokenBufferMemory(conversation=conv, model_instance=mi) + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=1) + + # At least 1 message should remain + assert len(result) >= 1 + + def test_no_pruning_when_within_limit(self): + """When tokens ≤ limit, no pruning occurs.""" + mem = self._make_memory() + mem.model_instance.get_llm_num_tokens.return_value = 50 # well under default 2000 + + msg = _make_message() + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages(max_token_limit=2000) + + assert len(result) == 2 # user + assistant + + def test_plain_user_and_assistant_messages_returned(self): + """Without files, plain UserPromptMessage and AssistantPromptMessage appear.""" + mem = self._make_memory() + + msg = _make_message(answer="My answer") + msg.query = "My query" + msg.parent_message_id = None + + call_count = {"n": 0} + + def scalars_side_effect(stmt): + r = MagicMock() + if call_count["n"] == 0: + r.all.return_value = [msg] + else: + r.all.return_value = [] + call_count["n"] += 1 + return r + + with ( + patch("core.memory.token_buffer_memory.db") as mock_db, + patch( + "core.memory.token_buffer_memory.extract_thread_messages", + return_value=[msg], + ), + patch( + "core.memory.token_buffer_memory.FileUploadConfigManager.convert", + return_value=None, + ), + ): + mock_db.session.scalars.side_effect = scalars_side_effect + result = mem.get_history_prompt_messages() + + assert len(result) == 2 + user_msg, ai_msg = result + assert isinstance(user_msg, UserPromptMessage) + assert user_msg.content == "My query" + assert isinstance(ai_msg, AssistantPromptMessage) + assert ai_msg.content == "My answer" + + +# =========================================================================== +# Tests for get_history_prompt_text +# =========================================================================== + + +class TestGetHistoryPromptText: + """Tests for get_history_prompt_text.""" + + def _make_memory(self) -> TokenBufferMemory: + conv = _make_conversation() + conv.app = MagicMock() + return TokenBufferMemory(conversation=conv, model_instance=_make_model_instance()) + + def test_empty_messages_returns_empty_string(self): + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]): + result = mem.get_history_prompt_text() + assert result == "" + + def test_user_and_assistant_messages_formatted(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hello"), + AssistantPromptMessage(content="World"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="H", ai_prefix="A") + assert result == "H: Hello\nA: World" + + def test_custom_prefixes_applied(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Hi"), + AssistantPromptMessage(content="Bye"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text(human_prefix="Human", ai_prefix="Bot") + assert "Human: Hi" in result + assert "Bot: Bye" in result + + def test_list_content_with_text_and_image(self): + """List content: TextPromptMessageContent → text; ImagePromptMessageContent → [image].""" + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="caption"), + ImagePromptMessageContent(url="http://img", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "caption" in result + assert "[image]" in result + + def test_list_content_text_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content=[TextPromptMessageContent(data="just text")]), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "just text" in result + + def test_list_content_image_only(self): + mem = self._make_memory() + messages = [ + UserPromptMessage( + content=[ + ImagePromptMessageContent(url="http://img", format="jpg", mime_type="image/jpeg"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "[image]" in result + + def test_unknown_role_skipped(self): + """Messages with a role that is not USER or ASSISTANT are skipped.""" + mem = self._make_memory() + + # Create a mock message with a SYSTEM role + system_msg = MagicMock() + system_msg.role = PromptMessageRole.SYSTEM + system_msg.content = "system instruction" + + user_msg = UserPromptMessage(content="hi") + + with patch.object(mem, "get_history_prompt_messages", return_value=[system_msg, user_msg]): + result = mem.get_history_prompt_text() + + assert "system instruction" not in result + assert "Human: hi" in result + + def test_passes_max_token_limit_and_message_limit(self): + """Parameters are forwarded to get_history_prompt_messages.""" + mem = self._make_memory() + with patch.object(mem, "get_history_prompt_messages", return_value=[]) as mock_get: + mem.get_history_prompt_text(max_token_limit=500, message_limit=10) + mock_get.assert_called_once_with(max_token_limit=500, message_limit=10) + + def test_multiple_messages_joined_by_newline(self): + mem = self._make_memory() + messages = [ + UserPromptMessage(content="Q1"), + AssistantPromptMessage(content="A1"), + UserPromptMessage(content="Q2"), + AssistantPromptMessage(content="A2"), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + lines = result.split("\n") + assert len(lines) == 4 + assert lines[0] == "Human: Q1" + assert lines[1] == "Assistant: A1" + assert lines[2] == "Human: Q2" + assert lines[3] == "Assistant: A2" + + def test_assistant_list_content_formatted(self): + """AssistantPromptMessage with list content is also handled.""" + mem = self._make_memory() + messages = [ + AssistantPromptMessage( + content=[ + TextPromptMessageContent(data="response text"), + ImagePromptMessageContent(url="http://img2", format="png", mime_type="image/png"), + ] + ), + ] + with patch.object(mem, "get_history_prompt_messages", return_value=messages): + result = mem.get_history_prompt_text() + assert "response text" in result + assert "[image]" in result diff --git a/api/tests/unit_tests/core/moderation/api/test_api.py b/api/tests/unit_tests/core/moderation/api/test_api.py new file mode 100644 index 0000000000..558b20e5f8 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/api/test_api.py @@ -0,0 +1,181 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError + +from core.extension.api_based_extension_requestor import APIBasedExtensionPoint +from core.moderation.api.api import ApiModeration, ModerationInputParams, ModerationOutputParams +from core.moderation.base import ModerationAction, ModerationInputsResult, ModerationOutputsResult +from models.api_based_extension import APIBasedExtension + + +class TestApiModeration: + @pytest.fixture + def api_config(self): + return { + "inputs_config": { + "enabled": True, + }, + "outputs_config": { + "enabled": True, + }, + "api_based_extension_id": "test-extension-id", + } + + @pytest.fixture + def api_moderation(self, api_config): + return ApiModeration(app_id="test-app-id", tenant_id="test-tenant-id", config=api_config) + + def test_moderation_input_params(self): + params = ModerationInputParams(app_id="app-1", inputs={"key": "val"}, query="test query") + assert params.app_id == "app-1" + assert params.inputs == {"key": "val"} + assert params.query == "test query" + + # Test defaults + params_default = ModerationInputParams() + assert params_default.app_id == "" + assert params_default.inputs == {} + assert params_default.query == "" + + def test_moderation_output_params(self): + params = ModerationOutputParams(app_id="app-1", text="test text") + assert params.app_id == "app-1" + assert params.text == "test text" + + with pytest.raises(ValidationError): + ModerationOutputParams() + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_success(self, mock_get_extension, api_config): + mock_get_extension.return_value = MagicMock(spec=APIBasedExtension) + ApiModeration.validate_config("test-tenant-id", api_config) + mock_get_extension.assert_called_once_with("test-tenant-id", "test-extension-id") + + def test_validate_config_missing_extension_id(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": True}, + } + with pytest.raises(ValueError, match="api_based_extension_id is required"): + ApiModeration.validate_config("test-tenant-id", config) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_validate_config_extension_not_found(self, mock_get_extension, api_config): + mock_get_extension.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + ApiModeration.validate_config("test-tenant-id", api_config) + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_inputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": True, "action": "direct_output", "preset_response": "Blocked by API"} + + result = api_moderation.moderation_for_inputs(inputs={"q": "a"}, query="hello") + + assert isinstance(result, ModerationInputsResult) + assert result.flagged is True + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "Blocked by API" + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_INPUT, + {"app_id": "test-app-id", "inputs": {"q": "a"}, "query": "hello"}, + ) + + def test_moderation_for_inputs_disabled(self): + config = { + "inputs_config": {"enabled": False}, + "outputs_config": {"enabled": True}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_inputs(inputs={}, query="") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + assert result.preset_response == "" + + def test_moderation_for_inputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_inputs({}, "") + + @patch("core.moderation.api.api.ApiModeration._get_config_by_requestor") + def test_moderation_for_outputs_enabled(self, mock_get_config, api_moderation): + mock_get_config.return_value = {"flagged": False, "action": "direct_output", "preset_response": ""} + + result = api_moderation.moderation_for_outputs(text="hello world") + + assert isinstance(result, ModerationOutputsResult) + assert result.flagged is False + + mock_get_config.assert_called_once_with( + APIBasedExtensionPoint.APP_MODERATION_OUTPUT, {"app_id": "test-app-id", "text": "hello world"} + ) + + def test_moderation_for_outputs_disabled(self): + config = { + "inputs_config": {"enabled": True}, + "outputs_config": {"enabled": False}, + "api_based_extension_id": "ext-id", + } + moderation = ApiModeration("app-id", "tenant-id", config) + result = moderation.moderation_for_outputs(text="test") + + assert result.flagged is False + assert result.action == ModerationAction.DIRECT_OUTPUT + + def test_moderation_for_outputs_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation.moderation_for_outputs("test") + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + @patch("core.moderation.api.api.decrypt_token") + @patch("core.moderation.api.api.APIBasedExtensionRequestor") + def test_get_config_by_requestor_success(self, mock_requestor_cls, mock_decrypt, mock_get_ext, api_moderation): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_ext.api_endpoint = "http://api.test" + mock_ext.api_key = "encrypted-key" + mock_get_ext.return_value = mock_ext + + mock_decrypt.return_value = "decrypted-key" + + mock_requestor = MagicMock() + mock_requestor.request.return_value = {"flagged": True} + mock_requestor_cls.return_value = mock_requestor + + params = {"some": "params"} + result = api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + assert result == {"flagged": True} + mock_get_ext.assert_called_once_with("test-tenant-id", "test-extension-id") + mock_decrypt.assert_called_once_with("test-tenant-id", "encrypted-key") + mock_requestor_cls.assert_called_once_with("http://api.test", "decrypted-key") + mock_requestor.request.assert_called_once_with(APIBasedExtensionPoint.APP_MODERATION_INPUT, params) + + def test_get_config_by_requestor_no_config(self): + moderation = ApiModeration("app-id", "tenant-id", None) + with pytest.raises(ValueError, match="The config is not set"): + moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.ApiModeration._get_api_based_extension") + def test_get_config_by_requestor_extension_not_found(self, mock_get_ext, api_moderation): + mock_get_ext.return_value = None + with pytest.raises(ValueError, match="API-based Extension not found"): + api_moderation._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, {}) + + @patch("core.moderation.api.api.db.session.scalar") + def test_get_api_based_extension(self, mock_scalar): + mock_ext = MagicMock(spec=APIBasedExtension) + mock_scalar.return_value = mock_ext + + result = ApiModeration._get_api_based_extension("tenant-1", "ext-1") + + assert result == mock_ext + mock_scalar.assert_called_once() + # Verify the call has the correct filters + args, kwargs = mock_scalar.call_args + stmt = args[0] + # We can't easily inspect the statement without complex sqlalchemy tricks, + # but calling it is usually enough for unit tests if we mock the result. diff --git a/api/tests/unit_tests/core/moderation/test_input_moderation.py b/api/tests/unit_tests/core/moderation/test_input_moderation.py new file mode 100644 index 0000000000..2dbc80cf14 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_input_moderation.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.app.app_config.entities import AppConfig, SensitiveWordAvoidanceEntity +from core.moderation.base import ModerationAction, ModerationError, ModerationInputsResult +from core.moderation.input_moderation import InputModeration +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager + + +class TestInputModeration: + @pytest.fixture + def app_config(self): + config = MagicMock(spec=AppConfig) + config.sensitive_word_avoidance = None + return config + + @pytest.fixture + def input_moderation(self): + return InputModeration() + + def test_check_no_sensitive_word_avoidance(self, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_not_flagged(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {"keywords": ["bad"]} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is False + assert final_inputs == inputs + assert final_query == query + mock_factory_cls.assert_called_once_with( + name="keywords", app_id=app_id, tenant_id=tenant_id, config={"keywords": ["bad"]} + ) + mock_factory.moderation_for_inputs.assert_called_once_with(dict(inputs), query) + + @patch("core.moderation.input_moderation.ModerationFactory") + @patch("core.moderation.input_moderation.TraceTask") + def test_check_with_trace_manager(self, mock_trace_task, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + trace_manager = MagicMock(spec=TraceQueueManager) + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_inputs.return_value = mock_result + + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + trace_manager=trace_manager, + ) + + trace_manager.add_trace_task.assert_called_once_with(mock_trace_task.return_value) + mock_trace_task.assert_called_once() + call_kwargs = mock_trace_task.call_args.kwargs + call_args = mock_trace_task.call_args.args + assert call_args[0] == TraceTaskName.MODERATION_TRACE + assert call_kwargs["message_id"] == message_id + assert call_kwargs["moderation_result"] == mock_result + assert call_kwargs["inputs"] == inputs + assert "timer" in call_kwargs + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_direct_output(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="Blocked content" + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + with pytest.raises(ModerationError) as excinfo: + input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert str(excinfo.value) == "Blocked content" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_overridden(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = ModerationInputsResult( + flagged=True, + action=ModerationAction.OVERRIDDEN, + inputs={"input_key": "overridden_value"}, + query="overridden query", + ) + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, tenant_id=tenant_id, app_config=app_config, inputs=inputs, query=query, message_id=message_id + ) + + assert flagged is True + assert final_inputs == {"input_key": "overridden_value"} + assert final_query == "overridden query" + + @patch("core.moderation.input_moderation.ModerationFactory") + def test_check_flagged_other_action(self, mock_factory_cls, app_config, input_moderation): + app_id = "test_app_id" + tenant_id = "test_tenant_id" + inputs = {"input_key": "input_value"} + query = "test query" + message_id = "test_message_id" + + # Setup config + sensitive_word_config = MagicMock(spec=SensitiveWordAvoidanceEntity) + sensitive_word_config.type = "keywords" + sensitive_word_config.config = {} + app_config.sensitive_word_avoidance = sensitive_word_config + + # Setup factory mock + mock_factory = mock_factory_cls.return_value + mock_result = MagicMock() + mock_result.flagged = True + mock_result.action = "NONE" # Some other action + mock_factory.moderation_for_inputs.return_value = mock_result + + flagged, final_inputs, final_query = input_moderation.check( + app_id=app_id, + tenant_id=tenant_id, + app_config=app_config, + inputs=inputs, + query=query, + message_id=message_id, + ) + + assert flagged is True + assert final_inputs == inputs + assert final_query == query diff --git a/api/tests/unit_tests/core/moderation/test_output_moderation.py b/api/tests/unit_tests/core/moderation/test_output_moderation.py new file mode 100644 index 0000000000..c6a7cd3f61 --- /dev/null +++ b/api/tests/unit_tests/core/moderation/test_output_moderation.py @@ -0,0 +1,234 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.entities.queue_entities import QueueMessageReplaceEvent +from core.moderation.base import ModerationAction, ModerationOutputsResult +from core.moderation.output_moderation import ModerationRule, OutputModeration + + +class TestOutputModeration: + @pytest.fixture + def mock_queue_manager(self): + return MagicMock(spec=AppQueueManager) + + @pytest.fixture + def moderation_rule(self): + return ModerationRule(type="keywords", config={"keywords": "badword"}) + + @pytest.fixture + def output_moderation(self, mock_queue_manager, moderation_rule): + return OutputModeration( + tenant_id="test_tenant", app_id="test_app", rule=moderation_rule, queue_manager=mock_queue_manager + ) + + def test_should_direct_output(self, output_moderation): + assert output_moderation.should_direct_output() is False + output_moderation.final_output = "blocked" + assert output_moderation.should_direct_output() is True + + def test_get_final_output(self, output_moderation): + assert output_moderation.get_final_output() == "" + output_moderation.final_output = "blocked" + assert output_moderation.get_final_output() == "blocked" + + def test_append_new_token(self, output_moderation): + with patch.object(OutputModeration, "start_thread") as mock_start: + output_moderation.append_new_token("hello") + assert output_moderation.buffer == "hello" + mock_start.assert_called_once() + + output_moderation.thread = MagicMock() + output_moderation.append_new_token(" world") + assert output_moderation.buffer == "hello world" + assert mock_start.call_count == 1 + + def test_moderation_completion_no_flag(self, output_moderation): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output, flagged = output_moderation.moderation_completion("safe content") + + assert output == "safe content" + assert flagged is False + assert output_moderation.is_final_chunk is True + + def test_moderation_completion_flagged_direct_output(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "preset" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert isinstance(args[0], QueueMessageReplaceEvent) + assert args[0].text == "preset" + assert args[1] == PublishFrom.TASK_PIPELINE + + def test_moderation_completion_flagged_overridden(self, output_moderation, mock_queue_manager): + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.OVERRIDDEN, text="masked content" + ) + + output, flagged = output_moderation.moderation_completion("badword content", public_event=True) + + assert output == "masked content" + assert flagged is True + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked content" + + def test_start_thread(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("core.moderation.output_moderation.current_app") as mock_current_app: + mock_current_app._get_current_object.return_value = mock_app + with patch("threading.Thread") as mock_thread_class: + mock_thread_instance = MagicMock() + mock_thread_class.return_value = mock_thread_instance + + thread = output_moderation.start_thread() + + assert thread == mock_thread_instance + mock_thread_class.assert_called_once() + mock_thread_instance.start.assert_called_once() + + def test_stop_thread(self, output_moderation): + mock_thread = MagicMock() + mock_thread.is_alive.return_value = True + output_moderation.thread = mock_thread + + output_moderation.stop_thread() + assert output_moderation.thread_running is False + + output_moderation.thread_running = True + mock_thread.is_alive.return_value = False + output_moderation.stop_thread() + assert output_moderation.thread_running is True + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_success(self, mock_factory_class, output_moderation): + mock_factory = mock_factory_class.return_value + mock_result = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + mock_factory.moderation_for_outputs.return_value = mock_result + + result = output_moderation.moderation("tenant", "app", "buffer") + + assert result == mock_result + mock_factory_class.assert_called_once_with( + name="keywords", app_id="app", tenant_id="tenant", config={"keywords": "badword"} + ) + + @patch("core.moderation.output_moderation.ModerationFactory") + def test_moderation_exception(self, mock_factory_class, output_moderation): + mock_factory_class.side_effect = Exception("error") + + result = output_moderation.moderation("tenant", "app", "buffer") + assert result is None + + def test_worker_loop_and_exit(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + # Test exit on thread_running=False + output_moderation.thread_running = False + output_moderation.worker(mock_app, 10) + # Should exit immediately + + def test_worker_no_flag(self, output_moderation): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult(flagged=False, action=ModerationAction.DIRECT_OUTPUT) + + output_moderation.buffer = "safe" + output_moderation.is_final_chunk = True + + # To avoid infinite loop, we'll set thread_running to False after one iteration + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + return mock_moderation.return_value + + mock_moderation.side_effect = side_effect + + output_moderation.worker(mock_app, 10) + + assert mock_moderation.called + + def test_worker_flagged_direct_output(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + mock_moderation.return_value = ModerationOutputsResult( + flagged=True, action=ModerationAction.DIRECT_OUTPUT, preset_response="preset" + ) + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + assert output_moderation.final_output == "preset" + mock_queue_manager.publish.assert_called_once() + # It breaks on DIRECT_OUTPUT + + def test_worker_flagged_overridden(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Use side_effect to change thread_running on second call + def side_effect(*args, **kwargs): + if mock_moderation.call_count > 1: + output_moderation.thread_running = False + return None + return ModerationOutputsResult(flagged=True, action=ModerationAction.OVERRIDDEN, text="masked") + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "badword" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_called_once() + args, _ = mock_queue_manager.publish.call_args + assert args[0].text == "masked" + + def test_worker_chunk_too_small(self, output_moderation): + mock_app = MagicMock(spec=Flask) + with patch("time.sleep") as mock_sleep: + # chunk_length < buffer_size and not is_final_chunk + output_moderation.buffer = "123" # length 3 + output_moderation.is_final_chunk = False + + def sleep_side_effect(seconds): + output_moderation.thread_running = False + + mock_sleep.side_effect = sleep_side_effect + + output_moderation.worker(mock_app, 10) # buffer_size 10 + + mock_sleep.assert_called_once_with(1) + + def test_worker_empty_not_flagged(self, output_moderation, mock_queue_manager): + mock_app = MagicMock(spec=Flask) + with patch.object(OutputModeration, "moderation") as mock_moderation: + # Return None (exception or no rule) + mock_moderation.return_value = None + + def side_effect(*args, **kwargs): + output_moderation.thread_running = False + + mock_moderation.side_effect = side_effect + + output_moderation.buffer = "something" + output_moderation.is_final_chunk = True + + output_moderation.worker(mock_app, 10) + + mock_queue_manager.publish.assert_not_called() diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py new file mode 100644 index 0000000000..acb43d4036 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/data_exporter/test_traceclient.py @@ -0,0 +1,326 @@ +import time +import uuid +from datetime import datetime +from unittest.mock import MagicMock, patch + +import httpx +import pytest +from opentelemetry.sdk.trace import ReadableSpan +from opentelemetry.trace import SpanKind, Status, StatusCode + +from core.ops.aliyun_trace.data_exporter.traceclient import ( + INVALID_SPAN_ID, + SpanBuilder, + TraceClient, + build_endpoint, + convert_datetime_to_nanoseconds, + convert_string_to_id, + convert_to_span_id, + convert_to_trace_id, + create_link, + generate_span_id, +) +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData + + +@pytest.fixture +def trace_client_factory(): + """Factory fixture for creating TraceClient instances with automatic cleanup.""" + clients_to_shutdown = [] + + def _factory(**kwargs): + client = TraceClient(**kwargs) + clients_to_shutdown.append(client) + return client + + yield _factory + + # Cleanup: shutdown all created clients + for client in clients_to_shutdown: + client.shutdown() + + +class TestTraceClient: + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.socket.gethostname") + def test_init(self, mock_gethostname, mock_exporter_class, trace_client_factory): + mock_gethostname.return_value = "test-host" + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + assert client.endpoint == "http://test-endpoint" + assert client.max_queue_size == 1000 + assert client.schedule_delay_sec == 5 + assert client.done is False + assert client.worker_thread.is_alive() + + client.shutdown() + assert client.done is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + spans = [MagicMock(spec=ReadableSpan)] + client.export(spans) + mock_exporter.export.assert_called_once_with(spans) + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_success(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 405 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is True + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_failure_status(self, mock_exporter_class, mock_head, trace_client_factory): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_head.return_value = mock_response + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.api_check() is False + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.httpx.head") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_api_check_exception(self, mock_exporter_class, mock_head, trace_client_factory): + mock_head.side_effect = httpx.RequestError("Connection error") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + with pytest.raises(ValueError, match="AliyunTrace API check failed: Connection error"): + client.api_check() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_get_project_url(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + assert client.get_project_url() == "https://arms.console.aliyun.com/#/llm" + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_add_span(self, mock_exporter_class, trace_client_factory): + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + max_export_batch_size=2, + ) + + # Test add None + client.add_span(None) + assert len(client.queue) == 0 + + # Test add valid SpanData + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + with patch.object(client.condition, "notify") as mock_notify: + client.add_span(span_data) + assert len(client.queue) == 1 + mock_notify.assert_not_called() + + client.add_span(span_data) + assert len(client.queue) == 2 + mock_notify.assert_called_once() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + @patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") + def test_add_span_queue_full(self, mock_logger, mock_exporter_class, trace_client_factory): + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint", max_queue_size=1) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + mock_span = MagicMock(spec=ReadableSpan) + client.span_builder.build_span = MagicMock(return_value=mock_span) + + client.add_span(span_data) + assert len(client.queue) == 1 + + client.add_span(span_data) + assert len(client.queue) == 1 + mock_logger.warning.assert_called_with("Queue is full, likely spans will be dropped.") + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_export_batch_error(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + mock_exporter.export.side_effect = Exception("Export failed") + + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + with patch("core.ops.aliyun_trace.data_exporter.traceclient.logger") as mock_logger: + client._export_batch() + mock_logger.warning.assert_called() + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_worker_loop(self, mock_exporter_class, trace_client_factory): + # We need to test the wait timeout in _worker + # But _worker runs in a thread. Let's mock condition.wait. + client = trace_client_factory( + service_name="test-service", + endpoint="http://test-endpoint", + schedule_delay_sec=0.1, + ) + + with patch.object(client.condition, "wait") as mock_wait: + # Let it run for a bit then shut down + time.sleep(0.2) + client.shutdown() + # mock_wait might have been called + assert mock_wait.called or client.done + + @patch("core.ops.aliyun_trace.data_exporter.traceclient.OTLPSpanExporter") + def test_shutdown_flushes(self, mock_exporter_class, trace_client_factory): + mock_exporter = mock_exporter_class.return_value + client = trace_client_factory(service_name="test-service", endpoint="http://test-endpoint") + + mock_span = MagicMock(spec=ReadableSpan) + client.queue.append(mock_span) + + client.shutdown() + # Should have called export twice (once in worker/export_batch, once in shutdown) + # or at least once if worker was waiting + assert mock_exporter.export.called + assert mock_exporter.shutdown.called + + +class TestSpanBuilder: + def test_build_span(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=789, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + attributes={"attr1": "val1"}, + events=[], + links=[], + ) + + span = builder.build_span(span_data) + assert isinstance(span, ReadableSpan) + assert span.name == "test-span" + assert span.context.trace_id == 123 + assert span.context.span_id == 456 + assert span.parent.span_id == 789 + assert span.resource == resource + assert span.attributes == {"attr1": "val1"} + + def test_build_span_no_parent(self): + resource = MagicMock() + builder = SpanBuilder(resource) + + span_data = SpanData( + name="test-span", + trace_id=123, + span_id=456, + parent_span_id=None, + start_time=1000, + end_time=2000, + status=Status(StatusCode.OK), + span_kind=SpanKind.INTERNAL, + ) + + span = builder.build_span(span_data) + assert span.parent is None + + +def test_create_link(): + trace_id_str = "0123456789abcdef0123456789abcdef" + link = create_link(trace_id_str) + assert link.context.trace_id == int(trace_id_str, 16) + assert link.context.span_id == INVALID_SPAN_ID + + with pytest.raises(ValueError, match="Invalid trace ID format"): + create_link("invalid-hex") + + +def test_generate_span_id(): + # Test normal generation + span_id = generate_span_id() + assert isinstance(span_id, int) + assert span_id != INVALID_SPAN_ID + + # Test retry loop + with patch("core.ops.aliyun_trace.data_exporter.traceclient.random.getrandbits") as mock_rand: + mock_rand.side_effect = [INVALID_SPAN_ID, 999] + span_id = generate_span_id() + assert span_id == 999 + assert mock_rand.call_count == 2 + + +def test_convert_to_trace_id(): + uid = str(uuid.uuid4()) + trace_id = convert_to_trace_id(uid) + assert trace_id == uuid.UUID(uid).int + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_trace_id(None) + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_trace_id("not-a-uuid") + + +def test_convert_string_to_id(): + assert convert_string_to_id("test") > 0 + # Test with None string + with patch("core.ops.aliyun_trace.data_exporter.traceclient.generate_span_id") as mock_gen: + mock_gen.return_value = 12345 + assert convert_string_to_id(None) == 12345 + + +def test_convert_to_span_id(): + uid = str(uuid.uuid4()) + span_id = convert_to_span_id(uid, "test-type") + assert isinstance(span_id, int) + + with pytest.raises(ValueError, match="UUID cannot be None"): + convert_to_span_id(None, "test") + + with pytest.raises(ValueError, match="Invalid UUID input"): + convert_to_span_id("not-a-uuid", "test") + + +def test_convert_datetime_to_nanoseconds(): + dt = datetime(2023, 1, 1, 12, 0, 0) + ns = convert_datetime_to_nanoseconds(dt) + assert ns == int(dt.timestamp() * 1e9) + assert convert_datetime_to_nanoseconds(None) is None + + +def test_build_endpoint(): + license_key = "abc" + + # CMS 2.0 endpoint + url1 = "https://log.aliyuncs.com" + assert build_endpoint(url1, license_key) == "https://log.aliyuncs.com/adapt_abc/api/v1/traces" + + # XTrace endpoint + url2 = "https://example.com" + assert build_endpoint(url2, license_key) == "https://example.com/adapt_abc/api/otlp/traces" diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py new file mode 100644 index 0000000000..2fcb927e0c --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_aliyun_trace_entity.py @@ -0,0 +1,88 @@ +import pytest +from opentelemetry import trace as trace_api +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import SpanKind, Status, StatusCode +from pydantic import ValidationError + +from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData, TraceMetadata + + +class TestTraceMetadata: + def test_trace_metadata_init(self): + links = [trace_api.Link(context=trace_api.SpanContext(0, 0, False))] + metadata = TraceMetadata( + trace_id=123, workflow_span_id=456, session_id="session_1", user_id="user_1", links=links + ) + assert metadata.trace_id == 123 + assert metadata.workflow_span_id == 456 + assert metadata.session_id == "session_1" + assert metadata.user_id == "user_1" + assert metadata.links == links + + +class TestSpanData: + def test_span_data_init_required_fields(self): + span_data = SpanData(trace_id=123, span_id=456, name="test_span", start_time=1000, end_time=2000) + assert span_data.trace_id == 123 + assert span_data.span_id == 456 + assert span_data.name == "test_span" + assert span_data.start_time == 1000 + assert span_data.end_time == 2000 + + # Check defaults + assert span_data.parent_span_id is None + assert span_data.attributes == {} + assert span_data.events == [] + assert span_data.links == [] + assert span_data.status.status_code == StatusCode.UNSET + assert span_data.span_kind == SpanKind.INTERNAL + + def test_span_data_with_optional_fields(self): + event = Event(name="event_1", timestamp=1500) + link = trace_api.Link(context=trace_api.SpanContext(0, 0, False)) + status = Status(StatusCode.OK) + + span_data = SpanData( + trace_id=123, + parent_span_id=111, + span_id=456, + name="test_span", + attributes={"key": "value"}, + events=[event], + links=[link], + status=status, + start_time=1000, + end_time=2000, + span_kind=SpanKind.SERVER, + ) + + assert span_data.parent_span_id == 111 + assert span_data.attributes == {"key": "value"} + assert span_data.events == [event] + assert span_data.links == [link] + assert span_data.status.status_code == status.status_code + assert span_data.span_kind == SpanKind.SERVER + + def test_span_data_missing_required_fields(self): + with pytest.raises(ValidationError): + SpanData( + trace_id=123, + # span_id missing + name="test_span", + start_time=1000, + end_time=2000, + ) + + def test_span_data_arbitrary_types_allowed(self): + # opentelemetry.trace.Status and Event are "arbitrary types" for Pydantic + # This test ensures they are accepted thanks to model_config + status = Status(StatusCode.ERROR, description="error occurred") + event = Event(name="exception", timestamp=1234, attributes={"exception.type": "ValueError"}) + + span_data = SpanData( + trace_id=123, span_id=456, name="test_span", status=status, events=[event], start_time=1000, end_time=2000 + ) + + assert span_data.status.status_code == status.status_code + assert span_data.status.description == status.description + assert span_data.events == [event] diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py new file mode 100644 index 0000000000..3961555b9a --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/entities/test_semconv.py @@ -0,0 +1,68 @@ +from core.ops.aliyun_trace.entities.semconv import ( + ACS_ARMS_SERVICE_FEATURE, + GEN_AI_COMPLETION, + GEN_AI_FRAMEWORK, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_PROVIDER_NAME, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + GEN_AI_USAGE_OUTPUT_TOKENS, + GEN_AI_USAGE_TOTAL_TOKENS, + GEN_AI_USER_ID, + GEN_AI_USER_NAME, + INPUT_VALUE, + OUTPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) + + +def test_constants(): + assert ACS_ARMS_SERVICE_FEATURE == "acs.arms.service.feature" + assert GEN_AI_SESSION_ID == "gen_ai.session.id" + assert GEN_AI_USER_ID == "gen_ai.user.id" + assert GEN_AI_USER_NAME == "gen_ai.user.name" + assert GEN_AI_SPAN_KIND == "gen_ai.span.kind" + assert GEN_AI_FRAMEWORK == "gen_ai.framework" + assert INPUT_VALUE == "input.value" + assert OUTPUT_VALUE == "output.value" + assert RETRIEVAL_QUERY == "retrieval.query" + assert RETRIEVAL_DOCUMENT == "retrieval.document" + assert GEN_AI_REQUEST_MODEL == "gen_ai.request.model" + assert GEN_AI_PROVIDER_NAME == "gen_ai.provider.name" + assert GEN_AI_USAGE_INPUT_TOKENS == "gen_ai.usage.input_tokens" + assert GEN_AI_USAGE_OUTPUT_TOKENS == "gen_ai.usage.output_tokens" + assert GEN_AI_USAGE_TOTAL_TOKENS == "gen_ai.usage.total_tokens" + assert GEN_AI_PROMPT == "gen_ai.prompt" + assert GEN_AI_COMPLETION == "gen_ai.completion" + assert GEN_AI_RESPONSE_FINISH_REASON == "gen_ai.response.finish_reason" + assert GEN_AI_INPUT_MESSAGE == "gen_ai.input.messages" + assert GEN_AI_OUTPUT_MESSAGE == "gen_ai.output.messages" + assert TOOL_NAME == "tool.name" + assert TOOL_DESCRIPTION == "tool.description" + assert TOOL_PARAMETERS == "tool.parameters" + + +def test_gen_ai_span_kind_enum(): + assert GenAISpanKind.CHAIN == "CHAIN" + assert GenAISpanKind.RETRIEVER == "RETRIEVER" + assert GenAISpanKind.RERANKER == "RERANKER" + assert GenAISpanKind.LLM == "LLM" + assert GenAISpanKind.EMBEDDING == "EMBEDDING" + assert GenAISpanKind.TOOL == "TOOL" + assert GenAISpanKind.AGENT == "AGENT" + assert GenAISpanKind.TASK == "TASK" + + # Verify iteration works (covers the class definition) + kinds = list(GenAISpanKind) + assert len(kinds) == 8 + assert "LLM" in kinds diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py new file mode 100644 index 0000000000..dfd61acfa7 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags + +import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module +from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_COMPLETION, + GEN_AI_INPUT_MESSAGE, + GEN_AI_OUTPUT_MESSAGE, + GEN_AI_PROMPT, + GEN_AI_REQUEST_MODEL, + GEN_AI_RESPONSE_FINISH_REASON, + GEN_AI_USAGE_TOTAL_TOKENS, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.entities.config_entity import AliyunConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey + + +class RecordingTraceClient: + def __init__(self, service_name: str = "service", endpoint: str = "endpoint"): + self.service_name = service_name + self.endpoint = endpoint + self.added_spans: list[object] = [] + + def add_span(self, span) -> None: + self.added_spans.append(span) + + def api_check(self) -> bool: + return True + + def get_project_url(self) -> str: + return "project-url" + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_link(trace_id: int = 1, span_id: int = 2) -> Link: + context = SpanContext( + trace_id=trace_id, + span_id=span_id, + is_remote=False, + trace_flags=TraceFlags.SAMPLED, + ) + return Link(context) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "workflow-id", + "tenant_id": "tenant-id", + "workflow_run_id": "00000000-0000-0000-0000-000000000001", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"sys.query": "hello"}, + "workflow_run_outputs": {"answer": "world"}, + "workflow_run_version": "v1", + "total_tokens": 1, + "file_list": [], + "query": "hello", + "metadata": {"conversation_id": "conv", "user_id": "u", "app_id": "app"}, + "message_id": None, + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 1, + "answer_tokens": 2, + "total_tokens": 3, + "conversation_mode": "chat", + "metadata": {"conversation_id": "conv", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000002", + "message_data": SimpleNamespace(from_account_id="acc", from_end_user_id=None), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000003", + "message_data": SimpleNamespace(), + "inputs": "q", + "documents": [SimpleNamespace()], + "start_time": _dt(), + "end_time": _dt(), + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "out", + "tool_config": {"desc": "d"}, + "tool_parameters": {}, + "time_cost": 0.1, + "metadata": {"conversation_id": "conv", "user_id": "u"}, + "message_id": "00000000-0000-0000-0000-000000000004", + "message_data": SimpleNamespace(), + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 1, + "metadata": {"conversation_id": "conv", "user_id": "u", "ls_model_name": "m", "ls_provider": "p"}, + "message_id": "00000000-0000-0000-0000-000000000005", + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + "trace_id": "550e8400-e29b-41d4-a716-446655440000", + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +@pytest.fixture +def trace_instance(monkeypatch: pytest.MonkeyPatch) -> AliyunDataTrace: + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", lambda base_url, license_key: "built-endpoint") + monkeypatch.setattr(aliyun_trace_module, "TraceClient", RecordingTraceClient) + # Mock get_service_account_with_tenant to avoid DB errors + monkeypatch.setattr(AliyunDataTrace, "get_service_account_with_tenant", lambda self, app_id: MagicMock()) + + config = AliyunConfig(app_name="app", license_key="k", endpoint="https://example.com") + trace = AliyunDataTrace(config) + return trace + + +def test_init_builds_endpoint_and_client(monkeypatch: pytest.MonkeyPatch): + build_endpoint = MagicMock(return_value="built") + trace_client_cls = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "build_endpoint", build_endpoint) + monkeypatch.setattr(aliyun_trace_module, "TraceClient", trace_client_cls) + + config = AliyunConfig(app_name="my-app", license_key="license", endpoint="https://example.com") + trace = AliyunDataTrace(config) + + build_endpoint.assert_called_once_with("https://example.com", "license") + trace_client_cls.assert_called_once_with(service_name="my-app", endpoint="built") + assert trace.trace_config == config + + +def test_trace_dispatches_to_correct_methods(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + workflow_trace = MagicMock() + message_trace = MagicMock() + suggested_question_trace = MagicMock() + dataset_retrieval_trace = MagicMock() + tool_trace = MagicMock() + monkeypatch.setattr(trace_instance, "workflow_trace", workflow_trace) + monkeypatch.setattr(trace_instance, "message_trace", message_trace) + monkeypatch.setattr(trace_instance, "suggested_question_trace", suggested_question_trace) + monkeypatch.setattr(trace_instance, "dataset_retrieval_trace", dataset_retrieval_trace) + monkeypatch.setattr(trace_instance, "tool_trace", tool_trace) + + trace_instance.trace(_make_workflow_trace_info()) + workflow_trace.assert_called_once() + + trace_instance.trace(_make_message_trace_info()) + message_trace.assert_called_once() + + trace_instance.trace(_make_suggested_question_trace_info()) + suggested_question_trace.assert_called_once() + + trace_instance.trace(_make_dataset_retrieval_trace_info()) + dataset_retrieval_trace.assert_called_once() + + trace_instance.trace(_make_tool_trace_info()) + tool_trace.assert_called_once() + + # Branches that do nothing but should be covered + trace_instance.trace(ModerationTraceInfo(flagged=False, action="allow", preset_response="", query="", metadata={})) + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + + +def test_api_check_delegates(trace_instance: AliyunDataTrace): + trace_instance.trace_client.api_check = MagicMock(return_value=False) + assert trace_instance.api_check() is False + + +def test_get_project_url_success(trace_instance: AliyunDataTrace): + assert trace_instance.get_project_url() == "project-url" + + +def test_get_project_url_error(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(trace_instance.trace_client, "get_project_url", MagicMock(side_effect=Exception("boom"))) + logger_mock = MagicMock() + monkeypatch.setattr(aliyun_trace_module, "logger", logger_mock) + + with pytest.raises(ValueError, match=r"Aliyun get project url failed: boom"): + trace_instance.get_project_url() + logger_mock.info.assert_called() + + +def test_workflow_trace_adds_workflow_and_node_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 111) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"workflow": 222}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + add_workflow_span = MagicMock() + get_workflow_node_executions = MagicMock(return_value=[MagicMock(), MagicMock()]) + build_workflow_node_span = MagicMock(side_effect=["span-1", "span-2"]) + monkeypatch.setattr(trace_instance, "add_workflow_span", add_workflow_span) + monkeypatch.setattr(trace_instance, "get_workflow_node_executions", get_workflow_node_executions) + monkeypatch.setattr(trace_instance, "build_workflow_node_span", build_workflow_node_span) + + trace_info = _make_workflow_trace_info( + trace_id="abcd", metadata={"conversation_id": "c", "user_id": "u", "app_id": "app"} + ) + trace_instance.workflow_trace(trace_info) + + add_workflow_span.assert_called_once() + passed_trace_metadata = add_workflow_span.call_args.args[1] + assert passed_trace_metadata.trace_id == 111 + assert passed_trace_metadata.workflow_span_id == 222 + assert passed_trace_metadata.session_id == "c" + assert passed_trace_metadata.user_id == "u" + assert passed_trace_metadata.links == [] + + assert trace_instance.trace_client.added_spans == ["span-1", "span-2"] + + +def test_message_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_message_trace_creates_message_and_llm_spans(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "llm": 30}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "get_user_id_from_message_data", lambda _: "user") + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_message_trace_info( + metadata={"conversation_id": "conv", "ls_model_name": "model", "ls_provider": "provider"}, + message_tokens=7, + answer_tokens=11, + total_tokens=18, + outputs="completion", + ) + trace_instance.message_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span, llm_span = trace_instance.trace_client.added_spans + + assert message_span.name == "message" + assert message_span.trace_id == 10 + assert message_span.parent_span_id is None + assert message_span.span_id == 20 + assert message_span.span_kind == SpanKind.SERVER + assert message_span.status == status + assert message_span.attributes["gen_ai.span.kind"] == GenAISpanKind.CHAIN + + assert llm_span.name == "llm" + assert llm_span.parent_span_id == 20 + assert llm_span.span_id == 30 + assert llm_span.status == status + assert llm_span.attributes[GEN_AI_REQUEST_MODEL] == "model" + assert llm_span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "18" + + +def test_dataset_retrieval_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_dataset_retrieval_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 1) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 2}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 3) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + monkeypatch.setattr(aliyun_trace_module, "extract_retrieval_documents", lambda _: [{"doc": "d"}]) + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info(inputs="query")) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "dataset_retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"doc": "d"}]' + + +def test_tool_trace_returns_early_if_no_message_data(trace_instance: AliyunDataTrace): + trace_info = _make_tool_trace_info(message_data=None) + trace_instance.tool_trace(trace_info) + assert trace_instance.trace_client.added_spans == [] + + +def test_tool_trace_creates_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "generate_span_id", lambda: 30) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_instance.tool_trace( + _make_tool_trace_info( + tool_name="my-tool", + tool_inputs={"a": 1}, + tool_config={"description": "x"}, + inputs={"i": 1}, + ) + ) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "my-tool" + assert span.status == status + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"description": "x"}' + + +def test_get_workflow_node_executions_requires_app_id(trace_instance: AliyunDataTrace): + trace_info = _make_workflow_trace_info(metadata={"conversation_id": "c"}) + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.get_workflow_node_executions(trace_info) + + +def test_get_workflow_node_executions_builds_repo_and_fetches( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + trace_info = _make_workflow_trace_info(metadata={"app_id": "app", "conversation_id": "c", "user_id": "u"}) + + account = object() + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", MagicMock(return_value=account)) + monkeypatch.setattr(aliyun_trace_module, "sessionmaker", MagicMock()) + monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) + + repo = MagicMock() + repo.get_by_workflow_run.return_value = ["node1"] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) + + result = trace_instance.get_workflow_node_executions(trace_info) + assert result == ["node1"] + repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + + +def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_llm_span", MagicMock(return_value="llm")) + + node_execution.node_type = BuiltinNodeTypes.LLM + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "llm" + + +def test_build_workflow_node_span_routes_knowledge_retrieval_type( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_retrieval_span", MagicMock(return_value="retrieval")) + + node_execution.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "retrieval" + + +def test_build_workflow_node_span_routes_tool_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_tool_span", MagicMock(return_value="tool")) + + node_execution.node_type = BuiltinNodeTypes.TOOL + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "tool" + + +def test_build_workflow_node_span_routes_code_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(return_value="task")) + + node_execution.node_type = BuiltinNodeTypes.CODE + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) == "task" + + +def test_build_workflow_node_span_handles_errors( + trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +): + node_execution = MagicMock(spec=WorkflowNodeExecution) + trace_info = _make_workflow_trace_info() + trace_metadata = MagicMock() + + monkeypatch.setattr(trace_instance, "build_workflow_task_span", MagicMock(side_effect=RuntimeError("boom"))) + node_execution.node_type = BuiltinNodeTypes.CODE + + assert trace_instance.build_workflow_node_span(node_execution, trace_info, trace_metadata) is None + assert "Error occurred in build_workflow_node_span" in caplog.text + + +def test_build_workflow_task_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "title" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_task_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.trace_id == 1 + assert span.span_id == 9 + assert span.status.status_code == StatusCode.OK + assert span.attributes["gen_ai.span.kind"] == GenAISpanKind.TASK + + +def test_build_workflow_tool_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[_make_link()]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "my-tool" + node_execution.inputs = {"a": 1} + node_execution.outputs = {"b": 2} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"k": "v"}} + + span = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[TOOL_NAME] == "my-tool" + assert span.attributes[TOOL_DESCRIPTION] == '{"k": "v"}' + assert span.attributes[TOOL_PARAMETERS] == '{"a": 1}' + assert span.status.status_code == StatusCode.OK + + # Cover metadata is None and inputs is None + node_execution.metadata = None + node_execution.inputs = None + span2 = trace_instance.build_workflow_tool_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[TOOL_DESCRIPTION] == "{}" + assert span2.attributes[TOOL_PARAMETERS] == "{}" + + +def test_build_workflow_retrieval_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr( + aliyun_trace_module, "format_retrieval_documents", lambda docs: [{"formatted": True}] if docs else [] + ) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "retrieval" + node_execution.inputs = {"query": "q"} + node_execution.outputs = {"result": [{"doc": "d"}]} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[RETRIEVAL_QUERY] == "q" + assert span.attributes[RETRIEVAL_DOCUMENT] == '[{"formatted": true}]' + + # Cover empty inputs/outputs + node_execution.inputs = None + node_execution.outputs = None + span2 = trace_instance.build_workflow_retrieval_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[RETRIEVAL_QUERY] == "" + assert span2.attributes[RETRIEVAL_DOCUMENT] == "[]" + + +def test_build_workflow_llm_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_span_id", lambda _, __: 9) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "get_workflow_node_status", lambda _: status) + monkeypatch.setattr(aliyun_trace_module, "format_input_messages", lambda _: "in") + monkeypatch.setattr(aliyun_trace_module, "format_output_messages", lambda _: "out") + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node-id" + node_execution.title = "llm" + node_execution.process_data = { + "usage": {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}, + "prompts": ["p"], + "model_name": "m", + "model_provider": "p1", + } + node_execution.outputs = {"text": "t", "finish_reason": "stop"} + node_execution.created_at = _dt() + node_execution.finished_at = _dt() + + span = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "3" + assert span.attributes[GEN_AI_REQUEST_MODEL] == "m" + assert span.attributes[GEN_AI_PROMPT] == '["p"]' + assert span.attributes[GEN_AI_COMPLETION] == "t" + assert span.attributes[GEN_AI_RESPONSE_FINISH_REASON] == "stop" + assert span.attributes[GEN_AI_INPUT_MESSAGE] == "in" + assert span.attributes[GEN_AI_OUTPUT_MESSAGE] == "out" + + # Cover usage from outputs if not in process_data + node_execution.process_data = {"prompts": []} + node_execution.outputs = {"usage": {"total_tokens": 10}, "text": ""} + span2 = trace_instance.build_workflow_llm_span(_make_workflow_trace_info(), node_execution, trace_metadata) + assert span2.attributes[GEN_AI_USAGE_TOTAL_TOKENS] == "10" + + +def test_add_workflow_span(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr( + aliyun_trace_module, "convert_to_span_id", lambda _, span_type: {"message": 20}.get(span_type, 0) + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_metadata = SimpleNamespace(trace_id=1, workflow_span_id=2, session_id="s", user_id="u", links=[]) + + # CASE 1: With message_id + trace_info = _make_workflow_trace_info( + message_id="msg-1", workflow_run_inputs={"sys.query": "hi"}, workflow_run_outputs={"ans": "ok"} + ) + trace_instance.add_workflow_span(trace_info, trace_metadata) + + assert len(trace_instance.trace_client.added_spans) == 2 + message_span = trace_instance.trace_client.added_spans[0] + workflow_span = trace_instance.trace_client.added_spans[1] + + assert message_span.name == "message" + assert message_span.span_kind == SpanKind.SERVER + assert message_span.parent_span_id is None + + assert workflow_span.name == "workflow" + assert workflow_span.span_kind == SpanKind.INTERNAL + assert workflow_span.parent_span_id == 20 + + trace_instance.trace_client.added_spans.clear() + + # CASE 2: Without message_id + trace_info_no_msg = _make_workflow_trace_info(message_id=None) + trace_instance.add_workflow_span(trace_info_no_msg, trace_metadata) + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "workflow" + assert span.span_kind == SpanKind.SERVER + assert span.parent_span_id is None + + +def test_suggested_question_trace(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(aliyun_trace_module, "convert_to_trace_id", lambda _: 10) + monkeypatch.setattr( + aliyun_trace_module, + "convert_to_span_id", + lambda _, span_type: {"message": 20, "suggested_question": 21}.get(span_type, 0), + ) + monkeypatch.setattr(aliyun_trace_module, "convert_datetime_to_nanoseconds", lambda _: 123) + monkeypatch.setattr(aliyun_trace_module, "create_links_from_trace_id", lambda _: []) + status = Status(StatusCode.OK) + monkeypatch.setattr(aliyun_trace_module, "create_status_from_error", lambda _: status) + + trace_info = _make_suggested_question_trace_info(suggested_question=["how?"]) + trace_instance.suggested_question_trace(trace_info) + + assert len(trace_instance.trace_client.added_spans) == 1 + span = trace_instance.trace_client.added_spans[0] + assert span.name == "suggested_question" + assert span.attributes[GEN_AI_COMPLETION] == '["how?"]' diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py new file mode 100644 index 0000000000..763fc90710 --- /dev/null +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -0,0 +1,275 @@ +import json +from unittest.mock import MagicMock + +from opentelemetry.trace import Link, StatusCode + +from core.ops.aliyun_trace.entities.semconv import ( + GEN_AI_FRAMEWORK, + GEN_AI_SESSION_ID, + GEN_AI_SPAN_KIND, + GEN_AI_USER_ID, + INPUT_VALUE, + OUTPUT_VALUE, +) +from core.ops.aliyun_trace.utils import ( + create_common_span_attributes, + create_links_from_trace_id, + create_status_from_error, + extract_retrieval_documents, + format_input_messages, + format_output_messages, + format_retrieval_documents, + get_user_id_from_message_data, + get_workflow_node_status, + serialize_json_data, +) +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionStatus +from models import EndUser + + +def test_get_user_id_from_message_data_no_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = None + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_get_user_id_from_message_data_with_end_user(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + end_user_data = MagicMock(spec=EndUser) + end_user_data.session_id = "session_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = end_user_data + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "session_id" + + +def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): + message_data = MagicMock() + message_data.from_account_id = "account_id" + message_data.from_end_user_id = "end_user_id" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = None + + mock_session = MagicMock() + mock_session.query.return_value = mock_query + + from core.ops.aliyun_trace.utils import db + + monkeypatch.setattr(db, "session", mock_session) + + assert get_user_id_from_message_data(message_data) == "account_id" + + +def test_create_status_from_error(): + # Case OK + status_ok = create_status_from_error(None) + assert status_ok.status_code == StatusCode.OK + + # Case Error + status_err = create_status_from_error("some error") + assert status_err.status_code == StatusCode.ERROR + assert status_err.description == "some error" + + +def test_get_workflow_node_status(): + node_execution = MagicMock(spec=WorkflowNodeExecution) + + # SUCCEEDED + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.OK + + # FAILED + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = "node fail" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node fail" + + # EXCEPTION + node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION + node_execution.error = "node exception" + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.ERROR + assert status.description == "node exception" + + # UNSET/OTHER + node_execution.status = WorkflowNodeExecutionStatus.RUNNING + status = get_workflow_node_status(node_execution) + assert status.status_code == StatusCode.UNSET + + +def test_create_links_from_trace_id(monkeypatch): + # Mock create_link + mock_link = MagicMock(spec=Link) + import core.ops.aliyun_trace.data_exporter.traceclient + + monkeypatch.setattr(core.ops.aliyun_trace.data_exporter.traceclient, "create_link", lambda trace_id_str: mock_link) + + # Trace ID None + assert create_links_from_trace_id(None) == [] + + # Trace ID Present + links = create_links_from_trace_id("trace_id") + assert len(links) == 1 + assert links[0] == mock_link + + +def test_extract_retrieval_documents(): + doc1 = MagicMock(spec=Document) + doc1.page_content = "content1" + doc1.metadata = {"dataset_id": "ds1", "doc_id": "di1", "document_id": "dd1", "score": 0.9} + + doc2 = MagicMock(spec=Document) + doc2.page_content = "content2" + doc2.metadata = {"dataset_id": "ds2"} # Missing some keys + + documents = [doc1, doc2] + extracted = extract_retrieval_documents(documents) + + assert len(extracted) == 2 + assert extracted[0]["content"] == "content1" + assert extracted[0]["metadata"]["dataset_id"] == "ds1" + assert extracted[0]["score"] == 0.9 + + assert extracted[1]["content"] == "content2" + assert extracted[1]["metadata"]["dataset_id"] == "ds2" + assert extracted[1]["metadata"]["doc_id"] is None + assert extracted[1]["score"] is None + + +def test_serialize_json_data(): + data = {"a": 1} + # Test ensure_ascii default (False) + assert serialize_json_data(data) == json.dumps(data, ensure_ascii=False) + # Test ensure_ascii True + assert serialize_json_data(data, ensure_ascii=True) == json.dumps(data, ensure_ascii=True) + + +def test_create_common_span_attributes(): + attrs = create_common_span_attributes( + session_id="s1", user_id="u1", span_kind="kind1", framework="fw1", inputs="in1", outputs="out1" + ) + assert attrs[GEN_AI_SESSION_ID] == "s1" + assert attrs[GEN_AI_USER_ID] == "u1" + assert attrs[GEN_AI_SPAN_KIND] == "kind1" + assert attrs[GEN_AI_FRAMEWORK] == "fw1" + assert attrs[INPUT_VALUE] == "in1" + assert attrs[OUTPUT_VALUE] == "out1" + + +def test_format_retrieval_documents(): + # Not a list + assert format_retrieval_documents("not a list") == [] + + # Valid list + docs = [ + {"metadata": {"score": 0.8, "document_id": "doc1", "source": "src1"}, "content": "c1", "title": "t1"}, + { + "metadata": {"_source": "src2", "doc_metadata": {"extra": "val"}}, + "content": "c2", + # Missing title + }, + "not a dict", # Should be skipped + ] + formatted = format_retrieval_documents(docs) + + assert len(formatted) == 2 + assert formatted[0]["document"]["content"] == "c1" + assert formatted[0]["document"]["metadata"]["title"] == "t1" + assert formatted[0]["document"]["metadata"]["source"] == "src1" + assert formatted[0]["document"]["score"] == 0.8 + assert formatted[0]["document"]["id"] == "doc1" + + assert formatted[1]["document"]["content"] == "c2" + assert formatted[1]["document"]["metadata"]["source"] == "src2" + assert formatted[1]["document"]["metadata"]["extra"] == "val" + assert "title" not in formatted[1]["document"]["metadata"] + assert formatted[1]["document"]["score"] == 0.0 # Default + + # Exception handling + # We can trigger an exception by passing something that causes an error in the loop logic, + # but the try/except covers the whole function. + # Passing a list that contains something that throws when calling .get() - though dicts won't. + # Let's mock a dict that raises on get. + class BadDict: + def get(self, *args, **kwargs): + raise Exception("boom") + + assert format_retrieval_documents([BadDict()]) == [] + + +def test_format_input_messages(): + # Not a dict + assert format_input_messages(None) == serialize_json_data([]) + + # No prompts + assert format_input_messages({}) == serialize_json_data([]) + + # Valid prompts + process_data = { + "prompts": [ + {"role": "user", "text": "hello"}, + {"role": "assistant", "text": "hi"}, + {"role": "system", "text": "be helpful"}, + {"role": "tool", "text": "result"}, + {"role": "invalid", "text": "skip me"}, + "not a dict", + {"role": "user", "text": ""}, # Empty text, should be skipped? Code says `if text: message = ...` + ] + } + result = format_input_messages(process_data) + result_list = json.loads(result) + + assert len(result_list) == 4 + assert result_list[0]["role"] == "user" + assert result_list[0]["parts"][0]["content"] == "hello" + assert result_list[1]["role"] == "assistant" + assert result_list[2]["role"] == "system" + assert result_list[3]["role"] == "tool" + + # Exception path + assert format_input_messages({"prompts": [None]}) == serialize_json_data([]) + + +def test_format_output_messages(): + # Not a dict + assert format_output_messages(None) == serialize_json_data([]) + + # No text + assert format_output_messages({"finish_reason": "stop"}) == serialize_json_data([]) + + # Valid + outputs = {"text": "done", "finish_reason": "length"} + result = format_output_messages(outputs) + result_list = json.loads(result) + assert len(result_list) == 1 + assert result_list[0]["role"] == "assistant" + assert result_list[0]["parts"][0]["content"] == "done" + assert result_list[0]["finish_reason"] == "length" + + # Invalid finish reason + outputs2 = {"text": "done", "finish_reason": "unknown"} + result2 = format_output_messages(outputs2) + result_list2 = json.loads(result2) + assert result_list2[0]["finish_reason"] == "stop" + + # Exception path + # Trigger exception in serialize_json_data by passing non-serializable + assert format_output_messages({"text": MagicMock()}) == serialize_json_data([]) diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py new file mode 100644 index 0000000000..1cee2f5b68 --- /dev/null +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -0,0 +1,398 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest +from opentelemetry.sdk.trace import Tracer +from opentelemetry.semconv.trace import SpanAttributes as OTELSpanAttributes +from opentelemetry.trace import StatusCode + +from core.ops.arize_phoenix_trace.arize_phoenix_trace import ( + ArizePhoenixDataTrace, + datetime_to_nanos, + error_to_string, + safe_json_dumps, + set_span_status, + setup_tracer, + wrap_span_metadata, +) +from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) + +# --- Helpers --- + + +def _dt(): + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_info(**kwargs): + defaults = { + "workflow_id": "w1", + "tenant_id": "t1", + "workflow_run_id": "r1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"in": "val"}, + "workflow_run_outputs": {"out": "val"}, + "workflow_run_version": "1.0", + "total_tokens": 10, + "file_list": ["f1"], + "query": "hi", + "metadata": {"app_id": "app1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(kwargs) + return WorkflowTraceInfo(**defaults) + + +def _make_message_info(**kwargs): + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 5, + "total_tokens": 10, + "conversation_mode": "chat", + "metadata": {"app_id": "app1"}, + "inputs": {"in": "val"}, + "outputs": "val", + "start_time": _dt(), + "end_time": _dt(), + "message_id": "m1", + } + defaults.update(kwargs) + return MessageTraceInfo(**defaults) + + +# --- Utility Function Tests --- + + +def test_datetime_to_nanos(): + dt = _dt() + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanos(dt) == expected + + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.datetime") as mock_dt: + mock_now = MagicMock() + mock_now.timestamp.return_value = 1704110400.0 + mock_dt.now.return_value = mock_now + assert datetime_to_nanos(None) == 1704110400000000000 + + +def test_error_to_string(): + try: + raise ValueError("boom") + except ValueError as e: + err = e + + res = error_to_string(err) + assert "ValueError: boom" in res + assert "traceback" in res.lower() or "line" in res.lower() + + assert error_to_string("str error") == "str error" + assert error_to_string(None) == "Empty Stack Trace" + + +def test_set_span_status(): + span = MagicMock() + # OK + set_span_status(span, None) + span.set_status.assert_called() + assert span.set_status.call_args[0][0].status_code == StatusCode.OK + + # Error Exception + span.reset_mock() + set_span_status(span, ValueError("fail")) + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.record_exception.assert_called() + + # Error String + span.reset_mock() + set_span_status(span, "fail-str") + assert span.set_status.call_args[0][0].status_code == StatusCode.ERROR + span.add_event.assert_called() + + # repr branch + class SilentError: + def __str__(self): + return "" + + def __repr__(self): + return "SilentErrorRepr" + + span.reset_mock() + set_span_status(span, SilentError()) + assert span.add_event.call_args[1]["attributes"][OTELSpanAttributes.EXCEPTION_MESSAGE] == "SilentErrorRepr" + + +def test_safe_json_dumps(): + assert safe_json_dumps({"a": _dt()}) == '{"a": "2024-01-01 00:00:00+00:00"}' + + +def test_wrap_span_metadata(): + res = wrap_span_metadata({"a": 1}, b=2) + assert res == {"a": 1, "b": 2, "created_from": "Dify"} + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.GrpcOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_arize(mock_provider, mock_exporter): + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://a.com/v1" + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.HttpOTLPSpanExporter") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.trace_sdk.TracerProvider") +def test_setup_tracer_phoenix(mock_provider, mock_exporter): + config = PhoenixConfig(endpoint="http://p.com", project="p") + setup_tracer(config) + mock_exporter.assert_called_once() + assert mock_exporter.call_args[1]["endpoint"] == "http://p.com/v1/traces" + + +def test_setup_tracer_exception(): + config = ArizeConfig(endpoint="http://a.com", project="p") + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.urlparse", side_effect=Exception("boom")): + with pytest.raises(Exception, match="boom"): + setup_tracer(config) + + +# --- ArizePhoenixDataTrace Class Tests --- + + +@pytest.fixture +def trace_instance(): + with patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.setup_tracer") as mock_setup: + mock_tracer = MagicMock(spec=Tracer) + mock_processor = MagicMock() + mock_setup.return_value = (mock_tracer, mock_processor) + config = ArizeConfig(endpoint="http://a.com", api_key="k", space_id="s", project="p") + return ArizePhoenixDataTrace(config) + + +def test_trace_dispatch(trace_instance): + with ( + patch.object(trace_instance, "workflow_trace") as m1, + patch.object(trace_instance, "message_trace") as m2, + patch.object(trace_instance, "moderation_trace") as m3, + patch.object(trace_instance, "suggested_question_trace") as m4, + patch.object(trace_instance, "dataset_retrieval_trace") as m5, + patch.object(trace_instance, "tool_trace") as m6, + patch.object(trace_instance, "generate_name_trace") as m7, + ): + trace_instance.trace(_make_workflow_info()) + m1.assert_called() + + trace_instance.trace(_make_message_info()) + m2.assert_called() + + trace_instance.trace(ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={})) + m3.assert_called() + + trace_instance.trace(SuggestedQuestionTraceInfo(suggested_question=[], total_tokens=0, level="i", metadata={})) + m4.assert_called() + + trace_instance.trace(DatasetRetrievalTraceInfo(metadata={})) + m5.assert_called() + + trace_instance.trace( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="o", + metadata={}, + tool_config={}, + time_cost=1, + tool_parameters={}, + ) + ) + m6.assert_called() + + trace_instance.trace(GenerateNameTraceInfo(tenant_id="t", metadata={})) + m7.assert_called() + + +def test_trace_exception(trace_instance): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("fail")): + with pytest.raises(RuntimeError): + trace_instance.trace(_make_workflow_info()) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.sessionmaker") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.DifyCoreRepositoryFactory") +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = repo + + node1 = MagicMock() + node1.node_type = "llm" + node1.status = "succeeded" + node1.inputs = {"q": "hi"} + node1.outputs = {"a": "bye", "usage": {"total_tokens": 5}} + node1.created_at = _dt() + node1.elapsed_time = 1.0 + node1.process_data = { + "prompts": [{"role": "user", "content": "hi"}], + "model_provider": "openai", + "model_name": "gpt-4", + } + node1.metadata = {"k": "v"} + node1.title = "title" + node1.id = "n1" + node1.error = None + + repo.get_by_workflow_run.return_value = [node1] + + with patch.object(trace_instance, "get_service_account_with_tenant"): + trace_instance.workflow_trace(info) + + assert trace_instance.tracer.start_span.call_count >= 2 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_workflow_trace_no_app_id(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_workflow_info() + info.metadata = {} + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(info) + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_success(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = None + info.error = None + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +@patch("core.ops.arize_phoenix_trace.arize_phoenix_trace.db") +def test_message_trace_with_error(mock_db, trace_instance): + mock_db.engine = MagicMock() + info = _make_message_info() + info.message_data = MagicMock() + info.message_data.from_account_id = "acc1" + info.message_data.from_end_user_id = None + info.message_data.query = "q" + info.message_data.answer = "a" + info.message_data.status = "s" + info.message_data.model_id = "m" + info.message_data.model_provider = "p" + info.message_data.message_metadata = "{}" + info.message_data.error = "processing failed" + info.error = "message error" + + trace_instance.message_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_trace_methods_return_early_with_no_message_data(trace_instance): + info = MagicMock() + info.message_data = None + + trace_instance.moderation_trace(info) + trace_instance.suggested_question_trace(info) + trace_instance.dataset_retrieval_trace(info) + trace_instance.tool_trace(info) + trace_instance.generate_name_trace(info) + + assert trace_instance.tracer.start_span.call_count == 0 + + +def test_moderation_trace_ok(trace_instance): + info = ModerationTraceInfo(flagged=True, action="a", preset_response="p", query="q", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.moderation_trace(info) + # root span (1) + moderation span (1) = 2 + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_suggested_question_trace_ok(trace_instance): + info = SuggestedQuestionTraceInfo(suggested_question=["?"], total_tokens=1, level="i", metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.suggested_question_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_dataset_retrieval_trace_ok(trace_instance): + info = DatasetRetrievalTraceInfo(documents=[], metadata={}) + info.message_data = MagicMock() + info.error = None + trace_instance.dataset_retrieval_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_tool_trace_ok(trace_instance): + info = ToolTraceInfo( + tool_name="t", tool_inputs={}, tool_outputs="o", metadata={}, tool_config={}, time_cost=1, tool_parameters={} + ) + info.message_data = MagicMock() + info.error = None + trace_instance.tool_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_generate_name_trace_ok(trace_instance): + info = GenerateNameTraceInfo(tenant_id="t", metadata={}) + info.message_data = MagicMock() + info.message_data.error = None + trace_instance.generate_name_trace(info) + assert trace_instance.tracer.start_span.call_count >= 1 + + +def test_get_project_url_phoenix(trace_instance): + trace_instance.arize_phoenix_config = PhoenixConfig(endpoint="http://p.com", project="p") + assert "p.com/projects/" in trace_instance.get_project_url() + + +def test_set_attribute_none_logic(trace_instance): + # Test role can be None + attrs = trace_instance._construct_llm_attributes([{"role": None, "content": "hi"}]) + assert "llm.input_messages.0.message.role" not in attrs + + # Test tool call id can be None + tool_call_none_id = {"id": None, "function": {"name": "f1"}} + attrs = trace_instance._construct_llm_attributes([{"role": "assistant", "tool_calls": [tool_call_none_id]}]) + assert "llm.input_messages.0.message.tool_calls.0.tool_call.id" not in attrs + + +def test_construct_llm_attributes_dict_branch(trace_instance): + attrs = trace_instance._construct_llm_attributes({"prompt": "hi"}) + assert '"prompt": "hi"' in attrs["llm.input_messages.0.message.content"] + assert attrs["llm.input_messages.0.message.role"] == "user" + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + + +def test_ensure_root_span_basic(trace_instance): + trace_instance.ensure_root_span("tid") + assert "tid" in trace_instance.dify_trace_ids diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py new file mode 100644 index 0000000000..0ff135562c --- /dev/null +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -0,0 +1,698 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangfuseConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( + LangfuseGeneration, + LangfuseSpan, + LangfuseTrace, + LevelEnum, + UnitEnum, +) +from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace +from dify_graph.enums import BuiltinNodeTypes +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def langfuse_config(): + return LangfuseConfig(public_key="pk-123", secret_key="sk-123", host="https://cloud.langfuse.com") + + +@pytest.fixture +def trace_instance(langfuse_config, monkeypatch): + # Mock Langfuse client to avoid network calls + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", lambda **kwargs: mock_client) + + instance = LangFuseDataTrace(langfuse_config) + return instance + + +def test_init(langfuse_config, monkeypatch): + mock_langfuse = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.Langfuse", mock_langfuse) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangFuseDataTrace(langfuse_config) + + mock_langfuse.assert_called_once_with( + public_key=langfuse_config.public_key, + secret_key=langfuse_config.secret_key, + host=langfuse_config.host, + ) + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Setup trace info + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id="msg-1", + conversation_id="conv-1", + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id="log-1", + error="", + ) + + # Mock DB and Repositories + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = BuiltinNodeTypes.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Other Node" + node_other.node_type = BuiltinNodeTypes.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None # Trigger datetime.now() branch + node_other.elapsed_time = 0.2 + node_other.metadata = None + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + # Track calls to add_trace, add_span, add_generation + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_trace (Workflow Level) + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "trace-1" + assert trace_data.name == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data.tags + assert "workflow" in trace_data.tags + + # Verify add_span (Workflow Run Span) + assert trace_instance.add_span.call_count >= 1 + # First span should be workflow run span because message_id is present + workflow_span = trace_instance.add_span.call_args_list[0][1]["langfuse_span_data"] + assert workflow_span.id == "run-1" + assert workflow_span.name == TraceTaskName.WORKFLOW_TRACE + + # Verify Generation for LLM node + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.id == "node-llm" + assert gen_data.usage.input == 10 + assert gen_data.usage.output == 20 + + # Verify normal span for Other node + # Second add_span call + other_span = trace_instance.add_span.call_args_list[1][1]["langfuse_span_data"] + assert other_span.id == "node-other" + assert other_span.level == LevelEnum.ERROR + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id=None, # Should fallback to workflow_run_id + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + ) + + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.id == "run-1" + assert trace_data.name == TraceTaskName.WORKFLOW_TRACE + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="tenant-1", + workflow_run_id="run-1", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={}, # Missing app_id + workflow_app_log_id="log-1", + error="", + ) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = "conv-1" + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_generation.assert_called_once() + + gen_data = trace_instance.add_generation.call_args[0][0] + assert gen_data.name == "llm" + assert gen_data.usage.total == 30 + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "conv-1" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + ) + + # Mock DB session for EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.user_id == "session-id-123" + assert trace_data.metadata["user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.MODERATION_TRACE + assert span_data.output["flagged"] is True + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_generation.assert_called_once() + gen_data = trace_instance.add_generation.call_args[1]["langfuse_generation_data"] + assert gen_data.name == TraceTaskName.SUGGESTED_QUESTION_TRACE + assert gen_data.usage.unit == UnitEnum.CHARACTERS + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == TraceTaskName.DATASET_RETRIEVAL_TRACE + assert span_data.output["documents"] == [{"id": "doc1"}] + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.name == "my_tool" + assert span_data.level == LevelEnum.ERROR + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + metadata={"m": 1}, + ) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[1]["langfuse_trace_data"] + assert trace_data.name == TraceTaskName.GENERATE_NAME_TRACE + assert trace_data.user_id == "tenant-1" + + span_data = trace_instance.add_span.call_args[1]["langfuse_span_data"] + assert span_data.trace_id == "conv-1" + + +def test_add_trace_success(trace_instance): + data = LangfuseTrace(id="t1", name="trace") + trace_instance.add_trace(data) + trace_instance.langfuse_client.trace.assert_called_once() + + +def test_add_trace_error(trace_instance): + trace_instance.langfuse_client.trace.side_effect = Exception("error") + data = LangfuseTrace(id="t1", name="trace") + with pytest.raises(ValueError, match="LangFuse Failed to create trace: error"): + trace_instance.add_trace(data) + + +def test_add_span_success(trace_instance): + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.add_span(data) + trace_instance.langfuse_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.langfuse_client.span.side_effect = Exception("error") + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create span: error"): + trace_instance.add_span(data) + + +def test_update_span(trace_instance): + span = MagicMock() + data = LangfuseSpan(id="s1", name="span", trace_id="t1") + trace_instance.update_span(span, data) + span.end.assert_called_once() + + +def test_add_generation_success(trace_instance): + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.add_generation(data) + trace_instance.langfuse_client.generation.assert_called_once() + + +def test_add_generation_error(trace_instance): + trace_instance.langfuse_client.generation.side_effect = Exception("error") + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + with pytest.raises(ValueError, match="LangFuse Failed to create generation: error"): + trace_instance.add_generation(data) + + +def test_update_generation(trace_instance): + gen = MagicMock() + data = LangfuseGeneration(id="g1", name="gen", trace_id="t1") + trace_instance.update_generation(gen, data) + gen.end.assert_called_once() + + +def test_api_check_success(trace_instance): + trace_instance.langfuse_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.langfuse_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_key_success(trace_instance): + mock_data = MagicMock() + mock_data.id = "proj-1" + trace_instance.langfuse_client.client.projects.get.return_value = MagicMock(data=[mock_data]) + assert trace_instance.get_project_key() == "proj-1" + + +def test_get_project_key_error(trace_instance): + trace_instance.langfuse_client.client.projects.get.side_effect = Exception("fail") + with pytest.raises(ValueError, match="LangFuse get project key failed: fail"): + trace_instance.get_project_key() + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="m", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="m", message_data=None, inputs={}, suggested_question=[], total_tokens=0, level="i", metadata={} + ) + trace_instance.add_generation = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_generation.assert_not_called() + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo(message_id="m", message_data=None, inputs={}, documents=[], metadata={}) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_langfuse_trace_entity_with_list_dict_input(): + # To cover lines 29-31 in langfuse_trace_entity.py + # We need to mock replace_text_with_content or just check if it works + # Actually replace_text_with_content is imported from core.ops.utils + data = LangfuseTrace(id="t1", name="n", input=[{"text": "hello"}]) + assert isinstance(data.input, list) + assert data.input[0]["content"] == "hello" + + +def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypatch, caplog): + # Setup trace info to trigger LLM node usage extraction + trace_info = WorkflowTraceInfo( + workflow_id="wf-1", + tenant_id="t", + workflow_run_id="r", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="c", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "app-1"}, + workflow_app_log_id="l", + error="", + ) + + node = MagicMock() + node.id = "n1" + node.title = "LLM Node" + node.node_type = BuiltinNodeTypes.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_generation = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + trace_instance.add_generation.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py new file mode 100644 index 0000000000..f656f7435f --- /dev/null +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -0,0 +1,608 @@ +import collections +from datetime import datetime, timedelta +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import LangSmithConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( + LangSmithRunModel, + LangSmithRunType, + LangSmithRunUpdateModel, +) +from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from models import EndUser + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0) + + +@pytest.fixture +def langsmith_config(): + return LangSmithConfig(api_key="ls-123", project="default", endpoint="https://api.smith.langchain.com") + + +@pytest.fixture +def trace_instance(langsmith_config, monkeypatch): + # Mock LangSmith client + mock_client = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", lambda **kwargs: mock_client) + + instance = LangSmithDataTrace(langsmith_config) + return instance + + +def test_init(langsmith_config, monkeypatch): + mock_client_class = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.Client", mock_client_class) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = LangSmithDataTrace(langsmith_config) + + mock_client_class.assert_called_once_with(api_key=langsmith_config.api_key, api_url=langsmith_config.endpoint) + assert instance.langsmith_key == langsmith_config.api_key + assert instance.project_name == langsmith_config.project + assert instance.file_base_url == "http://test.url" + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace(trace_instance, monkeypatch): + # Setup trace info + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + # Mock dependencies + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + # Mock node executions + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = BuiltinNodeTypes.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 30} + + node_other = MagicMock() + node_other.id = "node-other" + node_other.title = "Tool Node" + node_other.node_type = BuiltinNodeTypes.TOOL + node_other.status = "succeeded" + node_other.process_data = None + node_other.inputs = {"tool_input": "val"} + node_other.outputs = {"tool_output": "val"} + node_other.created_at = None # Trigger datetime.now() + node_other.elapsed_time = 0.2 + node_other.metadata = {} + + node_retrieval = MagicMock() + node_retrieval.id = "node-retrieval" + node_retrieval.title = "Retrieval Node" + node_retrieval.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL + node_retrieval.status = "succeeded" + node_retrieval.process_data = None + node_retrieval.inputs = {"query": "val"} + node_retrieval.outputs = {"results": "val"} + node_retrieval.created_at = _dt() + node_retrieval.elapsed_time = 0.2 + node_retrieval.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + trace_instance.workflow_trace(trace_info) + + # Verify add_run calls + # 1. message run (id="msg-1") + # 2. workflow run (id="run-1") + # 3. node llm run (id="node-llm") + # 4. node other run (id="node-other") + # 5. node retrieval run (id="node-retrieval") + assert trace_instance.add_run.call_count == 5 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + + assert call_args[0].id == "msg-1" + assert call_args[0].name == TraceTaskName.MESSAGE_TRACE + + assert call_args[1].id == "run-1" + assert call_args[1].name == TraceTaskName.WORKFLOW_TRACE + assert call_args[1].parent_run_id == "msg-1" + + assert call_args[2].id == "node-llm" + assert call_args[2].run_type == LangSmithRunType.llm + + assert call_args[3].id == "node-other" + assert call_args[3].run_type == LangSmithRunType.tool + + assert call_args[4].id == "node-retrieval" + assert call_args[4].run_type == LangSmithRunType.retriever + + +def test_workflow_trace_no_start_time(trace_instance, monkeypatch): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=10, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + trace_instance.workflow_trace(trace_info) + assert trace_instance.add_run.called + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.trace_id = "trace-1" + trace_info.message_id = None + trace_info.workflow_run_id = "run-1" + trace_info.start_time = None + trace_info.workflow_data = MagicMock() + trace_info.workflow_data.created_at = _dt() + trace_info.metadata = {} # Empty metadata + trace_info.workflow_app_log_id = "log-1" + trace_info.file_list = [] + trace_info.total_tokens = 0 + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.error = "" + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "msg-1" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.answer = "hello answer" + + trace_info = MessageTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"input": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id="trace-1", + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="file-url"), + ) + + # Mock EndUser lookup + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_run = MagicMock() + + trace_instance.message_trace(trace_info) + + # 1. message run + # 2. llm run + assert trace_instance.add_run.call_count == 2 + + call_args = [call[0][0] for call in trace_instance.add_run.call_args_list] + assert call_args[0].id == "msg-1" + assert call_args[0].extra["metadata"]["end_user_id"] == "session-id-123" + assert call_args[1].parent_run_id == "msg-1" + assert call_args[1].name == "llm" + + +def test_message_trace_no_data(trace_instance): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_data = None + trace_info.file_list = [] + trace_info.message_file_data = None + trace_info.metadata = {} + trace_instance.add_run = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace_no_data(trace_instance): + trace_info = MagicMock(spec=ModerationTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_suggested_question_trace_no_data(trace_instance): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_dataset_retrieval_trace_no_data(trace_instance): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_data = None + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + query="hi", + ) + + trace_instance.add_run = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.MODERATION_TRACE + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="msg-1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="msg-1", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="trace-1", + tool_config={}, + tool_parameters={}, + file_url="http://file", + ) + + trace_instance.add_run = MagicMock() + trace_instance.tool_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="conv-1", + start_time=None, + end_time=None, + metadata={}, + trace_id="trace-1", + ) + + trace_instance.add_run = MagicMock() + trace_instance.generate_name_trace(trace_info) + trace_instance.add_run.assert_called_once() + assert trace_instance.add_run.call_args[0][0].name == TraceTaskName.GENERATE_NAME_TRACE + + +def test_add_run_success(trace_instance): + run_data = LangSmithRunModel( + id="run-1", name="test", inputs={}, outputs={}, run_type=LangSmithRunType.tool, start_time=_dt() + ) + trace_instance.project_id = "proj-1" + trace_instance.add_run(run_data) + trace_instance.langsmith_client.create_run.assert_called_once() + args, kwargs = trace_instance.langsmith_client.create_run.call_args + assert kwargs["session_id"] == "proj-1" + + +def test_add_run_error(trace_instance): + run_data = LangSmithRunModel(id="run-1", name="test", run_type=LangSmithRunType.tool, start_time=_dt()) + trace_instance.langsmith_client.create_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to create run: failed"): + trace_instance.add_run(run_data) + + +def test_update_run_success(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1", outputs={"out": "val"}) + trace_instance.update_run(update_data) + trace_instance.langsmith_client.update_run.assert_called_once() + + +def test_update_run_error(trace_instance): + update_data = LangSmithRunUpdateModel(run_id="run-1") + trace_instance.langsmith_client.update_run.side_effect = Exception("failed") + with pytest.raises(ValueError, match="LangSmith Failed to update run: failed"): + trace_instance.update_run(update_data) + + +def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, caplog): + workflow_data = MagicMock() + workflow_data.created_at = _dt() + workflow_data.finished_at = _dt() + timedelta(seconds=1) + + trace_info = WorkflowTraceInfo( + tenant_id="tenant-1", + workflow_id="wf-1", + workflow_run_id="run-1", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_status="succeeded", + workflow_run_version="1.0", + workflow_run_elapsed_time=1.0, + total_tokens=100, + file_list=[], + query="hi", + message_id="msg-1", + conversation_id="conv-1", + start_time=_dt(), + end_time=_dt(), + trace_id="trace-1", + metadata={"app_id": "app-1"}, + workflow_app_log_id="log-1", + error="", + workflow_data=workflow_data, + ) + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node_llm = MagicMock() + node_llm.id = "node-llm" + node_llm.title = "LLM Node" + node_llm.node_type = BuiltinNodeTypes.LLM + node_llm.status = "succeeded" + node_llm.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node_llm.inputs = {} + node_llm.outputs = {} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_run = MagicMock() + + import logging + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + + +def test_api_check_success(trace_instance): + assert trace_instance.api_check() is True + assert trace_instance.langsmith_client.create_project.called + assert trace_instance.langsmith_client.delete_project.called + + +def test_api_check_error(trace_instance): + trace_instance.langsmith_client.create_project.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith API check failed: error"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.langsmith_client.get_run_url.return_value = "https://smith.langchain.com/o/org/p/proj/r/run" + url = trace_instance.get_project_url() + assert url == "https://smith.langchain.com/o/org/p/proj" + + +def test_get_project_url_error(trace_instance): + trace_instance.langsmith_client.get_run_url.side_effect = Exception("error") + with pytest.raises(ValueError, match="LangSmith get run url failed: error"): + trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py new file mode 100644 index 0000000000..cccedaa08c --- /dev/null +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -0,0 +1,1019 @@ +"""Comprehensive tests for core.ops.mlflow_trace.mlflow_trace module.""" + +from __future__ import annotations + +import json +import os +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds +from dify_graph.enums import BuiltinNodeTypes + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "conversation_id": "c1"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1", "from_account_id": "a1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace( + model_provider="openai", + model_id="gpt-4", + total_price=0.01, + answer="response text", + ), + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(), + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt(), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt(), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution row.""" + defaults = { + "id": "node-1", + "tenant_id": "t1", + "app_id": "app-1", + "title": "Node Title", + "node_type": BuiltinNodeTypes.CODE, + "status": "succeeded", + "inputs": '{"key": "value"}', + "outputs": '{"result": "ok"}', + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "execution_metadata": None, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_mlflow(): + with patch("core.ops.mlflow_trace.mlflow_trace.mlflow") as mock: + yield mock + + +@pytest.fixture +def mock_tracing(): + """Patch all MLflow tracing functions used by the module.""" + with ( + patch("core.ops.mlflow_trace.mlflow_trace.start_span_no_context") as mock_start, + patch("core.ops.mlflow_trace.mlflow_trace.update_current_trace") as mock_update, + patch("core.ops.mlflow_trace.mlflow_trace.set_span_in_context") as mock_set, + patch("core.ops.mlflow_trace.mlflow_trace.detach_span_from_context") as mock_detach, + ): + yield { + "start": mock_start, + "update": mock_update, + "set": mock_set, + "detach": mock_detach, + } + + +@pytest.fixture +def mock_db(): + with patch("core.ops.mlflow_trace.mlflow_trace.db") as mock: + yield mock + + +@pytest.fixture +def trace_instance(mock_mlflow): + """Create an MLflowDataTrace using a basic MLflowConfig (no auth).""" + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + return MLflowDataTrace(config) + + +# ── datetime_to_nanoseconds ───────────────────────────────────────────────── + + +class TestDatetimeToNanoseconds: + def test_none_returns_none(self): + assert datetime_to_nanoseconds(None) is None + + def test_converts_datetime(self): + dt = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(dt.timestamp() * 1_000_000_000) + assert datetime_to_nanoseconds(dt) == expected + + +# ── __init__ / setup ───────────────────────────────────────────────────────── + + +class TestInit: + def test_mlflow_config_no_auth(self, mock_mlflow): + config = MLflowConfig(tracking_uri="http://localhost:5000", experiment_id="0") + trace = MLflowDataTrace(config) + mock_mlflow.set_tracking_uri.assert_called_with("http://localhost:5000") + mock_mlflow.set_experiment.assert_called_with(experiment_id="0") + assert trace.get_project_url() == "http://localhost:5000/#/experiments/0/traces" + assert os.environ["MLFLOW_ENABLE_ASYNC_TRACE_LOGGING"] == "true" + + def test_mlflow_config_with_auth(self, mock_mlflow): + config = MLflowConfig( + tracking_uri="http://localhost:5000", + experiment_id="1", + username="user", + password="pass", + ) + MLflowDataTrace(config) + assert os.environ["MLFLOW_TRACKING_USERNAME"] == "user" + assert os.environ["MLFLOW_TRACKING_PASSWORD"] == "pass" + + def test_databricks_oauth(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com/", + experiment_id="42", + client_id="cid", + client_secret="csec", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_HOST"] == "https://db.com/" + assert os.environ["DATABRICKS_CLIENT_ID"] == "cid" + assert os.environ["DATABRICKS_CLIENT_SECRET"] == "csec" + mock_mlflow.set_tracking_uri.assert_called_with("databricks") + # Trailing slash stripped + assert trace.get_project_url() == "https://db.com/ml/experiments/42/traces" + + def test_databricks_pat(self, mock_mlflow): + config = DatabricksConfig( + host="https://db.com", + experiment_id="1", + personal_access_token="pat", + ) + trace = MLflowDataTrace(config) + assert os.environ["DATABRICKS_TOKEN"] == "pat" + assert "db.com/ml/experiments/1/traces" in trace.get_project_url() + + def test_databricks_no_creds_raises(self, mock_mlflow): + config = DatabricksConfig(host="https://db.com", experiment_id="1") + with pytest.raises(ValueError, match="Either Databricks token"): + MLflowDataTrace(config) + + +# ── trace dispatcher ──────────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_tool(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "tool_trace") as mock_tt: + trace_instance.trace(_make_tool_trace_info()) + mock_tt.assert_called_once() + + def test_dispatches_moderation(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + trace_instance.trace(_make_moderation_trace_info(message_data=SimpleNamespace(created_at=_dt()))) + mock_mod.assert_called_once() + + def test_dispatches_dataset_retrieval(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_suggested_question(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_generate_name(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + def test_reraises_exception(self, trace_instance, mock_tracing, mock_db): + with patch.object(trace_instance, "workflow_trace", side_effect=RuntimeError("boom")): + with pytest.raises(RuntimeError, match="boom"): + trace_instance.trace(_make_workflow_trace_info()) + + +# ── workflow_trace ─────────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(conversation_id="sess-1") + trace_instance.workflow_trace(trace_info) + + # Workflow span started and ended + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + workflow_run_inputs={"sys.app_id": "x", "user_input": "hi"}, + query="hello", + ) + trace_instance.workflow_trace(trace_info) + + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "sys.app_id" not in inputs + assert inputs["user_input"] == "hi" + assert inputs["query"] == "hello" + + def test_workflow_with_llm_node(self, trace_instance, mock_tracing, mock_db): + llm_node = _make_node( + node_type=BuiltinNodeTypes.LLM, + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ), + outputs='{"text": "hello world"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + node_span.end.assert_called_once() + workflow_span.end.assert_called_once() + + def test_workflow_with_question_classifier_node(self, trace_instance, mock_tracing, mock_db): + qc_node = _make_node( + node_type=BuiltinNodeTypes.QUESTION_CLASSIFIER, + process_data=json.dumps( + { + "prompts": "classify this", + "model_name": "gpt-4", + "model_provider": "openai", + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + assert mock_tracing["start"].call_count == 2 + + def test_workflow_with_http_request_node(self, trace_instance, mock_tracing, mock_db): + http_node = _make_node( + node_type=BuiltinNodeTypes.HTTP_REQUEST, + process_data='{"url": "https://api.com"}', + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # HTTP_REQUEST uses process_data as inputs + node_start_call = mock_tracing["start"].call_args_list[1] + assert node_start_call.kwargs["inputs"] == '{"url": "https://api.com"}' + + def test_workflow_with_knowledge_retrieval_node(self, trace_instance, mock_tracing, mock_db): + kr_node = _make_node( + node_type=BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + outputs=json.dumps( + { + "result": [ + {"content": "doc1", "metadata": {"source": "s1"}}, + {"content": "doc2", "metadata": {}}, + ] + } + ), + ) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + # outputs should be parsed to Document objects + end_call = node_span.end.call_args + outputs = end_call.kwargs["outputs"] + assert len(outputs) == 2 + + def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): + failed_node = _make_node(status="failed") + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_span.set_status.assert_called_once() + node_span.add_event.assert_called_once() + + def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + workflow_span = MagicMock() + mock_tracing["start"].return_value = workflow_span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(error="workflow failed") + trace_instance.workflow_trace(trace_info) + workflow_span.set_status.assert_called_once() + workflow_span.add_event.assert_called_once() + # Still ends the span via finally + workflow_span.end.assert_called_once() + + def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): + node = _make_node(inputs=None, outputs=None) + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + workflow_span = MagicMock() + node_span = MagicMock() + mock_tracing["start"].side_effect = [workflow_span, node_span] + mock_tracing["set"].return_value = "token" + + trace_instance.workflow_trace(_make_workflow_trace_info()) + node_call = mock_tracing["start"].call_args_list[1] + assert node_call.kwargs["inputs"] == {} + end_call = node_span.end.call_args + assert end_call.kwargs["outputs"] == {} + + def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info( + metadata={}, + conversation_id=None, + ) + trace_instance.workflow_trace(trace_info) + # _set_trace_metadata still called with empty metadata + mock_tracing["update"].assert_called_once() + + def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): + """When query is empty string, it's falsy so no query key added.""" + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + trace_info = _make_workflow_trace_info(query="") + trace_instance.workflow_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + inputs = call_kwargs.kwargs["inputs"] + assert "query" not in inputs + + +# ── _parse_llm_inputs_and_attributes ───────────────────────────────────────── + + +class TestParseLlmInputsAndAttributes: + def test_none_process_data(self, trace_instance): + node = _make_node(process_data=None) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_invalid_json(self, trace_instance): + node = _make_node(process_data="not json") + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == {} + assert attrs == {} + + def test_valid_process_data_with_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": [{"role": "user", "text": "hi"}], + "model_name": "gpt-4", + "model_provider": "openai", + "finish_reason": "stop", + "usage": {"prompt_tokens": 5, "completion_tokens": 10, "total_tokens": 15}, + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert isinstance(inputs, list) + assert attrs["model_name"] == "gpt-4" + assert "usage" in attrs + + def test_valid_process_data_without_usage(self, trace_instance): + node = _make_node( + process_data=json.dumps( + { + "prompts": "simple prompt", + "model_name": "gpt-3.5", + } + ) + ) + inputs, attrs = trace_instance._parse_llm_inputs_and_attributes(node) + assert inputs == "simple prompt" + assert attrs["model_name"] == "gpt-3.5" + + +# ── _parse_knowledge_retrieval_outputs ─────────────────────────────────────── + + +class TestParseKnowledgeRetrievalOutputs: + def test_with_results(self, trace_instance): + outputs = {"result": [{"content": "c1", "metadata": {"s": "1"}}]} + docs = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert len(docs) == 1 + assert docs[0].page_content == "c1" + + def test_empty_result(self, trace_instance): + outputs = {"result": []} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_no_result_key(self, trace_instance): + outputs = {"other": "data"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + def test_result_not_list(self, trace_instance): + outputs = {"result": "not a list"} + result = trace_instance._parse_knowledge_retrieval_outputs(outputs) + assert result == outputs + + +# ── message_trace ──────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing, mock_db): + trace_info = _make_message_trace_info(message_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_message_trace(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_instance.message_trace(_make_message_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_message_trace_with_error(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(error="something broke") + trace_instance.message_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_message_trace_with_file_data(self, trace_instance, mock_tracing, mock_db, monkeypatch): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setenv("FILES_URL", "http://files.test") + + file_data = SimpleNamespace(url="path/to/file.png") + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing_file.txt"], + ) + trace_instance.message_trace(trace_info) + call_kwargs = mock_tracing["start"].call_args + attrs = call_kwargs.kwargs["attributes"] + assert "http://files.test/path/to/file.png" in attrs["file_list"] + assert "existing_file.txt" in attrs["file_list"] + + def test_message_trace_file_list_none(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + mock_tracing["start"].assert_called_once() + + def test_message_trace_with_end_user(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + + end_user = MagicMock() + end_user.session_id = "session-xyz" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + + trace_info = _make_message_trace_info( + metadata={"from_end_user_id": "eu-1", "conversation_id": "c1"}, + ) + trace_instance.message_trace(trace_info) + # update_current_trace called with user id from EndUser + mock_tracing["update"].assert_called_once() + + def test_message_trace_with_no_conversation_id(self, trace_instance, mock_tracing, mock_db): + span = MagicMock() + mock_tracing["start"].return_value = span + mock_tracing["set"].return_value = "token" + mock_db.session.query.return_value.where.return_value.first.return_value = None + + trace_info = _make_message_trace_info( + metadata={"from_account_id": "acc-1"}, + ) + trace_instance.message_trace(trace_info) + mock_tracing["update"].assert_called_once() + + +# ── _get_message_user_id ───────────────────────────────────────────────────── + + +class TestGetMessageUserId: + def test_returns_end_user_session_id(self, trace_instance, mock_db): + end_user = MagicMock() + end_user.session_id = "session-1" + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) + assert result == "session-1" + + def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_account_id_when_no_end_user_id(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({"from_account_id": "acc-1"}) + assert result == "acc-1" + + def test_returns_none_when_nothing(self, trace_instance, mock_db): + result = trace_instance._get_message_user_id({}) + assert result is None + + +# ── tool_trace ─────────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + span.set_status.assert_not_called() + + def test_tool_trace_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.tool_trace(_make_tool_trace_info(error="tool failed")) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + span.end.assert_called_once() + + +# ── moderation_trace ───────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_moderation_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=_dt(), + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + end_kwargs = span.end.call_args.kwargs["outputs"] + assert end_kwargs["action"] == "allow" + assert end_kwargs["flagged"] is False + + def test_moderation_uses_message_data_created_at_if_no_start_time(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_moderation_trace_info( + message_data=SimpleNamespace(created_at=_dt()), + start_time=None, + end_time=_dt(), + ) + trace_instance.moderation_trace(trace_info) + mock_tracing["start"].assert_called_once() + + +# ── dataset_retrieval_trace ────────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.dataset_retrieval_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.dataset_retrieval_trace(_make_dataset_retrieval_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── suggested_question_trace ───────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_if_no_message_data(self, trace_instance, mock_tracing): + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.suggested_question_trace(_make_suggested_question_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + def test_suggested_question_with_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info(error="failed") + trace_instance.suggested_question_trace(trace_info) + span.set_status.assert_called_once() + span.add_event.assert_called_once() + + def test_uses_message_data_times_when_no_start_end(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_info = _make_suggested_question_trace_info( + start_time=None, + end_time=None, + ) + trace_instance.suggested_question_trace(trace_info) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── generate_name_trace ────────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["start"].return_value = span + + trace_instance.generate_name_trace(_make_generate_name_trace_info()) + mock_tracing["start"].assert_called_once() + span.end.assert_called_once() + + +# ── _get_workflow_nodes ────────────────────────────────────────────────────── + + +class TestGetWorkflowNodes: + def test_queries_db(self, trace_instance, mock_db): + mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + result = trace_instance._get_workflow_nodes("run-1") + assert result == ["n1", "n2"] + + +# ── _get_node_span_type ───────────────────────────────────────────────────── + + +class TestGetNodeSpanType: + @pytest.mark.parametrize( + ("node_type", "expected_contains"), + [ + (BuiltinNodeTypes.LLM, "LLM"), + (BuiltinNodeTypes.QUESTION_CLASSIFIER, "LLM"), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "RETRIEVER"), + (BuiltinNodeTypes.TOOL, "TOOL"), + (BuiltinNodeTypes.CODE, "TOOL"), + (BuiltinNodeTypes.HTTP_REQUEST, "TOOL"), + (BuiltinNodeTypes.AGENT, "AGENT"), + ], + ) + def test_mapped_types(self, trace_instance, node_type, expected_contains): + result = trace_instance._get_node_span_type(node_type) + assert expected_contains in str(result) + + def test_unknown_type_returns_chain(self, trace_instance): + result = trace_instance._get_node_span_type("unknown_node") + assert result == "CHAIN" + + +# ── _set_trace_metadata ───────────────────────────────────────────────────── + + +class TestSetTraceMetadata: + def test_sets_and_detaches(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + + trace_instance._set_trace_metadata(span, {"key": "val"}) + mock_tracing["set"].assert_called_once_with(span) + mock_tracing["update"].assert_called_once_with(metadata={"key": "val"}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_detaches_even_on_error(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = "token" + mock_tracing["update"].side_effect = RuntimeError("fail") + + with pytest.raises(RuntimeError): + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_called_once_with("token") + + def test_no_detach_when_token_is_none(self, trace_instance, mock_tracing): + span = MagicMock() + mock_tracing["set"].return_value = None + + trace_instance._set_trace_metadata(span, {}) + mock_tracing["detach"].assert_not_called() + + +# ── _parse_prompts ─────────────────────────────────────────────────────────── + + +class TestParsePrompts: + def test_string_input(self, trace_instance): + assert trace_instance._parse_prompts("hello") == "hello" + + def test_dict_input(self, trace_instance): + result = trace_instance._parse_prompts({"role": "user", "text": "hi"}) + assert result == {"role": "user", "content": "hi"} + + def test_list_input(self, trace_instance): + prompts = [ + {"role": "user", "text": "hi"}, + {"role": "assistant", "text": "hello"}, + ] + result = trace_instance._parse_prompts(prompts) + assert len(result) == 2 + assert result[0]["role"] == "user" + + def test_none_input(self, trace_instance): + assert trace_instance._parse_prompts(None) is None + + def test_int_passthrough(self, trace_instance): + assert trace_instance._parse_prompts(42) == 42 + + +# ── _parse_single_message ─────────────────────────────────────────────────── + + +class TestParseSingleMessage: + def test_basic_message(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hello"}) + assert result == {"role": "user", "content": "hello"} + + def test_default_role(self, trace_instance): + result = trace_instance._parse_single_message({"text": "hello"}) + assert result["role"] == "user" + + def test_with_tool_calls(self, trace_instance): + item = { + "role": "assistant", + "text": "", + "tool_calls": [{"id": "tc1", "function": {"name": "fn"}}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" in result + + def test_tool_role_ignores_tool_calls(self, trace_instance): + item = { + "role": "tool", + "text": "result", + "tool_calls": [{"id": "tc1"}], + } + result = trace_instance._parse_single_message(item) + assert "tool_calls" not in result + + def test_with_files(self, trace_instance): + item = {"role": "user", "text": "look", "files": ["f1.png"]} + result = trace_instance._parse_single_message(item) + assert result["files"] == ["f1.png"] + + def test_no_files(self, trace_instance): + result = trace_instance._parse_single_message({"role": "user", "text": "hi"}) + assert "files" not in result + + +# ── _resolve_tool_call_ids ─────────────────────────────────────────────────── + + +class TestResolveToolCallIds: + def test_resolves_tool_call_ids(self, trace_instance): + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"id": "tc1"}, {"id": "tc2"}], + }, + {"role": "tool", "content": "result1"}, + {"role": "tool", "content": "result2"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert result[1]["tool_call_id"] == "tc1" + assert result[2]["tool_call_id"] == "tc2" + + def test_no_tool_calls(self, trace_instance): + messages = [ + {"role": "user", "content": "hi"}, + {"role": "assistant", "content": "hello"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + assert "tool_call_id" not in result[1] + + def test_tool_message_no_ids_available(self, trace_instance): + """Tool message with no preceding tool_calls should not crash.""" + messages = [ + {"role": "tool", "content": "result"}, + ] + result = trace_instance._resolve_tool_call_ids(messages) + assert "tool_call_id" not in result[0] + + +# ── api_check ──────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_success(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.return_value = [] + assert trace_instance.api_check() is True + + def test_failure(self, trace_instance, mock_mlflow): + mock_mlflow.search_experiments.side_effect = ConnectionError("refused") + with pytest.raises(ValueError, match="MLflow connection failed"): + trace_instance.api_check() + + +# ── get_project_url ────────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_returns_url(self, trace_instance): + assert "experiments" in trace_instance.get_project_url() diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py new file mode 100644 index 0000000000..b2cb7d5109 --- /dev/null +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -0,0 +1,678 @@ +import collections +import logging +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.entities.config_entity import OpikConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey +from models import EndUser +from models.enums import MessageStatus + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +@pytest.fixture +def opik_config(): + return OpikConfig( + project="test-project", workspace="test-workspace", url="https://cloud.opik.com/api/", api_key="api-key-123" + ) + + +@pytest.fixture +def trace_instance(opik_config, monkeypatch): + mock_client = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", lambda **kwargs: mock_client) + + instance = OpikDataTrace(opik_config) + return instance + + +def test_wrap_dict(): + assert wrap_dict("input", {"a": 1}) == {"a": 1} + assert wrap_dict("input", "hello") == {"input": "hello"} + + +def test_wrap_metadata(): + assert wrap_metadata({"a": 1}, b=2) == {"a": 1, "b": 2, "created_from": "dify"} + + +def test_prepare_opik_uuid(): + # Test with valid datetime and uuid string + dt = datetime(2024, 1, 1) + uuid_str = "b3e8e918-472e-4b69-8051-12502c34fc07" + result = prepare_opik_uuid(dt, uuid_str) + assert result is not None + # We won't test the exact uuid7 value but just that it returns a string id + + # Test with None dt and uuid_str + result = prepare_opik_uuid(None, None) + assert result is not None + + +def test_init(opik_config, monkeypatch): + mock_opik = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.Opik", mock_opik) + monkeypatch.setenv("FILES_URL", "http://test.url") + + instance = OpikDataTrace(opik_config) + + mock_opik.assert_called_once_with( + project_name=opik_config.project, + workspace=opik_config.workspace, + host=opik_config.url, + api_key=opik_config.api_key, + ) + assert instance.file_base_url == "http://test.url" + assert instance.project == opik_config.project + + +def test_trace_dispatch(trace_instance, monkeypatch): + methods = [ + "workflow_trace", + "message_trace", + "moderation_trace", + "suggested_question_trace", + "dataset_retrieval_trace", + "tool_trace", + "generate_name_trace", + ] + mocks = {method: MagicMock() for method in methods} + for method, m in mocks.items(): + monkeypatch.setattr(trace_instance, method, m) + + # WorkflowTraceInfo + info = MagicMock(spec=WorkflowTraceInfo) + trace_instance.trace(info) + mocks["workflow_trace"].assert_called_once_with(info) + + # MessageTraceInfo + info = MagicMock(spec=MessageTraceInfo) + trace_instance.trace(info) + mocks["message_trace"].assert_called_once_with(info) + + # ModerationTraceInfo + info = MagicMock(spec=ModerationTraceInfo) + trace_instance.trace(info) + mocks["moderation_trace"].assert_called_once_with(info) + + # SuggestedQuestionTraceInfo + info = MagicMock(spec=SuggestedQuestionTraceInfo) + trace_instance.trace(info) + mocks["suggested_question_trace"].assert_called_once_with(info) + + # DatasetRetrievalTraceInfo + info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_instance.trace(info) + mocks["dataset_retrieval_trace"].assert_called_once_with(info) + + # ToolTraceInfo + info = MagicMock(spec=ToolTraceInfo) + trace_instance.trace(info) + mocks["tool_trace"].assert_called_once_with(info) + + # GenerateNameTraceInfo + info = MagicMock(spec=GenerateNameTraceInfo) + trace_instance.trace(info) + mocks["generate_name_trace"].assert_called_once_with(info) + + +def test_workflow_trace_with_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "fb05c7cd-6cec-4add-8a84-df03a408b4ce" + WORKFLOW_RUN_ID = "33c67568-7a8a-450e-8916-a5f135baeaef" + MESSAGE_ID = "04ec3956-85f3-488a-8539-1017251dc8c6" + CONVERSATION_ID = "d3d01066-23ae-4830-9ce4-eb5640b42a7e" + TRACE_ID = "bf26d929-6f15-4c2f-9abc-761c217056f3" + WORKFLOW_APP_LOG_ID = "ca0e018e-edd4-43fb-a05a-ea001ca8ef4b" + LLM_NODE_ID = "80d7dfa8-08f4-4ab7-aa37-0ca7d27207e3" + CODE_NODE_ID = "b9cd9a7b-c534-4aa9-b5da-efd454140900" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={"input": "hi"}, + workflow_run_outputs={"output": "hello"}, + workflow_run_version="1.0", + message_id=MESSAGE_ID, + conversation_id=CONVERSATION_ID, + total_tokens=100, + file_list=[], + query="hi", + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=TRACE_ID, + metadata={"app_id": "app-1", "user_id": "user-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + mock_session = MagicMock() + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: mock_session) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + node_llm = MagicMock() + node_llm.id = LLM_NODE_ID + node_llm.title = "LLM Node" + node_llm.node_type = BuiltinNodeTypes.LLM + node_llm.status = "succeeded" + node_llm.process_data = { + "model_mode": "chat", + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + } + node_llm.inputs = {"prompts": "p"} + node_llm.outputs = {"text": "t"} + node_llm.created_at = _dt() + node_llm.elapsed_time = 0.5 + node_llm.metadata = {"foo": "bar"} + + node_other = MagicMock() + node_other.id = CODE_NODE_ID + node_other.title = "Other Node" + node_other.node_type = BuiltinNodeTypes.CODE + node_other.status = "failed" + node_other.process_data = None + node_other.inputs = {"code": "print"} + node_other.outputs = {"result": "ok"} + node_other.created_at = None + node_other.elapsed_time = 0.2 + node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node_llm, node_other] + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_data = trace_instance.add_trace.call_args[1].get("opik_trace_data", trace_instance.add_trace.call_args[0][0]) + assert trace_data["name"] == TraceTaskName.MESSAGE_TRACE + assert "message" in trace_data["tags"] + assert "workflow" in trace_data["tags"] + + assert trace_instance.add_span.call_count >= 1 + + +def test_workflow_trace_no_message_id(trace_instance, monkeypatch): + # Define constants for better readability + WORKFLOW_ID = "f0708b36-b1d7-42b3-a876-1d01b7d8f1a3" + WORKFLOW_RUN_ID = "d42ec285-c2fd-4248-8866-5c9386b101ac" + CONVERSATION_ID = "88a17f2e-9436-4472-bab9-4b1601d5af3c" + WORKFLOW_APP_LOG_ID = "41780d0d-ffba-4220-bc0c-401e4c89cdfb" + + trace_info = WorkflowTraceInfo( + workflow_id=WORKFLOW_ID, + tenant_id="tenant-1", + workflow_run_id=WORKFLOW_RUN_ID, + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id=CONVERSATION_ID, + start_time=_dt(), + end_time=_dt(), + trace_id=None, + metadata={"app_id": "app-1"}, + workflow_app_log_id=WORKFLOW_APP_LOG_ID, + error="", + ) + + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + repo = MagicMock() + repo.get_by_workflow_run.return_value = [] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.workflow_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + + +def test_workflow_trace_missing_app_id(trace_instance, monkeypatch): + trace_info = WorkflowTraceInfo( + workflow_id="5745f1b8-f8e6-4859-8110-996acb6c8d6a", + tenant_id="tenant-1", + workflow_run_id="46f53304-1659-464b-bee5-116585f0bec8", + workflow_run_elapsed_time=1.0, + workflow_run_status="succeeded", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1.0", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="83f86b89-caef-4de8-a0f9-f164eddae1ea", + start_time=_dt(), + end_time=_dt(), + metadata={}, + workflow_app_log_id="339760b2-4b94-4532-8c81-133a97e4680e", + error="", + ) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + +def test_message_trace_basic(trace_instance, monkeypatch): + # Define constants for better readability + MESSAGE_DATA_ID = "e3a26712-8cac-4a25-94a4-a3bff21ee3ab" + CONVERSATION_ID = "9d3f3751-7521-4c19-9307-20e3cf6789a3" + MESSAGE_TRACE_ID = "710ace2f-bca8-41be-858c-54da42742a77" + OPIT_TRACE_ID = "f7dfd978-0d10-4549-8abf-00f2cbc49d2c" + + message_data = MagicMock() + message_data.id = MESSAGE_DATA_ID + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = None + message_data.provider_response_latency = 0.5 + message_data.conversation_id = CONVERSATION_ID + message_data.total_price = 0.01 + message_data.model_id = "gpt-4" + message_data.answer = "hello" + message_data.status = MessageStatus.NORMAL + message_data.error = None + + trace_info = MessageTraceInfo( + message_id=MESSAGE_TRACE_ID, + message_data=message_data, + inputs={"query": "hi"}, + outputs={"answer": "hello"}, + message_tokens=10, + answer_tokens=20, + total_tokens=30, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + trace_id=OPIT_TRACE_ID, + metadata={"foo": "bar"}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=[], + error=None, + message_file_data=MagicMock(url="test.png"), + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_1")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + +def test_message_trace_with_end_user(trace_instance, monkeypatch): + message_data = MagicMock() + message_data.id = "85411059-79fb-4deb-a76c-c2e215f1b97e" + message_data.from_account_id = "acc-1" + message_data.from_end_user_id = "end-user-1" + message_data.conversation_id = "7d9f96d8-3be2-4e93-9c0e-922ff98dccc6" + message_data.status = MessageStatus.NORMAL + message_data.model_id = "gpt-4" + message_data.error = "" + message_data.answer = "hello" + message_data.total_price = 0.0 + message_data.provider_response_latency = 0.1 + + trace_info = MessageTraceInfo( + message_id="6bff35c7-33b7-4acb-ba21-44569a0327d0", + message_data=message_data, + inputs={}, + outputs={}, + message_tokens=0, + answer_tokens=0, + total_tokens=0, + start_time=_dt(), + end_time=_dt(), + metadata={}, + conversation_mode="chat", + conversation_model="gpt-4", + file_list=["url1"], + error=None, + ) + + mock_end_user = MagicMock(spec=EndUser) + mock_end_user.session_id = "session-id-123" + + mock_query = MagicMock() + mock_query.where.return_value.first.return_value = mock_end_user + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) + trace_instance.add_span = MagicMock() + + trace_instance.message_trace(trace_info) + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["metadata"]["user_id"] == "acc-1" + assert trace_data["metadata"]["end_user_id"] == "session-id-123" + + +def test_message_trace_none_data(trace_instance): + trace_info = SimpleNamespace(message_data=None, file_list=[], message_file_data=None, metadata={}) + trace_instance.add_trace = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.add_trace.assert_not_called() + + +def test_moderation_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = ModerationTraceInfo( + message_id="489d0dfd-065c-4106-8f9c-daded296c92d", + message_data=message_data, + inputs={"q": "hi"}, + action="stop", + flagged=True, + preset_response="blocked", + start_time=None, + end_time=None, + metadata={"foo": "bar"}, + trace_id="6f16cf18-9f4b-4955-8b6b-43cfa10978fc", + query="hi", + ) + + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.MODERATION_TRACE + assert span_data["output"]["flagged"] is True + + +def test_moderation_trace_none(trace_instance): + trace_info = ModerationTraceInfo( + message_id="cd732e4e-37f1-4c7e-8c64-820308bedcbf", + message_data=None, + inputs={}, + action="s", + flagged=False, + preset_response="", + query="", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_suggested_question_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = SuggestedQuestionTraceInfo( + message_id="7de55bda-a91d-477e-98ab-85c53c438469", + message_data=message_data, + inputs="hi", + suggested_question=["q1"], + total_tokens=10, + level="info", + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a6687292-68c7-42ba-ae51-285579944d7b", + ) + + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.SUGGESTED_QUESTION_TRACE + + +def test_suggested_question_trace_none(trace_instance): + trace_info = SuggestedQuestionTraceInfo( + message_id="23696fc5-7e7f-46ec-bce8-1adc3c7f297d", + message_data=None, + inputs={}, + suggested_question=[], + total_tokens=0, + level="i", + metadata={}, + ) + trace_instance.add_span = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_dataset_retrieval_trace(trace_instance): + message_data = MagicMock() + message_data.created_at = _dt() + message_data.updated_at = _dt() + + trace_info = DatasetRetrievalTraceInfo( + message_id="3e1a819f-c391-4950-adfd-96f82e5419a1", + message_data=message_data, + inputs="query", + documents=[{"id": "doc1"}], + start_time=None, + end_time=None, + metadata={}, + trace_id="41361000-e9be-4d11-b5e4-ab27ce0817d6", + ) + + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == TraceTaskName.DATASET_RETRIEVAL_TRACE + + +def test_dataset_retrieval_trace_none(trace_instance): + trace_info = DatasetRetrievalTraceInfo( + message_id="35d6d44c-bccb-4e6e-8bd8-859257723ea8", message_data=None, inputs={}, documents=[], metadata={} + ) + trace_instance.add_span = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.add_span.assert_not_called() + + +def test_tool_trace(trace_instance): + trace_info = ToolTraceInfo( + message_id="99db92c4-2254-496a-b5cc-18153315ce35", + message_data=MagicMock(), + inputs={}, + outputs={}, + tool_name="my_tool", + tool_inputs={"a": 1}, + tool_outputs="result_string", + time_cost=0.1, + start_time=_dt(), + end_time=_dt(), + metadata={}, + trace_id="a15a5fcb-7ffd-4458-8330-208f4cb1f796", + tool_config={}, + tool_parameters={}, + error="some error", + ) + + trace_instance.add_span = MagicMock() + trace_instance.tool_trace(trace_info) + + trace_instance.add_span.assert_called_once() + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["name"] == "my_tool" + + +def test_generate_name_trace(trace_instance): + trace_info = GenerateNameTraceInfo( + inputs={"q": "hi"}, + outputs={"name": "new"}, + tenant_id="tenant-1", + conversation_id="271fe28f-6b86-416b-8d6b-bbbbfa9db791", + start_time=_dt(), + end_time=_dt(), + metadata={"921f010e-6878-4831-ae6b-271bf68c56fb": 1}, + ) + + trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_3")) + trace_instance.add_span = MagicMock() + + trace_instance.generate_name_trace(trace_info) + + trace_instance.add_trace.assert_called_once() + trace_instance.add_span.assert_called_once() + + trace_data = trace_instance.add_trace.call_args[0][0] + assert trace_data["name"] == TraceTaskName.GENERATE_NAME_TRACE + + span_data = trace_instance.add_span.call_args[0][0] + assert span_data["trace_id"] == "trace_id_3" + + +def test_add_trace_success(trace_instance): + trace_data = {"id": "t1", "name": "trace"} + trace_instance.opik_client.trace.return_value = MagicMock(id="t1") + trace = trace_instance.add_trace(trace_data) + trace_instance.opik_client.trace.assert_called_once() + assert trace.id == "t1" + + +def test_add_trace_error(trace_instance): + trace_instance.opik_client.trace.side_effect = Exception("error") + trace_data = {"id": "t1", "name": "trace"} + with pytest.raises(ValueError, match="Opik Failed to create trace: error"): + trace_instance.add_trace(trace_data) + + +def test_add_span_success(trace_instance): + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + trace_instance.add_span(span_data) + trace_instance.opik_client.span.assert_called_once() + + +def test_add_span_error(trace_instance): + trace_instance.opik_client.span.side_effect = Exception("error") + span_data = {"id": "s1", "name": "span", "trace_id": "t1"} + with pytest.raises(ValueError, match="Opik Failed to create span: error"): + trace_instance.add_span(span_data) + + +def test_api_check_success(trace_instance): + trace_instance.opik_client.auth_check.return_value = True + assert trace_instance.api_check() is True + + +def test_api_check_error(trace_instance): + trace_instance.opik_client.auth_check.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik API check failed: fail"): + trace_instance.api_check() + + +def test_get_project_url_success(trace_instance): + trace_instance.opik_client.get_project_url.return_value = "http://project.url" + assert trace_instance.get_project_url() == "http://project.url" + trace_instance.opik_client.get_project_url.assert_called_once_with(project_name=trace_instance.project) + + +def test_get_project_url_error(trace_instance): + trace_instance.opik_client.get_project_url.side_effect = Exception("fail") + with pytest.raises(ValueError, match="Opik get run url failed: fail"): + trace_instance.get_project_url() + + +def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch, caplog): + trace_info = WorkflowTraceInfo( + workflow_id="86a52565-4a6b-4a1b-9bfd-98e4595e70de", + tenant_id="66e8e918-472e-4b69-8051-12502c34fc07", + workflow_run_id="8403965c-3344-4d22-a8fe-d8d55cee64d9", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="1", + total_tokens=0, + file_list=[], + query="", + message_id=None, + conversation_id="7a02cb9d-6949-4c59-a89d-f25bbc881e0e", + start_time=_dt(), + end_time=_dt(), + metadata={"app_id": "77e8e918-472e-4b69-8051-12502c34fc07"}, + workflow_app_log_id="82268424-e193-476c-a6db-f473388ee5fe", + error="", + ) + + node = MagicMock() + node.id = "88e8e918-472e-4b69-8051-12502c34fc07" + node.title = "LLM Node" + node.node_type = BuiltinNodeTypes.LLM + node.status = "succeeded" + + class BadDict(collections.UserDict): + def get(self, key, default=None): + if key == "usage": + raise Exception("Usage extraction failed") + return super().get(key, default) + + node.process_data = BadDict({"model_mode": "chat", "model_name": "gpt-4", "usage": True, "prompts": ["p"]}) + node.created_at = _dt() + node.elapsed_time = 0.1 + node.metadata = {} + node.outputs = {} + + repo = MagicMock() + repo.get_by_workflow_run.return_value = [node] + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.add_trace = MagicMock() + trace_instance.add_span = MagicMock() + + with caplog.at_level(logging.ERROR): + trace_instance.workflow_trace(trace_info) + + assert "Failed to extract usage" in caplog.text + assert trace_instance.add_span.call_count >= 1 + # Verify that at least one of the spans is for the LLM Node + span_names = [call.args[0]["name"] for call in trace_instance.add_span.call_args_list] + assert "LLM Node" in span_names diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_client.py b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py new file mode 100644 index 0000000000..870c18e53e --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_client.py @@ -0,0 +1,583 @@ +"""Tests for the TencentTraceClient helpers that drive tracing and metrics.""" + +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from opentelemetry.sdk.trace import Event +from opentelemetry.trace import Status, StatusCode + +from core.ops.tencent_trace import client as client_module +from core.ops.tencent_trace.client import TencentTraceClient, _get_opentelemetry_sdk_version +from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData + +metric_reader_instances: list[DummyMetricReader] = [] +meter_provider_instances: list[DummyMeterProvider] = [] + + +class DummyHistogram: + """Placeholder histogram type used by the stubbed metric stack.""" + + +class AggregationTemporality: + DELTA = "delta" + + +class DummyMeter: + def __init__(self) -> None: + self.created: list[tuple[dict[str, object], MagicMock]] = [] + + def create_histogram(self, **kwargs: object) -> MagicMock: + hist = MagicMock(name=f"hist-{kwargs.get('name')}") + self.created.append((kwargs, hist)) + return hist + + +class DummyMeterProvider: + def __init__(self, resource: object, metric_readers: list[object]) -> None: + self.resource = resource + self.metric_readers = metric_readers + self.meter = DummyMeter() + self.shutdown = MagicMock(name="meter_provider_shutdown") + meter_provider_instances.append(self) + + def get_meter(self, name: str, version: str) -> DummyMeter: + return self.meter + + +class DummyMetricReader: + def __init__(self, exporter: object, export_interval_millis: int) -> None: + self.exporter = exporter + self.export_interval_millis = export_interval_millis + self.shutdown = MagicMock(name="metric_reader_shutdown") + metric_reader_instances.append(self) + + +class DummyGrpcMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyHttpMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporter: + def __init__(self, **kwargs: object) -> None: + self.kwargs = kwargs + + +class DummyJsonMetricExporterNoTemporality: + """Exporter that rejects preferred_temporality to exercise fallback.""" + + def __init__(self, **kwargs: object) -> None: + if "preferred_temporality" in kwargs: + raise RuntimeError("unsupported preferred_temporality") + self.kwargs = kwargs + + +def _add_stub_modules(monkeypatch: pytest.MonkeyPatch) -> None: + """Drop fake metric modules into sys.modules so the client imports resolve.""" + + metrics_module = types.ModuleType("opentelemetry.sdk.metrics") + metrics_module.Histogram = DummyHistogram + metrics_module.MeterProvider = DummyMeterProvider + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics", metrics_module) + + metrics_export_module = types.ModuleType("opentelemetry.sdk.metrics.export") + metrics_export_module.AggregationTemporality = AggregationTemporality + metrics_export_module.PeriodicExportingMetricReader = DummyMetricReader + monkeypatch.setitem(sys.modules, "opentelemetry.sdk.metrics.export", metrics_export_module) + + grpc_module = types.ModuleType("opentelemetry.exporter.otlp.proto.grpc.metric_exporter") + grpc_module.OTLPMetricExporter = DummyGrpcMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.grpc.metric_exporter", grpc_module) + + http_module = types.ModuleType("opentelemetry.exporter.otlp.proto.http.metric_exporter") + http_module.OTLPMetricExporter = DummyHttpMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.proto.http.metric_exporter", http_module) + + http_json_module = types.ModuleType("opentelemetry.exporter.otlp.http.json.metric_exporter") + http_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.http.json.metric_exporter", http_json_module) + + legacy_json_module = types.ModuleType("opentelemetry.exporter.otlp.json.metric_exporter") + legacy_json_module.OTLPMetricExporter = DummyJsonMetricExporter + monkeypatch.setitem(sys.modules, "opentelemetry.exporter.otlp.json.metric_exporter", legacy_json_module) + + +@pytest.fixture(autouse=True) +def stub_metric_modules(monkeypatch: pytest.MonkeyPatch) -> None: + metric_reader_instances.clear() + meter_provider_instances.clear() + _add_stub_modules(monkeypatch) + + +@pytest.fixture(autouse=True) +def patch_core_components(monkeypatch: pytest.MonkeyPatch) -> dict[str, object]: + span_exporter = MagicMock(name="span_exporter") + monkeypatch.setattr(client_module, "OTLPSpanExporter", MagicMock(return_value=span_exporter)) + + span_processor = MagicMock(name="span_processor") + monkeypatch.setattr(client_module, "BatchSpanProcessor", MagicMock(return_value=span_processor)) + + tracer = MagicMock(name="tracer") + span = MagicMock(name="span") + tracer.start_span.return_value = span + + tracer_provider = MagicMock(name="tracer_provider") + tracer_provider.get_tracer.return_value = tracer + tracer_provider.shutdown = MagicMock(name="tracer_provider_shutdown") + monkeypatch.setattr(client_module, "TracerProvider", MagicMock(return_value=tracer_provider)) + + resource = MagicMock(name="resource") + monkeypatch.setattr(client_module, "Resource", MagicMock(return_value=resource)) + + logger_mock = MagicMock(name="tencent_logger") + monkeypatch.setattr(client_module, "logger", logger_mock) + + trace_api_stub = SimpleNamespace( + set_span_in_context=MagicMock(name="set_span_in_context", return_value="trace-context"), + NonRecordingSpan=MagicMock(name="non_recording_span", side_effect=lambda ctx: f"non-{ctx}"), + ) + monkeypatch.setattr(client_module, "trace_api", trace_api_stub) + + fake_config = SimpleNamespace( + project=SimpleNamespace(version="test"), + COMMIT_SHA="sha", + DEPLOY_ENV="dev", + EDITION="cloud", + ) + monkeypatch.setattr(client_module, "dify_config", fake_config) + + monkeypatch.setattr(client_module.socket, "gethostname", lambda: "fake-host") + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "") + + return { + "span_exporter": span_exporter, + "span_processor": span_processor, + "tracer": tracer, + "span": span, + "tracer_provider": tracer_provider, + "logger": logger_mock, + "trace_api": trace_api_stub, + } + + +def _build_client() -> TencentTraceClient: + return TencentTraceClient( + service_name="service", + endpoint="https://trace.example.com:4317", + token="token", + ) + + +def test_get_opentelemetry_sdk_version_reads_install(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", lambda pkg: "2.0.0") + assert _get_opentelemetry_sdk_version() == "2.0.0" + + +def test_get_opentelemetry_sdk_version_falls_back(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(client_module, "version", MagicMock(side_effect=RuntimeError("boom"))) + assert _get_opentelemetry_sdk_version() == "1.27.0" + + +@pytest.mark.parametrize( + ("endpoint", "expected"), + [ + ( + "https://example.com:9090", + ("example.com:9090", False, "example.com", 9090), + ), + ( + "http://localhost", + ("localhost:4317", True, "localhost", 4317), + ), + ( + "example.com:bad", + ("example.com:4317", False, "example.com", 4317), + ), + ], +) +def test_resolve_grpc_target_parsable_variants(endpoint: str, expected: tuple[str, bool, str, int]) -> None: + assert TencentTraceClient._resolve_grpc_target(endpoint) == expected + + +def test_resolve_grpc_target_handles_errors() -> None: + assert TencentTraceClient._resolve_grpc_target(123) == ("localhost:4317", True, "localhost", 4317) + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_llm_duration", "hist_llm_duration", (0.3, {"foo": object()})), + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_call_histograms(method: str, attr_name: str, args: tuple[object, ...]) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + hist_mock.record.assert_called_once() + + +def test_record_methods_skip_when_histogram_missing() -> None: + client = _build_client() + client.hist_llm_duration = None + client.record_llm_duration(0.1) + + client.hist_token_usage = None + client.record_token_usage(1, "go", "chat", "model", "model", "addr", "provider") + + client.hist_time_to_first_token = None + client.record_time_to_first_token(0.2, "prov", "model") + + client.hist_time_to_generate = None + client.record_time_to_generate(0.3, "prov", "model") + + client.hist_trace_duration = None + client.record_trace_duration(0.5) + + +def test_record_llm_duration_handles_exceptions(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.hist_llm_duration = MagicMock(name="hist_llm_duration") + client.hist_llm_duration.record.side_effect = RuntimeError("boom") + + client.record_llm_duration(0.2) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_create_and_export_span_sets_attributes(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"key": "value"}, + events=[Event(name="evt", attributes={"k": "v"}, timestamp=123)], + status=Status(StatusCode.OK), + start_time=10, + end_time=20, + ) + + client._create_and_export_span(data) + span.set_attributes.assert_called_once() + span.add_event.assert_called_once() + span.set_status.assert_called_once() + span.end.assert_called_once_with(end_time=20) + assert client.span_contexts[2] == "ctx" + + +def test_create_and_export_span_uses_parent_context(patch_core_components: dict[str, object]) -> None: + client = _build_client() + client.span_contexts[10] = "existing" + span = patch_core_components["span"] + span.get_span_context.return_value = "child" + + data = SpanData( + trace_id=1, + parent_span_id=10, + span_id=11, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + trace_api = patch_core_components["trace_api"] + trace_api.NonRecordingSpan.assert_called_once_with("existing") + trace_api.set_span_in_context.assert_called_once() + + +def test_create_and_export_span_exception_logs_error(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + client.tracer.start_span.side_effect = RuntimeError("boom") + + client._create_and_export_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_api_check_connects_successfully(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 0 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert client.api_check() + socket_instance.connect_ex.assert_called_once() + + +def test_api_check_returns_false_and_handles_local(monkeypatch: pytest.MonkeyPatch) -> None: + client = _build_client() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("host:123", False, "host", 123)), + ) + + socket_mock = MagicMock() + socket_instance = MagicMock() + socket_instance.connect_ex.return_value = 1 + socket_mock.return_value = socket_instance + monkeypatch.setattr(client_module.socket, "socket", socket_mock) + + assert not client.api_check() + + monkeypatch.setattr( + TencentTraceClient, + "_resolve_grpc_target", + MagicMock(return_value=("localhost:4317", True, "localhost", 4317)), + ) + socket_instance.connect_ex.return_value = 1 + assert client.api_check() + + +def test_api_check_handles_exceptions(monkeypatch: pytest.MonkeyPatch) -> None: + client = TencentTraceClient("svc", "https://localhost", "token") + + monkeypatch.setattr(client_module.socket, "socket", MagicMock(side_effect=RuntimeError("boom"))) + assert client.api_check() + + +def test_get_project_url() -> None: + client = _build_client() + assert client.get_project_url() == "https://console.cloud.tencent.com/apm" + + +def test_shutdown_flushes_all_components(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span_processor = patch_core_components["span_processor"] + tracer_provider = patch_core_components["tracer_provider"] + + client.shutdown() + span_processor.force_flush.assert_called_once() + span_processor.shutdown.assert_called_once() + tracer_provider.shutdown.assert_called_once() + + meter_provider = meter_provider_instances[-1] + metric_reader = metric_reader_instances[-1] + meter_provider.shutdown.assert_called_once() + metric_reader.shutdown.assert_called_once() + + +def test_shutdown_logs_when_meter_provider_fails(patch_core_components: dict[str, object]) -> None: + client = _build_client() + meter_provider = meter_provider_instances[-1] + meter_provider.shutdown.side_effect = RuntimeError("boom") + client.metric_reader.shutdown.side_effect = RuntimeError("boom") + + client.shutdown() + logger = patch_core_components["logger"] + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down meter provider", + exc_info=True, + ) + logger.debug.assert_any_call( + "[Tencent APM] Error shutting down metric reader", + exc_info=True, + ) + + +def test_metrics_initialization_failure_sets_histogram_attributes(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(DummyMeterProvider, "__init__", MagicMock(side_effect=RuntimeError("err"))) + client = _build_client() + + assert client.meter is None + assert client.meter_provider is None + assert client.hist_llm_duration is None + assert client.hist_token_usage is None + assert client.hist_time_to_first_token is None + assert client.hist_time_to_generate is None + assert client.hist_trace_duration is None + assert client.metric_reader is None + + +def test_add_span_logs_exception(monkeypatch: pytest.MonkeyPatch, patch_core_components: dict[str, object]) -> None: + client = _build_client() + monkeypatch.setattr(client, "_create_and_export_span", MagicMock(side_effect=RuntimeError("boom"))) + + client.add_span( + SpanData( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={}, + events=[], + start_time=0, + end_time=1, + ) + ) + + logger = patch_core_components["logger"] + logger.exception.assert_called_once() + + +def test_create_and_export_span_converts_attribute_types(patch_core_components: dict[str, object]) -> None: + client = _build_client() + span = patch_core_components["span"] + span.get_span_context.return_value = "ctx" + + data = SpanData.model_construct( + trace_id=1, + parent_span_id=None, + span_id=2, + name="span", + attributes={"num": 5, "flag": True, "pi": 3.14, "text": "value"}, + events=[], + links=[], + status=Status(StatusCode.OK), + start_time=0, + end_time=1, + ) + + client._create_and_export_span(data) + (attrs,) = span.set_attributes.call_args.args + assert attrs["num"] == 5 + assert attrs["flag"] is True + assert attrs["pi"] == 3.14 + assert attrs["text"] == "value" + + +def test_record_llm_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_llm_duration") + client.hist_llm_duration = hist_mock + + client.record_llm_duration(0.3, {"foo": object(), "bar": 2}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["foo"], str) + assert attrs["bar"] == 2 + + +def test_record_trace_duration_converts_attributes() -> None: + client = _build_client() + hist_mock = MagicMock(name="hist_trace_duration") + client.hist_trace_duration = hist_mock + + client.record_trace_duration(1.0, {"meta": object(), "ok": True}) + _, attrs = hist_mock.record.call_args.args + assert isinstance(attrs["meta"], str) + assert attrs["ok"] is True + + +@pytest.mark.parametrize( + ("method", "attr_name", "args"), + [ + ("record_token_usage", "hist_token_usage", (5, "input", "chat", "gpt", "gpt", "addr", "dify")), + ("record_time_to_first_token", "hist_time_to_first_token", (0.4, "dify", "gpt")), + ("record_time_to_generate", "hist_time_to_generate", (0.6, "dify", "gpt")), + ("record_trace_duration", "hist_trace_duration", (1.0, {"meta": object()})), + ], +) +def test_record_methods_handle_exceptions( + method: str, attr_name: str, args: tuple[object, ...], patch_core_components: dict[str, object] +) -> None: + client = _build_client() + hist_mock = MagicMock(name=attr_name) + hist_mock.record.side_effect = RuntimeError("boom") + setattr(client, attr_name, hist_mock) + + getattr(client, method)(*args) + logger = patch_core_components["logger"] + logger.debug.assert_called() + + +def test_metrics_initializes_grpc_metric_exporter() -> None: + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyGrpcMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == "trace.example.com:4317" + assert metric_reader.exporter.kwargs["insecure"] is False + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_protobuf_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/protobuf") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + + +def test_metrics_initializes_http_json_metric_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + client = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporter) + assert metric_reader.export_interval_millis == client.metrics_export_interval_sec * 1000 + assert metric_reader.exporter.kwargs["endpoint"] == client.endpoint + assert metric_reader.exporter.kwargs["headers"]["authorization"] == "Bearer token" + assert "preferred_temporality" in metric_reader.exporter.kwargs + + +def test_metrics_http_json_metric_exporter_falls_back_without_temporality(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + exporter_module = sys.modules["opentelemetry.exporter.otlp.http.json.metric_exporter"] + monkeypatch.setattr(exporter_module, "OTLPMetricExporter", DummyJsonMetricExporterNoTemporality) + _ = _build_client() + metric_reader = metric_reader_instances[-1] + + assert isinstance(metric_reader.exporter, DummyJsonMetricExporterNoTemporality) + assert "preferred_temporality" not in metric_reader.exporter.kwargs + + +def test_metrics_http_json_uses_http_fallback_when_no_json_exporter(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("OTEL_EXPORTER_OTLP_PROTOCOL", "http/json") + + def _fail_import(mod_path: str) -> types.ModuleType: + raise ModuleNotFoundError(mod_path) + + monkeypatch.setattr(client_module.importlib, "import_module", _fail_import) + + _ = _build_client() + metric_reader = metric_reader_instances[-1] + assert isinstance(metric_reader.exporter, DummyHttpMetricExporter) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py new file mode 100644 index 0000000000..a0b6d52720 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -0,0 +1,359 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +from opentelemetry.trace import StatusCode + +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + MessageTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.entities.semconv import ( + GEN_AI_IS_ENTRY, + GEN_AI_IS_STREAMING_REQUEST, + GEN_AI_MODEL_NAME, + GEN_AI_SPAN_KIND, + GEN_AI_USAGE_INPUT_TOKENS, + INPUT_VALUE, + RETRIEVAL_DOCUMENT, + RETRIEVAL_QUERY, + TOOL_DESCRIPTION, + TOOL_NAME, + TOOL_PARAMETERS, + GenAISpanKind, +) +from core.ops.tencent_trace.span_builder import TencentSpanBuilder +from core.rag.models.document import Document +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus + + +class TestTencentSpanBuilder: + def test_get_time_nanoseconds(self): + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_datetime_to_nanoseconds") as mock_convert: + mock_convert.return_value = 123456789 + dt = datetime.now() + result = TencentSpanBuilder._get_time_nanoseconds(dt) + assert result == 123456789 + mock_convert.assert_called_once_with(dt) + + def test_build_workflow_spans(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {"sys.query": "hello"} + trace_info.workflow_run_outputs = {"answer": "world"} + trace_info.metadata = {"conversation_id": "conv_id"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.side_effect = [1, 2] # workflow_span_id, message_span_id + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 2 + assert spans[0].name == "message" + assert spans[0].span_id == 2 + assert spans[1].name == "workflow" + assert spans[1].span_id == 1 + assert spans[1].parent_span_id == 2 + + def test_build_workflow_spans_no_message(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.workflow_run_inputs = {} + trace_info.workflow_run_outputs = {} + trace_info.metadata = {} # No conversation_id + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 1 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + spans = TencentSpanBuilder.build_workflow_spans(trace_info, 123, "user_1") + + assert len(spans) == 1 + assert spans[0].name == "workflow" + assert spans[0].status.status_code == StatusCode.ERROR + assert spans[0].status.description == "some error" + assert spans[0].attributes[GEN_AI_IS_ENTRY] == "true" + + def test_build_workflow_llm_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = { + "model_name": "gpt-4", + "model_provider": "openai", + "usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30, "time_to_first_token": 0.5}, + "prompts": ["hello"], + } + node_execution.outputs = {"text": "world"} + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.name == "GENERATION" + assert span.attributes[GEN_AI_MODEL_NAME] == "gpt-4" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "10" + + def test_build_workflow_llm_span_usage_in_outputs(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.process_data = {} + node_execution.outputs = { + "text": "world", + "usage": {"prompt_tokens": 15, "completion_tokens": 25, "total_tokens": 40}, + } + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 456 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_llm_span(123, 1, trace_info, node_execution) + + assert span.attributes[GEN_AI_USAGE_INPUT_TOKENS] == "15" + assert GEN_AI_IS_STREAMING_REQUEST not in span.attributes + + def test_build_message_span_standalone(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = {"q": "hi"} + trace_info.outputs = "hello" + trace_info.metadata = {"conversation_id": "conv_id"} + trace_info.is_streaming_request = True + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.name == "message" + assert span.attributes[GEN_AI_IS_STREAMING_REQUEST] == "true" + assert span.attributes[INPUT_VALUE] == str(trace_info.inputs) + + def test_build_message_span_standalone_with_error(self): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg_id" + trace_info.error = "some error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.inputs = None + trace_info.outputs = None + trace_info.metadata = {} + trace_info.is_streaming_request = False + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 789 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_message_span(trace_info, 123, "user_1") + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "some error" + assert span.attributes[INPUT_VALUE] == "" + + def test_build_tool_span(self): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg_id" + trace_info.tool_name = "search" + trace_info.error = "tool error" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.tool_parameters = {"p": 1} + trace_info.tool_inputs = {"i": 2} + trace_info.tool_outputs = "result" + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 101 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_tool_span(trace_info, 123, 1) + + assert span.name == "search" + assert span.status.status_code == StatusCode.ERROR + assert span.attributes[TOOL_NAME] == "search" + + def test_build_retrieval_span(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "query" + trace_info.error = None + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + + doc = Document( + page_content="content", metadata={"dataset_id": "d1", "doc_id": "di1", "document_id": "du1", "score": 0.9} + ) + trace_info.documents = [doc] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.name == "retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "query" + assert "content" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_retrieval_span_with_error(self): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg_id" + trace_info.inputs = "" + trace_info.error = "retrieval failed" + trace_info.start_time = datetime.now() + trace_info.end_time = datetime.now() + trace_info.documents = [] + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 202 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_retrieval_span(trace_info, 123, 1) + + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "retrieval failed" + + def test_get_workflow_node_status(self): + node = MagicMock(spec=WorkflowNodeExecution) + + node.status = WorkflowNodeExecutionStatus.SUCCEEDED + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.OK + + node.status = WorkflowNodeExecutionStatus.FAILED + node.error = "fail" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "fail" + + node.status = WorkflowNodeExecutionStatus.EXCEPTION + node.error = "exc" + status = TencentSpanBuilder._get_workflow_node_status(node) + assert status.status_code == StatusCode.ERROR + assert status.description == "exc" + + node.status = WorkflowNodeExecutionStatus.RUNNING + assert TencentSpanBuilder._get_workflow_node_status(node).status_code == StatusCode.UNSET + + def test_build_workflow_retrieval_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"query": "q1"} + node_execution.outputs = {"result": [{"content": "c1"}]} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.name == "my retrieval" + assert span.attributes[RETRIEVAL_QUERY] == "q1" + assert "c1" in span.attributes[RETRIEVAL_DOCUMENT] + + def test_build_workflow_retrieval_span_empty(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my retrieval" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {} + node_execution.outputs = {} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 303 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_retrieval_span(123, 1, trace_info, node_execution) + + assert span.attributes[RETRIEVAL_QUERY] == "" + assert span.attributes[RETRIEVAL_DOCUMENT] == "" + + def test_build_workflow_tool_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"info": "some"}} + node_execution.inputs = {"param": "val"} + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.name == "my tool" + assert span.attributes[TOOL_NAME] == "my tool" + assert "some" in span.attributes[TOOL_DESCRIPTION] + + def test_build_workflow_tool_span_no_metadata(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my tool" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.metadata = None + node_execution.inputs = None + node_execution.outputs = {"res": "ok"} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 404 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_tool_span(123, 1, trace_info, node_execution) + + assert span.attributes[TOOL_DESCRIPTION] == "{}" + assert span.attributes[TOOL_PARAMETERS] == "{}" + + def test_build_workflow_task_span(self): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"conversation_id": "conv_id"} + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "node_id" + node_execution.title = "my task" + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_execution.inputs = {"in": 1} + node_execution.outputs = {"out": 2} + node_execution.created_at = datetime.now() + node_execution.finished_at = datetime.now() + + with patch("core.ops.tencent_trace.utils.TencentTraceUtils.convert_to_span_id") as mock_convert_id: + mock_convert_id.return_value = 505 + with patch.object(TencentSpanBuilder, "_get_time_nanoseconds", return_value=100): + span = TencentSpanBuilder.build_workflow_task_span(123, 1, trace_info, node_execution) + + assert span.name == "my task" + assert span.attributes[GEN_AI_SPAN_KIND] == GenAISpanKind.TASK.value diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py new file mode 100644 index 0000000000..f259e4639f --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -0,0 +1,647 @@ +import logging +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TencentConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowTraceInfo, +) +from core.ops.tencent_trace.tencent_trace import TencentDataTrace +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import BuiltinNodeTypes +from models import Account, App, TenantAccountJoin + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def tencent_config(): + return TencentConfig(service_name="test-service", endpoint="https://test-endpoint", token="test-token") + + +@pytest.fixture +def mock_trace_client(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceClient") as mock: + yield mock + + +@pytest.fixture +def mock_span_builder(): + with patch("core.ops.tencent_trace.tencent_trace.TencentSpanBuilder") as mock: + yield mock + + +@pytest.fixture +def mock_trace_utils(): + with patch("core.ops.tencent_trace.tencent_trace.TencentTraceUtils") as mock: + yield mock + + +@pytest.fixture +def tencent_data_trace(tencent_config, mock_trace_client): + return TencentDataTrace(tencent_config) + + +class TestTencentDataTrace: + def test_init(self, tencent_config, mock_trace_client): + trace = TencentDataTrace(tencent_config) + mock_trace_client.assert_called_once_with( + service_name=tencent_config.service_name, + endpoint=tencent_config.endpoint, + token=tencent_config.token, + metrics_export_interval_sec=5, + ) + assert trace.trace_client == mock_trace_client.return_value + + def test_trace_dispatch(self, tencent_data_trace): + methods = [ + ( + WorkflowTraceInfo( + workflow_id="wf", + tenant_id="t", + workflow_run_id="run", + workflow_run_elapsed_time=1.0, + workflow_run_status="s", + workflow_run_inputs={}, + workflow_run_outputs={}, + workflow_run_version="v", + total_tokens=0, + file_list=[], + query="", + metadata={}, + ), + "workflow_trace", + ), + ( + MessageTraceInfo( + message_id="msg", + message_data={}, + inputs={}, + outputs={}, + start_time=None, + end_time=None, + conversation_mode="chat", + conversation_model="gpt-3.5-turbo", + message_tokens=0, + answer_tokens=0, + total_tokens=0, + metadata={}, + ), + "message_trace", + ), + ( + ModerationTraceInfo( + flagged=False, action="a", preset_response="p", query="q", metadata={}, message_id="m" + ), + None, + ), # Pass + ( + SuggestedQuestionTraceInfo( + suggested_question=[], + level="l", + total_tokens=0, + metadata={}, + message_id="m", + message_data={}, + inputs={}, + start_time=None, + end_time=None, + ), + "suggested_question_trace", + ), + ( + DatasetRetrievalTraceInfo( + metadata={}, + message_id="m", + message_data={}, + inputs={}, + documents=[], + start_time=None, + end_time=None, + ), + "dataset_retrieval_trace", + ), + ( + ToolTraceInfo( + tool_name="t", + tool_inputs={}, + tool_outputs="", + tool_config={}, + tool_parameters={}, + time_cost=0, + metadata={}, + message_id="m", + inputs={}, + outputs={}, + start_time=None, + end_time=None, + ), + "tool_trace", + ), + ( + GenerateNameTraceInfo( + tenant_id="t", metadata={}, message_id="m", inputs={}, outputs={}, start_time=None, end_time=None + ), + None, + ), # Pass + ] + + for trace_info, method_name in methods: + if method_name: + with patch.object(tencent_data_trace, method_name) as mock_method: + tencent_data_trace.trace(trace_info) + mock_method.assert_called_once_with(trace_info) + else: + tencent_data_trace.trace(trace_info) + + def test_api_check(self, tencent_data_trace): + tencent_data_trace.trace_client.api_check.return_value = True + assert tencent_data_trace.api_check() is True + tencent_data_trace.trace_client.api_check.assert_called_once() + + def test_get_project_url(self, tencent_data_trace): + tencent_data_trace.trace_client.get_project_url.return_value = "http://url" + assert tencent_data_trace.get_project_url() == "http://url" + tencent_data_trace.trace_client.get_project_url.assert_called_once() + + def test_workflow_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_process_workflow_nodes") as mock_proc: + with patch.object(tencent_data_trace, "_record_workflow_trace_duration") as mock_dur: + mock_span_builder.build_workflow_spans.return_value = [MagicMock(), MagicMock()] + + tencent_data_trace.workflow_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("run-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_workflow_spans.assert_called_once() + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_proc.assert_called_once_with(trace_info, 123) + mock_dur.assert_called_once_with(trace_info) + + def test_workflow_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.workflow_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow trace") + + def test_message_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.message_id = "msg-id" + trace_info.trace_id = "parent-trace-id" + + mock_trace_utils.convert_to_trace_id.return_value = 123 + mock_trace_utils.create_link.return_value = "link" + + with patch.object(tencent_data_trace, "_get_user_id", return_value="user-1"): + with patch.object(tencent_data_trace, "_record_message_llm_metrics") as mock_metrics: + with patch.object(tencent_data_trace, "_record_message_trace_duration") as mock_dur: + mock_span_builder.build_message_span.return_value = MagicMock() + + tencent_data_trace.message_trace(trace_info) + + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_trace_utils.create_link.assert_called_once_with("parent-trace-id") + mock_span_builder.build_message_span.assert_called_once() + tencent_data_trace.trace_client.add_span.assert_called_once() + mock_metrics.assert_called_once_with(trace_info) + mock_dur.assert_called_once_with(trace_info) + + def test_message_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_trace_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.message_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process message trace") + + def test_tool_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.tool_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_tool_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_tool_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = None + + tencent_data_trace.tool_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_tool_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=ToolTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.tool_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process tool trace") + + def test_dataset_retrieval_trace(self, tencent_data_trace, mock_trace_utils, mock_span_builder): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + mock_trace_utils.convert_to_span_id.return_value = 456 + mock_trace_utils.convert_to_trace_id.return_value = 123 + + tencent_data_trace.dataset_retrieval_trace(trace_info) + + mock_trace_utils.convert_to_span_id.assert_called_once_with("msg-id", "message") + mock_trace_utils.convert_to_trace_id.assert_called_once_with("msg-id") + mock_span_builder.build_retrieval_span.assert_called_once_with(trace_info, 123, 456) + tencent_data_trace.trace_client.add_span.assert_called_once() + + def test_dataset_retrieval_trace_no_msg_id(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = None + + tencent_data_trace.dataset_retrieval_trace(trace_info) + tencent_data_trace.trace_client.add_span.assert_not_called() + + def test_dataset_retrieval_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=DatasetRetrievalTraceInfo) + trace_info.message_id = "msg-id" + + with patch( + "core.ops.tencent_trace.tencent_trace.TencentTraceUtils.convert_to_span_id", side_effect=Exception("error") + ): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.dataset_retrieval_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process dataset retrieval trace") + + def test_suggested_question_trace(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Processing suggested question trace") + + def test_suggested_question_trace_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=SuggestedQuestionTraceInfo) + with patch("core.ops.tencent_trace.tencent_trace.logger.info", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.suggested_question_trace(trace_info) + mock_log.assert_called_once_with("[Tencent APM] Failed to process suggested question trace") + + def test_process_workflow_nodes(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.workflow_run_id = "run-id" + mock_trace_utils.convert_to_span_id.return_value = 111 + + node1 = MagicMock(spec=WorkflowNodeExecution) + node1.id = "n1" + node1.node_type = BuiltinNodeTypes.LLM + node2 = MagicMock(spec=WorkflowNodeExecution) + node2.id = "n2" + node2.node_type = BuiltinNodeTypes.TOOL + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node1, node2]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=["span1", "span2"]): + with patch.object(tencent_data_trace, "_record_llm_metrics") as mock_metrics: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + + assert tencent_data_trace.trace_client.add_span.call_count == 2 + mock_metrics.assert_called_once_with(node1) + + def test_process_workflow_nodes_node_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.return_value = 111 + + node = MagicMock(spec=WorkflowNodeExecution) + node.id = "n1" + + with patch.object(tencent_data_trace, "_get_workflow_node_executions", return_value=[node]): + with patch.object(tencent_data_trace, "_build_workflow_node_span", side_effect=Exception("node error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + # The exception should be caught by the outer handler since convert_to_span_id is called first + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_process_workflow_nodes_exception(self, tencent_data_trace, mock_trace_utils): + trace_info = MagicMock(spec=WorkflowTraceInfo) + mock_trace_utils.convert_to_span_id.side_effect = Exception("outer error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace._process_workflow_nodes(trace_info, 123) + mock_log.assert_called_once_with("[Tencent APM] Failed to process workflow nodes") + + def test_build_workflow_node_span(self, tencent_data_trace, mock_span_builder): + trace_info = MagicMock(spec=WorkflowTraceInfo) + + nodes = [ + (BuiltinNodeTypes.LLM, mock_span_builder.build_workflow_llm_span), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, mock_span_builder.build_workflow_retrieval_span), + (BuiltinNodeTypes.TOOL, mock_span_builder.build_workflow_tool_span), + (BuiltinNodeTypes.CODE, mock_span_builder.build_workflow_task_span), + ] + + for node_type, builder_method in nodes: + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = node_type + builder_method.return_value = "span" + + result = tencent_data_trace._build_workflow_node_span(node, 123, trace_info, 456) + + assert result == "span" + builder_method.assert_called_once_with(123, 456, trace_info, node) + + def test_build_workflow_node_span_exception(self, tencent_data_trace, mock_span_builder): + node = MagicMock(spec=WorkflowNodeExecution) + node.node_type = BuiltinNodeTypes.LLM + node.id = "n1" + mock_span_builder.build_workflow_llm_span.side_effect = Exception("error") + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + result = tencent_data_trace._build_workflow_node_span(node, 123, MagicMock(), 456) + assert result is None + mock_log.assert_called_once() + + def test_get_workflow_node_executions(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + trace_info.workflow_run_id = "run-1" + + app = MagicMock(spec=App) + app.id = "app-1" + app.created_by = "user-1" + + account = MagicMock(spec=Account) + account.id = "user-1" + + tenant_join = MagicMock(spec=TenantAccountJoin) + tenant_join.tenant_id = "tenant-1" + + mock_executions = [MagicMock()] + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.side_effect = [app, account] + session.query.return_value.filter_by.return_value.first.return_value = tenant_join + + with patch( + "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" + ) as mock_repo: + mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + + results = tencent_data_trace._get_workflow_node_executions(trace_info) + + assert results == mock_executions + account.set_tenant_id.assert_called_once_with("tenant-1") + + def test_get_workflow_node_executions_no_app_id(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {} + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_workflow_node_executions_app_not_found(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.metadata = {"app_id": "app-1"} + + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() # Ensure init_app is mocked + mock_db.engine = "engine" + with patch("core.ops.tencent_trace.tencent_trace.Session") as mock_session_ctx: + session = mock_session_ctx.return_value.__enter__.return_value + session.scalar.return_value = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + results = tencent_data_trace._get_workflow_node_executions(trace_info) + assert results == [] + mock_log.assert_called_once() + + def test_get_user_id_workflow(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "tenant-1" + trace_info.metadata = {"user_id": "user-1"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("Database error")): + with patch("core.ops.tencent_trace.tencent_trace.db") as mock_db: + mock_db.init_app = MagicMock() + mock_db.engine = MagicMock() + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + + def test_get_user_id_only_user_id(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"user_id": "user-1"} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "user-1" + + def test_get_user_id_anonymous(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "anonymous" + + def test_get_user_id_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.tenant_id = "t" + trace_info.metadata = {"user_id": "u"} + + with patch("core.ops.tencent_trace.tencent_trace.sessionmaker", side_effect=Exception("error")): + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + user_id = tencent_data_trace._get_user_id(trace_info) + assert user_id == "unknown" + mock_log.assert_called_once_with("[Tencent APM] Failed to get user ID") + + def test_record_llm_metrics_usage_in_process_data(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = { + "usage": { + "latency": 2.5, + "time_to_first_token": 0.5, + "time_to_generate": 2.0, + "prompt_tokens": 10, + "completion_tokens": 20, + }, + "model_provider": "openai", + "model_name": "gpt-4", + "model_mode": "chat", + } + node.outputs = {} + + tencent_data_trace._record_llm_metrics(node) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_llm_metrics_usage_in_outputs(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = {} + node.outputs = {"usage": {"latency": 1.0, "prompt_tokens": 5}} + + tencent_data_trace._record_llm_metrics(node) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_token_usage.assert_called_once() + + def test_record_llm_metrics_exception(self, tencent_data_trace): + node = MagicMock(spec=WorkflowNodeExecution) + node.process_data = None + node.outputs = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_llm_metrics(node) + # Should not crash + + def test_record_message_llm_metrics(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {"ls_provider": "openai", "ls_model_name": "gpt-4"} + trace_info.message_data = {"provider_response_latency": 1.1} + trace_info.is_streaming_request = True + trace_info.gen_ai_server_time_to_first_token = 0.2 + trace_info.llm_streaming_time_to_generate = 0.9 + trace_info.message_tokens = 15 + trace_info.answer_tokens = 25 + + tencent_data_trace._record_message_llm_metrics(trace_info) + + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + tencent_data_trace.trace_client.record_time_to_first_token.assert_called_once() + tencent_data_trace.trace_client.record_time_to_generate.assert_called_once() + assert tencent_data_trace.trace_client.record_token_usage.call_count == 2 + + def test_record_message_llm_metrics_object_data(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = {} + msg_data = MagicMock() + msg_data.provider_response_latency = 1.1 + msg_data.model_provider = "anthropic" + msg_data.model_id = "claude" + trace_info.message_data = msg_data + trace_info.is_streaming_request = False + + tencent_data_trace._record_message_llm_metrics(trace_info) + tencent_data_trace.trace_client.record_llm_duration.assert_called_once() + + def test_record_message_llm_metrics_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.metadata = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_llm_metrics(trace_info) + # Should not crash + + def test_record_workflow_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=3) + trace_info.workflow_run_status = "succeeded" + trace_info.conversation_id = "conv-1" + + # Mock the record_trace_duration method to capture arguments + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + # Assert the method was called once + mock_record.assert_called_once() + + # Extract arguments passed to the method + args, kwargs = mock_record.call_args + + # Validate the duration argument + assert args[0] == 3.0 + + # Validate the attributes dict in kwargs + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["conversation_mode"] == "workflow" + assert attributes["has_conversation"] == "true" + + def test_record_workflow_trace_duration_fallback(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = None + trace_info.workflow_run_elapsed_time = 4.5 + trace_info.workflow_run_status = "failed" + trace_info.conversation_id = None + + with patch.object(tencent_data_trace.trace_client, "record_trace_duration") as mock_record: + tencent_data_trace._record_workflow_trace_duration(trace_info) + mock_record.assert_called_once() + args, kwargs = mock_record.call_args + assert args[0] == 4.5 + # Check attributes dict (either in kwargs or as second positional arg) + attributes = kwargs["attributes"] if "attributes" in kwargs else args[1] if len(args) > 1 else {} + assert attributes["has_conversation"] == "false" + + def test_record_workflow_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=WorkflowTraceInfo) + trace_info.start_time = MagicMock() # This might cause total_seconds() to fail if not mocked right + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_workflow_trace_duration(trace_info) + + def test_record_message_trace_duration(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + from datetime import datetime, timedelta + + now = datetime.now() + trace_info.start_time = now + trace_info.end_time = now + timedelta(seconds=2) + trace_info.conversation_mode = "chat" + trace_info.is_streaming_request = True + + tencent_data_trace._record_message_trace_duration(trace_info) + tencent_data_trace.trace_client.record_trace_duration.assert_called_once_with( + 2.0, {"conversation_mode": "chat", "stream": "true"} + ) + + def test_record_message_trace_duration_exception(self, tencent_data_trace): + trace_info = MagicMock(spec=MessageTraceInfo) + trace_info.start_time = None + + with patch("core.ops.tencent_trace.tencent_trace.logger.debug") as mock_log: + tencent_data_trace._record_message_trace_duration(trace_info) + + def test_del(self, tencent_data_trace): + client = tencent_data_trace.trace_client + tencent_data_trace.__del__() + client.shutdown.assert_called_once() + + def test_del_exception(self, tencent_data_trace): + tencent_data_trace.trace_client.shutdown.side_effect = Exception("error") + with patch("core.ops.tencent_trace.tencent_trace.logger.exception") as mock_log: + tencent_data_trace.__del__() + mock_log.assert_called_once_with("[Tencent APM] Failed to shutdown trace client during cleanup") diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py new file mode 100644 index 0000000000..ef28d18e20 --- /dev/null +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace_utils.py @@ -0,0 +1,106 @@ +"""Unit tests for Tencent APM tracing utilities.""" + +from __future__ import annotations + +import hashlib +import uuid +from datetime import UTC, datetime +from unittest.mock import patch + +import pytest +from opentelemetry.trace import Link, TraceFlags + +from core.ops.tencent_trace.utils import TencentTraceUtils + + +def test_convert_to_trace_id_with_valid_uuid() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + assert TencentTraceUtils.convert_to_trace_id(uuid_str) == uuid.UUID(uuid_str).int + + +def test_convert_to_trace_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + assert TencentTraceUtils.convert_to_trace_id(None) == expected_uuid.int + uuid4_mock.assert_called_once() + + +def test_convert_to_trace_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_trace_id("not-a-uuid") + + +def test_convert_to_span_id_is_deterministic_and_sensitive_to_type() -> None: + uuid_str = "12345678-1234-5678-1234-567812345678" + span_type = "llm" + + uuid_obj = uuid.UUID(uuid_str) + combined_key = f"{uuid_obj.hex}-{span_type}" + hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest() + expected = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False) + + assert TencentTraceUtils.convert_to_span_id(uuid_str, span_type) == expected + assert TencentTraceUtils.convert_to_span_id(uuid_str, "other") != expected + + +def test_convert_to_span_id_uses_uuid4_when_none() -> None: + expected_uuid = uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=expected_uuid) as uuid4_mock: + span_id = TencentTraceUtils.convert_to_span_id(None, "workflow") + assert isinstance(span_id, int) + uuid4_mock.assert_called_once() + + +def test_convert_to_span_id_raises_value_error_for_invalid_uuid() -> None: + with pytest.raises(ValueError, match=r"^Invalid UUID input:"): + TencentTraceUtils.convert_to_span_id("bad-uuid", "span") + + +def test_generate_span_id_skips_invalid_span_id() -> None: + with patch( + "core.ops.tencent_trace.utils.random.getrandbits", + side_effect=[TencentTraceUtils.INVALID_SPAN_ID, 42], + ) as bits_mock: + assert TencentTraceUtils.generate_span_id() == 42 + assert bits_mock.call_count == 2 + + +def test_convert_datetime_to_nanoseconds_accepts_datetime() -> None: + start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + expected = int(start_time.timestamp() * 1e9) + assert TencentTraceUtils.convert_datetime_to_nanoseconds(start_time) == expected + + +def test_convert_datetime_to_nanoseconds_uses_now_when_none() -> None: + fixed = datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC) + expected = int(fixed.timestamp() * 1e9) + + with patch("core.ops.tencent_trace.utils.datetime") as datetime_mock: + datetime_mock.now.return_value = fixed + assert TencentTraceUtils.convert_datetime_to_nanoseconds(None) == expected + datetime_mock.now.assert_called_once() + + +@pytest.mark.parametrize( + ("trace_id_str", "expected_trace_id"), + [ + ("0" * 31 + "1", int("0" * 31 + "1", 16)), + (str(uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc")), uuid.UUID("cccccccc-cccc-cccc-cccc-cccccccccccc").int), + ], +) +def test_create_link_accepts_hex_or_uuid(trace_id_str: str, expected_trace_id: int) -> None: + link = TencentTraceUtils.create_link(trace_id_str) + assert isinstance(link, Link) + assert link.context.trace_id == expected_trace_id + assert link.context.span_id == TencentTraceUtils.INVALID_SPAN_ID + assert link.context.is_remote is False + assert link.context.trace_flags == TraceFlags(TraceFlags.SAMPLED) + + +@pytest.mark.parametrize("trace_id_str", ["g" * 32, "not-a-uuid", None]) +def test_create_link_falls_back_to_uuid4(trace_id_str: object) -> None: + fallback_uuid = uuid.UUID("dddddddd-dddd-dddd-dddd-dddddddddddd") + with patch("core.ops.tencent_trace.utils.uuid.uuid4", return_value=fallback_uuid) as uuid4_mock: + link = TencentTraceUtils.create_link(trace_id_str) # type: ignore[arg-type] + assert link.context.trace_id == fallback_uuid.int + uuid4_mock.assert_called_once() diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py new file mode 100644 index 0000000000..49d6b698ef --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -0,0 +1,36 @@ +from openinference.semconv.trace import OpenInferenceSpanKindValues + +from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind +from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes + + +class TestGetNodeSpanKind: + """Tests for _get_node_span_kind helper.""" + + def test_all_node_types_are_mapped_correctly(self): + """Ensure every built-in node type is mapped to the correct span kind.""" + # Mappings for node types that have a specialised span kind. + special_mappings = { + BuiltinNodeTypes.LLM: OpenInferenceSpanKindValues.LLM, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER, + BuiltinNodeTypes.TOOL: OpenInferenceSpanKindValues.TOOL, + BuiltinNodeTypes.AGENT: OpenInferenceSpanKindValues.AGENT, + } + + # Test that every built-in node type is mapped to the correct span kind. + # Node types not in `special_mappings` should default to CHAIN. + for node_type in BUILT_IN_NODE_TYPES: + expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN) + actual_span_kind = _get_node_span_kind(node_type) + assert actual_span_kind == expected_span_kind, ( + f"Node type {node_type!r} was mapped to {actual_span_kind}, but {expected_span_kind} was expected." + ) + + def test_unknown_string_defaults_to_chain(self): + """An unrecognised node type string should still return CHAIN.""" + assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN + + def test_stale_dataset_retrieval_not_in_mapping(self): + """The old 'dataset_retrieval' string was never a valid NodeType value; + make sure it is not present in the mapping dictionary.""" + assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND diff --git a/api/tests/unit_tests/core/ops/test_base_trace_instance.py b/api/tests/unit_tests/core/ops/test_base_trace_instance.py new file mode 100644 index 0000000000..a8bee7dfa7 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_base_trace_instance.py @@ -0,0 +1,112 @@ +from unittest.mock import MagicMock + +import pytest +from sqlalchemy.orm import Session + +from core.ops.base_trace_instance import BaseTraceInstance +from core.ops.entities.config_entity import BaseTracingConfig +from core.ops.entities.trace_entity import BaseTraceInfo +from models import Account, App, TenantAccountJoin + + +class ConcreteTraceInstance(BaseTraceInstance): + def __init__(self, trace_config: BaseTracingConfig): + super().__init__(trace_config) + + def trace(self, trace_info: BaseTraceInfo): + super().trace(trace_info) + + +@pytest.fixture +def mock_db_session(monkeypatch): + mock_session = MagicMock(spec=Session) + mock_session.__enter__.return_value = mock_session + mock_session.__exit__.return_value = None + + mock_session_class = MagicMock(return_value=mock_session) + + monkeypatch.setattr("core.ops.base_trace_instance.Session", mock_session_class) + monkeypatch.setattr("core.ops.base_trace_instance.db", MagicMock()) + return mock_session + + +def test_get_service_account_with_tenant_app_not_found(mock_db_session): + mock_db_session.scalar.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id not found"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_no_creator(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = None + mock_db_session.scalar.return_value = mock_app + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="App with id some_app_id has no creator"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_creator_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + # First call to scalar returns app, second returns None (for account) + mock_db_session.scalar.side_effect = [mock_app, None] + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Creator account with id creator_id not found for app some_app_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_tenant_not_found(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + # session.query(TenantAccountJoin).filter_by(...).first() returns None + mock_db_session.query.return_value.filter_by.return_value.first.return_value = None + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + with pytest.raises(ValueError, match="Current tenant not found for account creator_id"): + instance.get_service_account_with_tenant("some_app_id") + + +def test_get_service_account_with_tenant_success(mock_db_session): + mock_app = MagicMock(spec=App) + mock_app.id = "some_app_id" + mock_app.created_by = "creator_id" + + mock_account = MagicMock(spec=Account) + mock_account.id = "creator_id" + mock_account.set_tenant_id = MagicMock() + + mock_db_session.scalar.side_effect = [mock_app, mock_account] + + mock_tenant_join = MagicMock(spec=TenantAccountJoin) + mock_tenant_join.tenant_id = "tenant_id" + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_tenant_join + + config = MagicMock(spec=BaseTracingConfig) + instance = ConcreteTraceInstance(config) + + result = instance.get_service_account_with_tenant("some_app_id") + + assert result == mock_account + mock_account.set_tenant_id.assert_called_once_with("tenant_id") diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py new file mode 100644 index 0000000000..7660967183 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -0,0 +1,329 @@ +"""Tests for OpikDataTrace workflow_trace changes. + +Covers: +- _seed_to_uuid4 helper: produces valid UUID4 strings deterministically +- prepare_opik_uuid helper: basic contract +- workflow_trace without message_id now creates a root span parented to None +- workflow_trace without message_id: node spans parent to root_span_id (not workflow_app_log_id) +- workflow_trace with message_id still creates root span keyed on workflow_run_id (unchanged path) +""" + +from __future__ import annotations + +import uuid +from datetime import datetime +from unittest.mock import MagicMock, patch + +from core.ops.entities.trace_entity import TraceTaskName, WorkflowTraceInfo +from core.ops.opik_trace.opik_trace import OpikDataTrace, _seed_to_uuid4, prepare_opik_uuid + +# A stable UUID4 used as the workflow_run_id throughout all tests. +_WORKFLOW_RUN_ID = "a3f1b2c4-d5e6-4f78-9a0b-c1d2e3f4a5b6" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_workflow_trace_info( + *, + message_id: str | None = None, + workflow_app_log_id: str | None = None, + workflow_run_id: str = _WORKFLOW_RUN_ID, +) -> WorkflowTraceInfo: + """Return a minimal WorkflowTraceInfo suitable for unit testing.""" + return WorkflowTraceInfo( + message_id=message_id, + workflow_id="wf-id", + tenant_id="tenant-id", + workflow_run_id=workflow_run_id, + workflow_app_log_id=workflow_app_log_id, + workflow_run_elapsed_time=1.5, + workflow_run_status="succeeded", + workflow_run_inputs={"query": "hello"}, + workflow_run_outputs={"result": "world"}, + workflow_run_version="1", + total_tokens=42, + file_list=[], + query="hello", + start_time=datetime(2025, 1, 1, 12, 0, 0), + end_time=datetime(2025, 1, 1, 12, 0, 1), + metadata={"app_id": "app-abc"}, + conversation_id=None, + ) + + +def _make_opik_trace_instance() -> OpikDataTrace: + """Construct an OpikDataTrace with the Opik SDK client mocked out.""" + with patch("core.ops.opik_trace.opik_trace.Opik"): + from core.ops.entities.config_entity import OpikConfig + + config = OpikConfig(api_key="key", project="test-project", url="https://www.comet.com/opik/api/") + instance = OpikDataTrace(config) + + instance.add_trace = MagicMock(return_value=MagicMock(id="mock-trace-id")) + instance.add_span = MagicMock() + instance.get_service_account_with_tenant = MagicMock(return_value=MagicMock()) + return instance + + +# --------------------------------------------------------------------------- +# _seed_to_uuid4 +# --------------------------------------------------------------------------- + + +class TestSeedToUuid4: + def test_returns_valid_uuid4_string(self): + result = _seed_to_uuid4("some-arbitrary-seed") + parsed = uuid.UUID(result) + assert parsed.version == 4 + + def test_is_deterministic(self): + assert _seed_to_uuid4("seed-abc") == _seed_to_uuid4("seed-abc") + + def test_different_seeds_give_different_results(self): + assert _seed_to_uuid4("seed-1") != _seed_to_uuid4("seed-2") + + def test_workflow_run_id_with_root_suffix_is_valid_uuid4(self): + """The primary use-case: deriving a root-span UUID from workflow_run_id + '-root'.""" + seed = _WORKFLOW_RUN_ID + "-root" + result = _seed_to_uuid4(seed) + parsed = uuid.UUID(result) + assert parsed.version == 4 + + def test_seed_and_seed_root_produce_different_uuids(self): + """Root span UUID must differ from the base workflow UUID to avoid ID collisions.""" + base = _seed_to_uuid4(_WORKFLOW_RUN_ID) + with_root = _seed_to_uuid4(_WORKFLOW_RUN_ID + "-root") + assert base != with_root + + +# --------------------------------------------------------------------------- +# prepare_opik_uuid +# --------------------------------------------------------------------------- + + +class TestPrepareOpikUuid: + def test_is_deterministic(self): + dt = datetime(2025, 6, 15, 10, 30, 0) + uid = str(uuid.uuid4()) + assert prepare_opik_uuid(dt, uid) == prepare_opik_uuid(dt, uid) + + def test_different_uuids_give_different_results(self): + dt = datetime(2025, 6, 15, 10, 30, 0) + assert prepare_opik_uuid(dt, str(uuid.uuid4())) != prepare_opik_uuid(dt, str(uuid.uuid4())) + + def test_none_datetime_does_not_raise(self): + assert prepare_opik_uuid(None, str(uuid.uuid4())) is not None + + def test_none_uuid_does_not_raise(self): + assert prepare_opik_uuid(datetime(2025, 1, 1), None) is not None + + +# --------------------------------------------------------------------------- +# workflow_trace — no message_id (new code path) +# --------------------------------------------------------------------------- + + +class TestWorkflowTraceWithoutMessageId: + def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): + instance = _make_opik_trace_instance() + fake_repo = MagicMock() + fake_repo.get_by_workflow_run.return_value = node_executions or [] + + with ( + patch("core.ops.opik_trace.opik_trace.db") as mock_db, + patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch( + "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=fake_repo, + ), + ): + mock_db.engine = MagicMock() + instance.workflow_trace(trace_info) + + return instance + + def _expected_root_span_id(self, trace_info: WorkflowTraceInfo): + return prepare_opik_uuid( + trace_info.start_time, + _seed_to_uuid4(trace_info.workflow_run_id + "-root"), + ) + + def test_root_span_is_created(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + assert instance.add_span.called + + def test_root_span_id_matches_expected(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + expected = self._expected_root_span_id(trace_info) + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["id"] == expected + + def test_root_span_has_no_parent(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["parent_span_id"] is None + + def test_trace_name_is_workflow_trace(self): + """Without message_id, the Opik trace itself should be named WORKFLOW_TRACE.""" + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + trace_kwargs = instance.add_trace.call_args_list[0][0][0] + assert trace_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE + + def test_root_span_name_is_workflow_trace(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["name"] == TraceTaskName.WORKFLOW_TRACE + + def test_root_span_has_workflow_tag(self): + trace_info = _make_workflow_trace_info(message_id=None) + instance = self._run(trace_info) + + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert "workflow" in root_span_kwargs["tags"] + + def test_node_execution_spans_are_parented_to_root(self): + """Node spans must use root_span_id as parent, not any other ID.""" + trace_info = _make_workflow_trace_info(message_id=None) + expected_root_span_id = self._expected_root_span_id(trace_info) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "LLM Node" + node_exec.node_type = "llm" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {"prompt": "hi"} + node_exec.outputs = {"text": "hello"} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.5 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + # call_args_list[0] = root span, [1] = node execution span + assert instance.add_span.call_count == 2 + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] == expected_root_span_id + + def test_node_span_not_parented_to_workflow_app_log_id(self): + """Old behaviour derived parent from workflow_app_log_id; that must no longer apply.""" + trace_info = _make_workflow_trace_info( + message_id=None, + workflow_app_log_id=str(uuid.uuid4()), + ) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "Tool Node" + node_exec.node_type = "tool" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {} + node_exec.outputs = {} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.2 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + old_parent_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_app_log_id) + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] != old_parent_id + + def test_root_span_id_differs_from_trace_id(self): + """The root span must have a different ID from the Opik trace to maintain correct hierarchy.""" + trace_info = _make_workflow_trace_info(message_id=None) + dify_trace_id = trace_info.trace_id or trace_info.workflow_run_id + opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id) + root_span_id = self._expected_root_span_id(trace_info) + assert root_span_id != opik_trace_id + + +# --------------------------------------------------------------------------- +# workflow_trace — with message_id (unchanged path, guard against regression) +# --------------------------------------------------------------------------- + + +class TestWorkflowTraceWithMessageId: + _MESSAGE_ID = str(uuid.uuid4()) + + def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): + instance = _make_opik_trace_instance() + fake_repo = MagicMock() + fake_repo.get_by_workflow_run.return_value = node_executions or [] + + with ( + patch("core.ops.opik_trace.opik_trace.db") as mock_db, + patch("core.ops.opik_trace.opik_trace.sessionmaker"), + patch( + "core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory.create_workflow_node_execution_repository", + return_value=fake_repo, + ), + ): + mock_db.engine = MagicMock() + instance.workflow_trace(trace_info) + + return instance + + def test_trace_name_is_message_trace(self): + """With message_id, the Opik trace should be named MESSAGE_TRACE.""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + instance = self._run(trace_info) + + trace_kwargs = instance.add_trace.call_args_list[0][0][0] + assert trace_kwargs["name"] == TraceTaskName.MESSAGE_TRACE + + def test_root_span_uses_workflow_run_id_directly(self): + """When message_id is set, root_span_id = prepare_opik_uuid(start_time, workflow_run_id).""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + instance = self._run(trace_info) + + expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) + root_span_kwargs = instance.add_span.call_args_list[0][0][0] + assert root_span_kwargs["id"] == expected_root_span_id + + def test_root_span_id_differs_from_no_message_id_case(self): + """The two branches must produce different root span IDs for the same workflow_run_id.""" + id_with_message = prepare_opik_uuid( + datetime(2025, 1, 1, 12, 0, 0), + _WORKFLOW_RUN_ID, + ) + id_without_message = prepare_opik_uuid( + datetime(2025, 1, 1, 12, 0, 0), + _seed_to_uuid4(_WORKFLOW_RUN_ID + "-root"), + ) + assert id_with_message != id_without_message + + def test_node_spans_parented_to_workflow_run_root_span(self): + """Node spans must still parent to root_span_id derived from workflow_run_id.""" + trace_info = _make_workflow_trace_info(message_id=self._MESSAGE_ID) + expected_root_span_id = prepare_opik_uuid(trace_info.start_time, trace_info.workflow_run_id) + + node_exec = MagicMock() + node_exec.id = str(uuid.uuid4()) + node_exec.title = "LLM" + node_exec.node_type = "llm" + node_exec.status = "succeeded" + node_exec.process_data = {} + node_exec.inputs = {} + node_exec.outputs = {} + node_exec.created_at = datetime(2025, 1, 1, 12, 0, 0) + node_exec.elapsed_time = 0.3 + node_exec.metadata = {} + + instance = self._run(trace_info, node_executions=[node_exec]) + + node_span_kwargs = instance.add_span.call_args_list[1][0][0] + assert node_span_kwargs["parent_span_id"] == expected_root_span_id diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py new file mode 100644 index 0000000000..2d325ccb0e --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -0,0 +1,576 @@ +import contextlib +import json +import queue +from datetime import datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from core.ops.ops_trace_manager import ( + OpsTraceManager, + TraceQueueManager, + TraceTask, + TraceTaskName, +) + + +class DummyConfig: + def __init__(self, **kwargs): + self._data = kwargs + + def model_dump(self): + return dict(self._data) + + +class DummyTraceInstance: + instances: list["DummyTraceInstance"] = [] + + def __init__(self, config): + self.config = config + DummyTraceInstance.instances.append(self) + + def api_check(self): + return True + + def get_project_key(self): + return "fake-key" + + def get_project_url(self): + return "https://project.fake" + + +FAKE_PROVIDER_ENTRY = { + "config_class": DummyConfig, + "secret_keys": ["secret_value"], + "other_keys": ["other_value"], + "trace_instance": DummyTraceInstance, +} + + +class FakeProviderMap: + def __init__(self, data): + self._data = data + + def __getitem__(self, key): + if key in self._data: + return self._data[key] + raise KeyError(f"Unsupported tracing provider: {key}") + + +class DummyTimer: + def __init__(self, interval, function): + self.interval = interval + self.function = function + self.name = "" + self.daemon = False + self.started = False + + def start(self): + self.started = True + + def is_alive(self): + return False + + +class FakeMessageFile: + def __init__(self): + self.url = "path/to/file" + self.id = "file-id" + self.type = "document" + self.created_by_role = "role" + self.created_by = "user" + + +def make_message_data(**overrides): + created_at = datetime(2025, 2, 20, 12, 0, 0) + base = { + "id": "msg-id", + "conversation_id": "conv-id", + "created_at": created_at, + "updated_at": created_at + timedelta(seconds=3), + "message": "hello", + "provider_response_latency": 1, + "message_tokens": 5, + "answer_tokens": 7, + "answer": "world", + "error": "", + "status": "complete", + "model_provider": "provider", + "model_id": "model", + "from_end_user_id": "end-user", + "from_account_id": "account", + "agent_based": False, + "workflow_run_id": "workflow-run", + "from_source": "source", + "message_metadata": json.dumps({"usage": {"time_to_first_token": 1, "time_to_generate": 2}}), + "agent_thoughts": [], + "query": "sample-query", + "inputs": "sample-input", + } + base.update(overrides) + + class MessageData: + def __init__(self, data): + self.__dict__.update(data) + + def to_dict(self): + return dict(self.__dict__) + + return MessageData(base) + + +def make_agent_thought(tool_name, created_at): + return SimpleNamespace( + tools=[tool_name], + created_at=created_at, + tool_meta={ + tool_name: { + "tool_config": {"foo": "bar"}, + "time_cost": 5, + "error": "", + "tool_parameters": {"x": 1}, + } + }, + ) + + +def make_workflow_run(): + return SimpleNamespace( + workflow_id="wf-1", + tenant_id="tenant", + id="run-id", + elapsed_time=10, + status="finished", + inputs_dict={"sys.file": ["f1"], "query": "search"}, + outputs_dict={"out": "value"}, + version="3", + error=None, + total_tokens=12, + workflow_run_id="run-id", + created_at=datetime(2025, 2, 20, 10, 0, 0), + finished_at=datetime(2025, 2, 20, 10, 0, 5), + triggered_from="user", + app_id="app-id", + to_dict=lambda self=None: {"run": "value"}, + ) + + +def configure_db_query(session, *, message_file=None, workflow_app_log=None): + def _side_effect(model): + query = MagicMock() + query.filter_by.return_value.first.return_value = None + if message_file and model.__name__ == "MessageFile": + query.filter_by.return_value.first.return_value = message_file + if workflow_app_log and model.__name__ == "WorkflowAppLog": + query.filter_by.return_value.first.return_value = workflow_app_log + return query + + session.query.side_effect = _side_effect + + +class DummySessionContext: + scalar_values = [] + + def __init__(self, engine): + self._values = list(self.scalar_values) + self._index = 0 + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + return False + + def scalar(self, *args, **kwargs): + if self._index >= len(self._values): + return None + value = self._values[self._index] + self._index += 1 + return value + + +@pytest.fixture(autouse=True) +def patch_provider_map(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({"dummy": FAKE_PROVIDER_ENTRY}) + ) + OpsTraceManager.ops_trace_instances_cache.clear() + OpsTraceManager.decrypted_configs_cache.clear() + + +@pytest.fixture(autouse=True) +def patch_timer_and_current_app(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.threading.Timer", DummyTimer) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_queue", queue.Queue()) + monkeypatch.setattr("core.ops.ops_trace_manager.trace_manager_timer", None) + + class FakeApp: + def app_context(self): + return contextlib.nullcontext() + + fake_current = MagicMock() + fake_current._get_current_object.return_value = FakeApp() + monkeypatch.setattr("core.ops.ops_trace_manager.current_app", fake_current) + + +@pytest.fixture(autouse=True) +def patch_sqlalchemy_session(monkeypatch): + monkeypatch.setattr("core.ops.ops_trace_manager.Session", DummySessionContext) + + +@pytest.fixture +def encryption_mocks(monkeypatch): + encrypt_mock = MagicMock(side_effect=lambda tenant, value: f"enc-{value}") + batch_decrypt_mock = MagicMock(side_effect=lambda tenant, values: [f"dec-{value}" for value in values]) + obfuscate_mock = MagicMock(side_effect=lambda value: f"ob-{value}") + monkeypatch.setattr("core.ops.ops_trace_manager.encrypt_token", encrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.batch_decrypt_token", batch_decrypt_mock) + monkeypatch.setattr("core.ops.ops_trace_manager.obfuscated_token", obfuscate_mock) + return encrypt_mock, batch_decrypt_mock, obfuscate_mock + + +@pytest.fixture +def mock_db(monkeypatch): + session = MagicMock() + session.scalars.return_value.all.return_value = ["chat"] + db_mock = MagicMock() + db_mock.session = session + db_mock.engine = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.db", db_mock) + return session + + +@pytest.fixture +def workflow_repo_fixture(monkeypatch): + repo = MagicMock() + repo.get_workflow_run_by_id_without_tenant.return_value = make_workflow_run() + monkeypatch.setattr(TraceTask, "_get_workflow_run_repo", classmethod(lambda cls: repo)) + return repo + + +@pytest.fixture +def trace_task_message(monkeypatch, mock_db): + message_data = make_message_data() + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) + configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + return message_data + + +def test_encrypt_tracing_config_handles_star_and_encrypt(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "value", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "enc-value" + assert encrypted["other_value"] == "info" + + +def test_encrypt_tracing_config_preserves_star(encryption_mocks): + encrypted = OpsTraceManager.encrypt_tracing_config( + "tenant", + "dummy", + {"secret_value": "*", "other_value": "info"}, + current_trace_config={"secret_value": "keep"}, + ) + assert encrypted["secret_value"] == "keep" + + +def test_decrypt_tracing_config_caches(encryption_mocks): + _, decrypt_mock, _ = encryption_mocks + payload = {"secret_value": "enc", "other_value": "info"} + first = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + second = OpsTraceManager.decrypt_tracing_config("tenant", "dummy", payload) + assert first == second + assert decrypt_mock.call_count == 1 + + +def test_obfuscated_decrypt_token(encryption_mocks): + _, _, obfuscate_mock = encryption_mocks + result = OpsTraceManager.obfuscated_decrypt_token("dummy", {"secret_value": "value", "other_value": "info"}) + assert "secret_value" in result + assert result["secret_value"] == "ob-value" + obfuscate_mock.assert_called_once() + + +def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + app = SimpleNamespace(id="app-id", tenant_id="tenant") + mock_db.scalar.return_value = app + + decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + assert decrypted["other_value"] == "info" + + +def test_get_decrypted_tracing_config_missing_trace_config(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None + + +def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): + trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): + trace_config_data = SimpleNamespace(tracing_config=None) + mock_db.query.return_value.where.return_value.first.return_value = trace_config_data + mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + with pytest.raises(ValueError, match="Tracing config cannot be None"): + OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") + + +def test_get_ops_trace_instance_handles_none_app(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) + mock_db.query.return_value.where.return_value.first.return_value = app + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) + assert OpsTraceManager.get_ops_trace_instance("app-id") is None + + +def test_get_ops_trace_instance_success(monkeypatch, mock_db): + app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) + mock_db.query.return_value.where.return_value.first.return_value = app + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", + classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), + ) + instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is not None + cached_instance = OpsTraceManager.get_ops_trace_instance("app-id") + assert instance is cached_instance + + +def test_get_app_config_through_message_id_returns_none(mock_db): + mock_db.scalar.return_value = None + assert OpsTraceManager.get_app_config_through_message_id("m") is None + + +def test_get_app_config_through_message_id_prefers_override(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id=None, override_model_configs={"foo": "bar"}) + app_config = SimpleNamespace(id="config-id") + mock_db.scalar.side_effect = [message, conversation] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result == {"foo": "bar"} + + +def test_get_app_config_through_message_id_app_model_config(mock_db): + message = SimpleNamespace(conversation_id="conv") + conversation = SimpleNamespace(app_model_config_id="cfg", override_model_configs=None) + mock_db.scalar.side_effect = [message, conversation, SimpleNamespace(id="cfg")] + result = OpsTraceManager.get_app_config_through_message_id("m") + assert result.id == "cfg" + + +def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="Invalid tracing provider"): + OpsTraceManager.update_app_tracing_config("app", True, "bad") + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.update_app_tracing_config("app", True, None) + + +def test_update_app_tracing_config_success(mock_db): + app = SimpleNamespace(id="app-id", tracing="{}") + mock_db.query.return_value.where.return_value.first.return_value = app + OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") + assert app.tracing is not None + mock_db.commit.assert_called_once() + + +def test_get_app_tracing_config_errors_when_missing(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ValueError, match="App not found"): + OpsTraceManager.get_app_tracing_config("app") + + +def test_get_app_tracing_config_returns_defaults(mock_db): + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} + + +def test_get_app_tracing_config_returns_payload(mock_db): + payload = {"enabled": True, "tracing_provider": "dummy"} + mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + assert OpsTraceManager.get_app_tracing_config("app-id") == payload + + +def test_check_and_project_helpers(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.provider_config_map", + FakeProviderMap( + { + "dummy": { + "config_class": DummyConfig, + "trace_instance": type( + "Trace", + (), + { + "__init__": lambda self, cfg: None, + "api_check": lambda self: True, + "get_project_key": lambda self: "key", + "get_project_url": lambda self: "url", + }, + ), + "secret_keys": [], + "other_keys": [], + } + } + ), + ) + assert OpsTraceManager.check_trace_config_is_effective({}, "dummy") + assert OpsTraceManager.get_trace_config_project_key({}, "dummy") == "key" + assert OpsTraceManager.get_trace_config_project_url({}, "dummy") == "url" + + +def test_trace_task_conversation_and_extract(monkeypatch): + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE, message_id="msg") + assert task.conversation_trace(foo="bar") == {"foo": "bar"} + assert task._extract_streaming_metrics(make_message_data(message_metadata="not json")) == {} + + +def test_trace_task_message_trace(trace_task_message, mock_db): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + result = task.message_trace("msg-id") + assert result.message_id == "msg-id" + + +def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): + DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] + execution = SimpleNamespace(id_="run-id") + task = TraceTask( + trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" + ) + result = task.workflow_trace(workflow_run_id="run-id", conversation_id="conv", user_id="user") + assert result.workflow_run_id == "run-id" + assert result.workflow_id == "wf-1" + + +def test_trace_task_moderation_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.MODERATION_TRACE, message_id="msg-id") + moderation_result = SimpleNamespace(action="block", preset_response="no", query="q", flagged=True) + timer = {"start": 1, "end": 2} + result = task.moderation_trace("msg-id", timer, moderation_result=moderation_result, inputs={"src": "payload"}) + assert result.flagged is True + assert result.message_id == "log-id" + + +def test_trace_task_suggested_question_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + result = task.suggested_question_trace("msg-id", timer, suggested_question=["q1"]) + assert result.message_id == "log-id" + assert "suggested_question" in result.__dict__ + + +def test_trace_task_dataset_retrieval_trace(trace_task_message): + task = TraceTask(trace_type=TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 2} + mock_doc = SimpleNamespace(model_dump=lambda: {"doc": "value"}) + result = task.dataset_retrieval_trace("msg-id", timer, documents=[mock_doc]) + assert result.documents == [{"doc": "value"}] + + +def test_trace_task_tool_trace(monkeypatch, mock_db): + custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) + monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) + configure_db_query(mock_db, message_file=FakeMessageFile()) + task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") + timer = {"start": 1, "end": 5} + result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") + assert result.tool_name == "tool-a" + assert result.time_cost == 5 + + +def test_trace_task_generate_name_trace(): + task = TraceTask(trace_type=TraceTaskName.GENERATE_NAME_TRACE, conversation_id="conv-id") + timer = {"start": 1, "end": 2} + assert task.generate_name_trace("conv-id", timer, tenant_id=None) == {} + result = task.generate_name_trace( + "conv-id", timer, tenant_id="tenant", generate_conversation_name="name", inputs="q" + ) + assert result.outputs == "name" + assert result.tenant_id == "tenant" + + +def test_extract_streaming_metrics_invalid_json(): + task = TraceTask(trace_type=TraceTaskName.MESSAGE_TRACE, message_id="msg-id") + fake_message = make_message_data(message_metadata="invalid") + assert task._extract_streaming_metrics(fake_message) == {} + + +def test_trace_queue_manager_add_and_collect(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + manager.add_trace_task(task) + tasks = manager.collect_tasks() + assert tasks == [task] + + +def test_trace_queue_manager_run_invokes_send(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + manager = TraceQueueManager(app_id="app-id", user_id="user") + task = TraceTask(trace_type=TraceTaskName.CONVERSATION_TRACE) + called = {} + + def fake_collect(): + return [task] + + def fake_send(tasks): + called["tasks"] = tasks + + monkeypatch.setattr(TraceQueueManager, "collect_tasks", lambda self: fake_collect()) + monkeypatch.setattr(TraceQueueManager, "send_to_celery", lambda self, t: fake_send(t)) + manager.run() + assert called["tasks"] == [task] + + +def test_trace_queue_manager_send_to_celery(monkeypatch): + monkeypatch.setattr( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", classmethod(lambda cls, aid: True) + ) + storage_save = MagicMock() + process_delay = MagicMock() + monkeypatch.setattr("core.ops.ops_trace_manager.storage.save", storage_save) + monkeypatch.setattr("core.ops.ops_trace_manager.process_trace_tasks.delay", process_delay) + monkeypatch.setattr("core.ops.ops_trace_manager.uuid4", MagicMock(return_value=SimpleNamespace(hex="file-123"))) + + manager = TraceQueueManager(app_id="app-id", user_id="user") + + class DummyTraceInfo: + def model_dump(self): + return {"trace": "info"} + + class DummyTask: + def __init__(self): + self.app_id = "app-id" + + def execute(self): + return DummyTraceInfo() + + task = DummyTask() + manager.send_to_celery([task]) + storage_save.assert_called_once() + process_delay.assert_called_once_with({"file_id": "file-123", "app_id": "app-id"}) diff --git a/api/tests/unit_tests/core/ops/test_utils.py b/api/tests/unit_tests/core/ops/test_utils.py index e1084001b7..8a89422782 100644 --- a/api/tests/unit_tests/core/ops/test_utils.py +++ b/api/tests/unit_tests/core/ops/test_utils.py @@ -1,9 +1,20 @@ import re from datetime import datetime +from unittest.mock import MagicMock, patch import pytest -from core.ops.utils import generate_dotted_order, validate_project_name, validate_url, validate_url_with_path +from core.ops.utils import ( + filter_none_values, + generate_dotted_order, + get_message_data, + measure_time, + replace_text_with_content, + validate_integer_id, + validate_project_name, + validate_url, + validate_url_with_path, +) class TestValidateUrl: @@ -187,3 +198,92 @@ class TestGenerateDottedOrder: result = generate_dotted_order(run_id, start_time, None) assert "." not in result + + def test_dotted_order_with_string_start_time(self): + """Test dotted_order generation with string start_time.""" + start_time = "2025-12-23T04:19:55.111000" + run_id = "test-run-id" + result = generate_dotted_order(run_id, start_time) + + assert result == "20251223T041955111000Ztest-run-id" + + +class TestFilterNoneValues: + """Test cases for filter_none_values function""" + + def test_filter_none_values(self): + data = {"a": 1, "b": None, "c": "test", "d": datetime(2025, 1, 1, 12, 0, 0)} + result = filter_none_values(data) + assert result == {"a": 1, "c": "test", "d": "2025-01-01T12:00:00"} + + def test_filter_none_values_empty(self): + assert filter_none_values({}) == {} + + +class TestGetMessageData: + """Test cases for get_message_data function""" + + @patch("core.ops.utils.db") + @patch("core.ops.utils.Message") + @patch("core.ops.utils.select") + def test_get_message_data(self, mock_select, mock_message, mock_db): + mock_scalar = mock_db.session.scalar + mock_msg_instance = MagicMock() + mock_scalar.return_value = mock_msg_instance + + result = get_message_data("message-id") + + assert result == mock_msg_instance + mock_select.assert_called_once() + mock_scalar.assert_called_once() + + +class TestMeasureTime: + """Test cases for measure_time function""" + + def test_measure_time(self): + with measure_time() as timing_info: + assert "start" in timing_info + assert isinstance(timing_info["start"], datetime) + assert timing_info["end"] is None + + assert timing_info["end"] is not None + assert isinstance(timing_info["end"], datetime) + assert timing_info["end"] >= timing_info["start"] + + +class TestReplaceTextWithContent: + """Test cases for replace_text_with_content function""" + + def test_replace_text_with_content_dict(self): + data = {"text": "hello", "other": "world"} + assert replace_text_with_content(data) == {"content": "hello", "other": "world"} + + def test_replace_text_with_content_nested(self): + data = {"text": "v1", "nested": {"text": "v2", "list": [{"text": "v3"}]}} + expected = {"content": "v1", "nested": {"content": "v2", "list": [{"content": "v3"}]}} + assert replace_text_with_content(data) == expected + + def test_replace_text_with_content_list(self): + data = [{"text": "v1"}, "v2"] + assert replace_text_with_content(data) == [{"content": "v1"}, "v2"] + + def test_replace_text_with_content_primitive(self): + assert replace_text_with_content(123) == 123 + assert replace_text_with_content("text") == "text" + + +class TestValidateIntegerId: + """Test cases for validate_integer_id function""" + + def test_valid_integer_id(self): + assert validate_integer_id("123") == "123" + assert validate_integer_id(" 456 ") == "456" + + def test_invalid_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("abc") + + def test_empty_integer_id_raises_error(self): + with pytest.raises(ValueError, match="ID must be a valid integer"): + validate_integer_id("") diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py new file mode 100644 index 0000000000..8057bbbad5 --- /dev/null +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -0,0 +1,1196 @@ +"""Comprehensive tests for core.ops.weave_trace.weave_trace module.""" + +from __future__ import annotations + +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from weave.trace_server.trace_server_interface import TraceStatus + +from core.ops.entities.config_entity import WeaveConfig +from core.ops.entities.trace_entity import ( + DatasetRetrievalTraceInfo, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + TraceTaskName, + WorkflowTraceInfo, +) +from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.ops.weave_trace.weave_trace import WeaveDataTrace +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +def _dt() -> datetime: + return datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC) + + +def _make_weave_config(**overrides) -> WeaveConfig: + defaults = { + "api_key": "wv-api-key", + "project": "my-project", + "entity": "my-entity", + "host": None, + } + defaults.update(overrides) + return WeaveConfig(**defaults) + + +def _make_workflow_trace_info(**overrides) -> WorkflowTraceInfo: + defaults = { + "workflow_id": "wf-id", + "tenant_id": "tenant-1", + "workflow_run_id": "run-1", + "workflow_run_elapsed_time": 1.0, + "workflow_run_status": "succeeded", + "workflow_run_inputs": {"key": "val"}, + "workflow_run_outputs": {"answer": "42"}, + "workflow_run_version": "v1", + "total_tokens": 10, + "file_list": [], + "query": "hello", + "metadata": {"user_id": "u1", "app_id": "app-1"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return WorkflowTraceInfo(**defaults) + + +def _make_message_trace_info(**overrides) -> MessageTraceInfo: + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + defaults = { + "conversation_model": "chat", + "message_tokens": 5, + "answer_tokens": 10, + "total_tokens": 15, + "conversation_mode": "chat", + "metadata": {"conversation_id": "c1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": {"prompt": "hi"}, + "outputs": "ok", + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return MessageTraceInfo(**defaults) + + +def _make_moderation_trace_info(**overrides) -> ModerationTraceInfo: + defaults = { + "flagged": False, + "action": "allow", + "preset_response": "", + "query": "test", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + } + defaults.update(overrides) + return ModerationTraceInfo(**defaults) + + +def _make_suggested_question_trace_info(**overrides) -> SuggestedQuestionTraceInfo: + defaults = { + "suggested_question": ["q1", "q2"], + "level": "info", + "total_tokens": 5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": SimpleNamespace(created_at=_dt(), updated_at=_dt()), + "inputs": {"i": 1}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return SuggestedQuestionTraceInfo(**defaults) + + +def _make_dataset_retrieval_trace_info(**overrides) -> DatasetRetrievalTraceInfo: + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + defaults = { + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "message_data": msg_data, + "inputs": "query", + "documents": [{"content": "doc"}], + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return DatasetRetrievalTraceInfo(**defaults) + + +def _make_tool_trace_info(**overrides) -> ToolTraceInfo: + defaults = { + "tool_name": "my_tool", + "tool_inputs": {"x": 1}, + "tool_outputs": "output", + "tool_config": {"desc": "d"}, + "tool_parameters": {"p": "v"}, + "time_cost": 0.5, + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": "v"}, + "outputs": {"o": "v"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + "error": None, + } + defaults.update(overrides) + return ToolTraceInfo(**defaults) + + +def _make_generate_name_trace_info(**overrides) -> GenerateNameTraceInfo: + defaults = { + "tenant_id": "t1", + "metadata": {"user_id": "u1"}, + "message_id": "msg-1", + "inputs": {"i": 1}, + "outputs": {"name": "test"}, + "start_time": _dt(), + "end_time": _dt() + timedelta(seconds=1), + } + defaults.update(overrides) + return GenerateNameTraceInfo(**defaults) + + +def _make_node(**overrides): + """Create a mock workflow node execution object.""" + defaults = { + "id": "node-1", + "title": "Node Title", + "node_type": BuiltinNodeTypes.CODE, + "status": "succeeded", + "inputs": {"key": "value"}, + "outputs": {"result": "ok"}, + "created_at": _dt(), + "elapsed_time": 1.0, + "process_data": None, + "metadata": {}, + } + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +# ── Fixtures ───────────────────────────────────────────────────────────────── + + +@pytest.fixture +def mock_wandb(): + with patch("core.ops.weave_trace.weave_trace.wandb") as mock: + mock.login.return_value = True + yield mock + + +@pytest.fixture +def mock_weave(): + with patch("core.ops.weave_trace.weave_trace.weave") as mock: + client = MagicMock() + client.entity = "my-entity" + client.project = "my-project" + mock.init.return_value = client + yield mock, client + + +@pytest.fixture +def trace_instance(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with mocked wandb/weave.""" + _, weave_client = mock_weave + config = _make_weave_config() + instance = WeaveDataTrace(config) + return instance + + +@pytest.fixture +def trace_instance_with_host(mock_wandb, mock_weave): + """Create a WeaveDataTrace instance with host configured.""" + _, weave_client = mock_weave + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + return instance + + +# ── TestInit ───────────────────────────────────────────────────────────────── + + +class TestInit: + def test_init_without_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login without host.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(host=None) + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with(key="wv-api-key", verify=True, relogin=True) + mock_w.init.assert_called_once_with(project_name="my-entity/my-project") + assert instance.weave_api_key == "wv-api-key" + assert instance.project_name == "my-project" + assert instance.entity == "my-entity" + assert instance.calls == {} + + def test_init_with_host(self, mock_wandb, mock_weave): + """Test __init__ calls wandb.login with host.""" + config = _make_weave_config(host="https://my.wandb.host") + instance = WeaveDataTrace(config) + + mock_wandb.login.assert_called_once_with( + key="wv-api-key", verify=True, relogin=True, host="https://my.wandb.host" + ) + assert instance.host == "https://my.wandb.host" + + def test_init_without_entity(self, mock_wandb, mock_weave): + """Test __init__ initializes weave without entity prefix when entity is None.""" + mock_w, weave_client = mock_weave + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + + mock_w.init.assert_called_once_with(project_name="my-project") + + def test_init_login_failure_raises(self, mock_wandb, mock_weave): + """Test __init__ raises ValueError when wandb.login returns False.""" + mock_wandb.login.return_value = False + config = _make_weave_config() + + with pytest.raises(ValueError, match="Weave login failed"): + WeaveDataTrace(config) + + def test_init_files_url_from_env(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL is read from environment.""" + monkeypatch.setenv("FILES_URL", "http://files.example.com") + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://files.example.com" + + def test_init_files_url_default(self, mock_wandb, mock_weave, monkeypatch): + """Test FILES_URL defaults to http://127.0.0.1:5001.""" + monkeypatch.delenv("FILES_URL", raising=False) + config = _make_weave_config() + instance = WeaveDataTrace(config) + assert instance.file_base_url == "http://127.0.0.1:5001" + + def test_project_id_set_correctly(self, trace_instance): + """Test that project_id is set from weave_client entity/project.""" + assert trace_instance.project_id == "my-entity/my-project" + + +# ── TestGetProjectUrl ───────────────────────────────────────────────────────── + + +class TestGetProjectUrl: + def test_get_project_url_with_entity(self, trace_instance): + """Returns wandb URL with entity/project.""" + url = trace_instance.get_project_url() + assert url == "https://wandb.ai/my-entity/my-project" + + def test_get_project_url_without_entity(self, mock_wandb, mock_weave): + """Returns wandb URL with project only when entity is None.""" + config = _make_weave_config(entity=None) + instance = WeaveDataTrace(config) + url = instance.get_project_url() + assert url == "https://wandb.ai/my-project" + + def test_get_project_url_exception_raises(self, trace_instance, monkeypatch): + """Raises ValueError when exception occurs in get_project_url.""" + monkeypatch.setattr(trace_instance, "entity", None) + monkeypatch.setattr(trace_instance, "project_name", None) + # Force an error by making string formatting fail + with patch("core.ops.weave_trace.weave_trace.logger") as mock_logger: + # Simulate exception via property + original_entity = trace_instance.entity + trace_instance.entity = None + trace_instance.project_name = None + url = trace_instance.get_project_url() + assert "https://wandb.ai/" in url + + +# ── TestTraceDispatcher ───────────────────────────────────────────────────── + + +class TestTraceDispatcher: + def test_dispatches_workflow_trace(self, trace_instance): + with patch.object(trace_instance, "workflow_trace") as mock_wt: + trace_instance.trace(_make_workflow_trace_info()) + mock_wt.assert_called_once() + + def test_dispatches_message_trace(self, trace_instance): + with patch.object(trace_instance, "message_trace") as mock_mt: + trace_instance.trace(_make_message_trace_info()) + mock_mt.assert_called_once() + + def test_dispatches_moderation_trace(self, trace_instance): + with patch.object(trace_instance, "moderation_trace") as mock_mod: + msg_data = MagicMock() + msg_data.created_at = _dt() + trace_instance.trace(_make_moderation_trace_info(message_data=msg_data)) + mock_mod.assert_called_once() + + def test_dispatches_suggested_question_trace(self, trace_instance): + with patch.object(trace_instance, "suggested_question_trace") as mock_sq: + trace_instance.trace(_make_suggested_question_trace_info()) + mock_sq.assert_called_once() + + def test_dispatches_dataset_retrieval_trace(self, trace_instance): + with patch.object(trace_instance, "dataset_retrieval_trace") as mock_dr: + trace_instance.trace(_make_dataset_retrieval_trace_info()) + mock_dr.assert_called_once() + + def test_dispatches_tool_trace(self, trace_instance): + with patch.object(trace_instance, "tool_trace") as mock_tool: + trace_instance.trace(_make_tool_trace_info()) + mock_tool.assert_called_once() + + def test_dispatches_generate_name_trace(self, trace_instance): + with patch.object(trace_instance, "generate_name_trace") as mock_gn: + trace_instance.trace(_make_generate_name_trace_info()) + mock_gn.assert_called_once() + + +# ── TestNormalizeTime ───────────────────────────────────────────────────────── + + +class TestNormalizeTime: + def test_none_returns_utc_now(self, trace_instance): + now_before = datetime.now(UTC) + result = trace_instance._normalize_time(None) + now_after = datetime.now(UTC) + assert result.tzinfo is not None + assert now_before <= result <= now_after + + def test_naive_datetime_gets_utc(self, trace_instance): + naive = datetime(2024, 6, 15, 12, 0, 0) + result = trace_instance._normalize_time(naive) + assert result.tzinfo == UTC + assert result.year == 2024 + assert result.month == 6 + + def test_aware_datetime_unchanged(self, trace_instance): + aware = datetime(2024, 6, 15, 12, 0, 0, tzinfo=UTC) + result = trace_instance._normalize_time(aware) + assert result == aware + assert result.tzinfo == UTC + + +# ── TestStartCall ───────────────────────────────────────────────────────────── + + +class TestStartCall: + def test_start_call_basic(self, trace_instance): + """Test basic start_call stores call metadata.""" + run = WeaveTraceModel( + id="run-1", + op="test-op", + inputs={"key": "val"}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run) + + assert "run-1" in trace_instance.calls + assert trace_instance.calls["run-1"]["trace_id"] == "t-1" + assert trace_instance.calls["run-1"]["parent_id"] is None + trace_instance.weave_client.server.call_start.assert_called_once() + + def test_start_call_with_parent(self, trace_instance): + """Test start_call records parent_run_id.""" + run = WeaveTraceModel( + id="child-1", + op="child-op", + inputs={}, + attributes={"trace_id": "t-1", "start_time": _dt()}, + ) + trace_instance.start_call(run, parent_run_id="parent-1") + + assert trace_instance.calls["child-1"]["parent_id"] == "parent-1" + + def test_start_call_none_inputs_becomes_empty_dict(self, trace_instance): + """Test that None inputs is normalized to {}.""" + run = WeaveTraceModel( + id="run-2", + op="op", + inputs=None, + attributes={"trace_id": "t-2", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert req.start.inputs == {} + + def test_start_call_non_dict_inputs_becomes_str_dict(self, trace_instance): + """Test that non-dict inputs is wrapped as string.""" + run = WeaveTraceModel( + id="run-3", + op="op", + inputs="some string input", + attributes={"trace_id": "t-3", "start_time": _dt()}, + ) + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + # String inputs gets converted by validator to a dict + assert isinstance(req.start.inputs, dict) + + def test_start_call_none_attributes_becomes_empty_dict(self, trace_instance): + """Test that None attributes is handled properly.""" + run = WeaveTraceModel( + id="run-4", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.start_call(run) + # trace_id should fall back to run_data.id + assert trace_instance.calls["run-4"]["trace_id"] == "run-4" + + def test_start_call_non_dict_attributes_becomes_dict(self, trace_instance): + """Test that non-dict attributes is wrapped.""" + run = WeaveTraceModel( + id="run-5", + op="op", + inputs={}, + attributes=None, + ) + # Manually override after construction + run.attributes = "some-attr-string" + trace_instance.start_call(run) + call_args = trace_instance.weave_client.server.call_start.call_args + req = call_args[0][0] + assert isinstance(req.start.attributes, dict) + assert req.start.attributes == {"attributes": "some-attr-string"} + + def test_start_call_trace_id_falls_back_to_run_id(self, trace_instance): + """When trace_id not in attributes, falls back to run_data.id.""" + run = WeaveTraceModel( + id="run-6", + op="op", + inputs={}, + attributes={"start_time": _dt()}, + ) + trace_instance.start_call(run) + assert trace_instance.calls["run-6"]["trace_id"] == "run-6" + + +# ── TestFinishCall ────────────────────────────────────────────────────────── + + +class TestFinishCall: + def _setup_call(self, trace_instance, run_id="run-1", trace_id="t-1"): + """Helper: register a call so finish_call can find it.""" + trace_instance.calls[run_id] = {"trace_id": trace_id, "parent_id": None} + + def test_finish_call_success(self, trace_instance): + """Test finish_call sends call_end with SUCCESS status.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={"result": "ok"}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 1 + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 0 + assert req.end.exception is None + + def test_finish_call_with_error(self, trace_instance): + """Test finish_call sends call_end with ERROR status when exception is set.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + outputs={}, + attributes={"start_time": _dt(), "end_time": _dt() + timedelta(seconds=1)}, + exception="Something broke", + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["status_counts"][TraceStatus.ERROR] == 1 + assert req.end.summary["status_counts"][TraceStatus.SUCCESS] == 0 + assert req.end.exception == "Something broke" + + def test_finish_call_missing_id_raises(self, trace_instance): + """Test finish_call raises ValueError when call id not found.""" + run = WeaveTraceModel( + id="nonexistent", + op="op", + inputs={}, + ) + with pytest.raises(ValueError, match="Call with id nonexistent not found"): + trace_instance.finish_call(run) + + def test_finish_call_elapsed_negative_clamped_to_zero(self, trace_instance): + """Test that negative elapsed time is clamped to 0.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes={ + "start_time": _dt() + timedelta(seconds=5), + "end_time": _dt(), # end before start + }, + ) + trace_instance.finish_call(run) + call_args = trace_instance.weave_client.server.call_end.call_args + req = call_args[0][0] + assert req.end.summary["weave"]["latency_ms"] == 0 + + def test_finish_call_none_attributes(self, trace_instance): + """Test finish_call handles None attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + def test_finish_call_non_dict_attributes(self, trace_instance): + """Test finish_call handles non-dict attributes.""" + self._setup_call(trace_instance) + run = WeaveTraceModel( + id="run-1", + op="op", + inputs={}, + attributes=None, + ) + run.attributes = "some string attr" + trace_instance.finish_call(run) + trace_instance.weave_client.server.call_end.assert_called_once() + + +# ── TestWorkflowTrace ───────────────────────────────────────────────────────── + + +class TestWorkflowTrace: + def _setup_repo(self, monkeypatch, nodes=None): + """Helper to patch session/repo dependencies.""" + if nodes is None: + nodes = [] + + repo = MagicMock() + repo.get_by_workflow_run.return_value = nodes + + mock_factory = MagicMock() + mock_factory.create_workflow_node_execution_repository.return_value = repo + + monkeypatch.setattr("core.ops.weave_trace.weave_trace.DifyCoreRepositoryFactory", mock_factory) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + return repo + + def test_workflow_trace_no_nodes_no_message_id(self, trace_instance, monkeypatch): + """Workflow trace with no nodes and no message_id.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Only workflow run: start_call and finish_call each called once + assert trace_instance.start_call.call_count == 1 + assert trace_instance.finish_call.call_count == 1 + + def test_workflow_trace_with_message_id(self, trace_instance, monkeypatch): + """Workflow trace with message_id creates both message and workflow runs.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id="msg-1") + trace_instance.workflow_trace(trace_info) + + # message run + workflow run = 2 start_call / finish_call + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_workflow_trace_with_node_execution(self, trace_instance, monkeypatch): + """Workflow trace iterates node executions and creates node runs.""" + node = _make_node( + id="node-1", + node_type=BuiltinNodeTypes.CODE, + inputs={"k": "v"}, + outputs={"r": "ok"}, + elapsed_time=0.5, + created_at=_dt(), + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 5}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # workflow run + node run = 2 calls + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_with_llm_node(self, trace_instance, monkeypatch): + """LLM node uses process_data prompts as inputs.""" + node = _make_node( + node_type=BuiltinNodeTypes.LLM, + process_data={ + "prompts": [{"role": "user", "content": "hi"}], + "model_mode": "chat", + "model_provider": "openai", + "model_name": "gpt-4", + }, + inputs={"key": "val"}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Check node start_call was called with prompts input + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + # WeaveTraceModel validator wraps list prompts into {"messages": [...]} + # The key "messages" should be present (validator transforms the list) + assert "messages" in node_run.inputs + + def test_workflow_trace_with_non_llm_node_uses_inputs(self, trace_instance, monkeypatch): + """Non-LLM node uses node_execution.inputs directly.""" + node = _make_node( + node_type=BuiltinNodeTypes.TOOL, + inputs={"tool_input": "val"}, + process_data=None, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # node run inputs should be from node.inputs; validator adds usage_metadata + file_list + node_call_args = trace_instance.start_call.call_args_list[-1] + node_run = node_call_args[0][0] + assert node_run.inputs.get("tool_input") == "val" + + def test_workflow_trace_missing_app_id_raises(self, trace_instance, monkeypatch): + """Raises ValueError when app_id is missing from metadata.""" + monkeypatch.setattr("core.ops.weave_trace.weave_trace.sessionmaker", lambda bind: MagicMock()) + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", MagicMock(engine="engine")) + + trace_info = _make_workflow_trace_info( + message_id=None, + metadata={"user_id": "u1"}, # no app_id + ) + + with pytest.raises(ValueError, match="No app_id found in trace_info metadata"): + trace_instance.workflow_trace(trace_info) + + def test_workflow_trace_start_time_none_defaults_to_now(self, trace_instance, monkeypatch): + """start_time defaults to datetime.now() when None.""" + self._setup_repo(monkeypatch, nodes=[]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None, start_time=None) + trace_instance.workflow_trace(trace_info) + + assert trace_instance.start_call.call_count == 1 + + def test_workflow_trace_node_created_at_none(self, trace_instance, monkeypatch): + """Node with created_at=None uses datetime.now().""" + node = _make_node(created_at=None, elapsed_time=0.5) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_workflow_trace_chat_mode_llm_node_adds_provider(self, trace_instance, monkeypatch): + """Chat mode LLM node adds ls_provider and ls_model_name to attributes.""" + node = _make_node( + node_type=BuiltinNodeTypes.LLM, + process_data={"model_mode": "chat", "model_provider": "openai", "model_name": "gpt-4", "prompts": []}, + ) + self._setup_repo(monkeypatch, nodes=[node]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + start_calls = [] + + def capture_start(run, parent_run_id=None): + start_calls.append((run, parent_run_id)) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # Last start call is the node run + node_run, _ = start_calls[-1] + assert node_run.attributes.get("ls_provider") == "openai" + assert node_run.attributes.get("ls_model_name") == "gpt-4" + + def test_workflow_trace_nodes_sorted_by_created_at(self, trace_instance, monkeypatch): + """Nodes are sorted by created_at before processing.""" + node1 = _make_node(id="node-b", created_at=_dt() + timedelta(seconds=2)) + node2 = _make_node(id="node-a", created_at=_dt()) + self._setup_repo(monkeypatch, nodes=[node1, node2]) + monkeypatch.setattr(trace_instance, "get_service_account_with_tenant", lambda app_id: MagicMock()) + + processed_ids = [] + + def capture_start(run, parent_run_id=None): + processed_ids.append(run.id) + + trace_instance.start_call = capture_start + trace_instance.finish_call = MagicMock() + + trace_info = _make_workflow_trace_info(message_id=None) + trace_instance.workflow_trace(trace_info) + + # First call = workflow run, then node-a, then node-b + assert processed_ids[1] == "node-a" + assert processed_ids[2] == "node-b" + + +# ── TestMessageTrace ────────────────────────────────────────────────────────── + + +class TestMessageTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """message_trace returns early when message_data is None.""" + trace_info = _make_message_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.message_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_message_trace(self, trace_instance, monkeypatch): + """message_trace creates message run and llm child run.""" + monkeypatch.setattr( + "core.ops.weave_trace.weave_trace.db.session.query", + lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + ) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info() + trace_instance.message_trace(trace_info) + + # message run + llm child run + assert trace_instance.start_call.call_count == 2 + assert trace_instance.finish_call.call_count == 2 + + def test_message_trace_with_file_data(self, trace_instance, monkeypatch): + """message_trace appends file URL to file_list.""" + file_data = MagicMock() + file_data.url = "path/to/file.png" + trace_instance.file_base_url = "http://files.test" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info( + message_file_data=file_data, + file_list=["existing.txt"], + ) + trace_instance.message_trace(trace_info) + + # The first start_call arg (the message run) should have file in outputs or inputs + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert "http://files.test/path/to/file.png" in message_run.file_list + + def test_message_trace_with_end_user(self, trace_instance, monkeypatch): + """message_trace looks up end user and sets end_user_id attribute.""" + end_user = MagicMock() + end_user.session_id = "session-xyz" + + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = end_user + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = "eu-1" + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.attributes.get("end_user_id") == "session-xyz" + + def test_message_trace_no_end_user(self, trace_instance, monkeypatch): + """message_trace handles when from_end_user_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + msg_data = MagicMock() + msg_data.id = "msg-1" + msg_data.from_account_id = "acc-1" + msg_data.from_end_user_id = None + + trace_info = _make_message_trace_info(message_data=msg_data) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): + """trace_id falls back to message_id when trace_id is None.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(trace_id=None) + trace_instance.message_trace(trace_info) + + message_run = trace_instance.start_call.call_args_list[0][0][0] + assert message_run.id == "msg-1" + + def test_message_trace_file_list_none(self, trace_instance, monkeypatch): + """message_trace handles file_list=None gracefully.""" + mock_db = MagicMock() + mock_db.session.query.return_value.where.return_value.first.return_value = None + monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_message_trace_info(file_list=None, message_file_data=None) + trace_instance.message_trace(trace_info) + assert trace_instance.start_call.call_count == 2 + + +# ── TestModerationTrace ─────────────────────────────────────────────────────── + + +class TestModerationTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """moderation_trace returns early when message_data is None.""" + trace_info = _make_moderation_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_moderation_trace(self, trace_instance): + """moderation_trace creates a run with correct outputs.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=_dt(), + end_time=_dt() + timedelta(seconds=1), + action="block", + flagged=True, + preset_response="blocked", + ) + trace_instance.moderation_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.outputs["action"] == "block" + assert run.outputs["flagged"] is True + + def test_moderation_trace_with_no_times_uses_message_data_times(self, trace_instance): + """When start/end times are None, uses message_data created_at/updated_at.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + msg_data.updated_at = _dt() + timedelta(seconds=1) + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + start_time=None, + end_time=None, + ) + trace_instance.moderation_trace(trace_info) + trace_instance.start_call.assert_called_once() + + def test_moderation_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + msg_data = MagicMock() + msg_data.created_at = _dt() + + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_moderation_trace_info( + message_data=msg_data, + trace_id=None, + ) + trace_instance.moderation_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestSuggestedQuestionTrace ──────────────────────────────────────────────── + + +class TestSuggestedQuestionTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """suggested_question_trace returns early when message_data is None.""" + trace_info = _make_suggested_question_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.suggested_question_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_suggested_question_trace(self, trace_instance): + """suggested_question_trace creates a run parented to trace_id.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id="t-1") + trace_instance.suggested_question_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_suggested_question_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_suggested_question_trace_info(trace_id=None) + trace_instance.suggested_question_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestDatasetRetrievalTrace ───────────────────────────────────────────────── + + +class TestDatasetRetrievalTrace: + def test_returns_early_when_no_message_data(self, trace_instance): + """dataset_retrieval_trace returns early when message_data is None.""" + trace_info = _make_dataset_retrieval_trace_info(message_data=None) + trace_instance.start_call = MagicMock() + trace_instance.dataset_retrieval_trace(trace_info) + trace_instance.start_call.assert_not_called() + + def test_basic_dataset_retrieval_trace(self, trace_instance): + """dataset_retrieval_trace creates a run with documents as outputs.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info( + documents=[{"id": "d1"}, {"id": "d2"}], + trace_id="t-1", + ) + trace_instance.dataset_retrieval_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + # WeaveTraceModel validator injects usage_metadata/file_list into dict outputs + assert run.outputs.get("documents") == [{"id": "d1"}, {"id": "d2"}] + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "t-1" + + def test_dataset_retrieval_trace_trace_id_fallback(self, trace_instance): + """trace_id falls back to message_id when trace_id is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_dataset_retrieval_trace_info(trace_id=None) + trace_instance.dataset_retrieval_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + +# ── TestToolTrace ───────────────────────────────────────────────────────────── + + +class TestToolTrace: + def test_basic_tool_trace(self, trace_instance): + """tool_trace creates a run with correct op as tool_name.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id="t-1") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.op == "my_tool" + # WeaveTraceModel validator injects usage_metadata/file_list into dict inputs + assert run.inputs.get("x") == 1 + + def test_tool_trace_with_file_url(self, trace_instance): + """tool_trace adds file_url to file_list when provided.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url="http://files/file.pdf") + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert "http://files/file.pdf" in run.file_list + + def test_tool_trace_without_file_url(self, trace_instance): + """tool_trace uses empty file_list when file_url is None.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(file_url=None) + trace_instance.tool_trace(trace_info) + + run = trace_instance.start_call.call_args[0][0] + assert run.file_list == [] + + def test_tool_trace_trace_id_from_message_id(self, trace_instance): + """trace_id uses message_id fallback.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None) + trace_instance.tool_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + assert kwargs.get("parent_run_id") == "msg-1" + + def test_tool_trace_message_id_none_uses_conversation_id(self, trace_instance): + """When message_id is None, tries conversation_id attribute.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_tool_trace_info(trace_id=None, message_id=None) + trace_instance.tool_trace(trace_info) + + # No crash; parent_run_id is None since no fallback + _, kwargs = trace_instance.start_call.call_args + # parent_run_id should be None when no message_id and no trace_id + assert kwargs.get("parent_run_id") is None + + +# ── TestGenerateNameTrace ───────────────────────────────────────────────────── + + +class TestGenerateNameTrace: + def test_basic_generate_name_trace(self, trace_instance): + """generate_name_trace creates a run with correct op.""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + trace_instance.start_call.assert_called_once() + trace_instance.finish_call.assert_called_once() + + run = trace_instance.start_call.call_args[0][0] + assert run.op == str(TraceTaskName.GENERATE_NAME_TRACE) + + def test_generate_name_trace_no_parent(self, trace_instance): + """generate_name_trace has no parent run (no parent_run_id).""" + trace_instance.start_call = MagicMock() + trace_instance.finish_call = MagicMock() + + trace_info = _make_generate_name_trace_info() + trace_instance.generate_name_trace(trace_info) + + _, kwargs = trace_instance.start_call.call_args + # No parent_run_id passed to generate_name start_call + assert kwargs == {} or kwargs.get("parent_run_id") is None + + +# ── TestApiCheck ────────────────────────────────────────────────────────────── + + +class TestApiCheck: + def test_api_check_success_without_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login without host.""" + trace_instance.host = None + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with(key=trace_instance.weave_api_key, verify=True, relogin=True) + + def test_api_check_success_with_host(self, trace_instance, mock_wandb): + """api_check returns True on successful login with host.""" + trace_instance.host = "https://my.wandb.host" + mock_wandb.login.return_value = True + + result = trace_instance.api_check() + + assert result is True + mock_wandb.login.assert_called_with( + key=trace_instance.weave_api_key, verify=True, relogin=True, host="https://my.wandb.host" + ) + + def test_api_check_login_failure_raises(self, trace_instance, mock_wandb): + """api_check raises ValueError when login returns False.""" + trace_instance.host = None + mock_wandb.login.return_value = False + + with pytest.raises(ValueError, match="Weave API check failed"): + trace_instance.api_check() + + def test_api_check_exception_raises_value_error(self, trace_instance, mock_wandb): + """api_check raises ValueError when wandb.login raises exception.""" + trace_instance.host = None + mock_wandb.login.side_effect = Exception("network error") + + with pytest.raises(ValueError, match="Weave API check failed: network error"): + trace_instance.api_check() diff --git a/api/tests/unit_tests/core/plugin/impl/__init__.py b/api/tests/unit_tests/core/plugin/impl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/plugin/impl/test_agent_client.py b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py new file mode 100644 index 0000000000..1537ffacf5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_agent_client.py @@ -0,0 +1,91 @@ +from types import SimpleNamespace + +from core.plugin.entities.request import PluginInvokeContext +from core.plugin.impl.agent import PluginAgentClient + + +def _agent_provider(name: str = "agent") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + strategies=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginAgentClient: + def test_fetch_agent_strategy_providers(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("remote") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "strategies": [{"identity": {"provider": "old"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["strategies"][0]["identity"]["provider"] == "remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.strategies[0].identity.provider == "org/plugin/remote" + + def test_fetch_agent_strategy_provider(self, mocker): + client = PluginAgentClient() + provider = _agent_provider("provider") + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + assert transformer({"data": None}) == {"data": None} + payload = {"data": {"declaration": {"strategies": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["strategies"][0]["identity"]["provider"] == "provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_agent_strategy_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.strategies[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks_and_passes_context(self, mocker): + client = PluginAgentClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["raw"]) + ) + merge_mock = mocker.patch("core.plugin.impl.agent.merge_blob_chunks", return_value=["merged"]) + context = PluginInvokeContext() + + result = client.invoke( + tenant_id="tenant-1", + user_id="user-1", + agent_provider="org/plugin/provider", + agent_strategy="router", + agent_params={"k": "v"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + context=context, + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + payload = stream_mock.call_args.kwargs["data"] + assert payload["data"]["agent_strategy_provider"] == "provider" + assert payload["context"] == context.model_dump() + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" diff --git a/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py new file mode 100644 index 0000000000..5f564062d5 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_asset_manager.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +import pytest + +from core.plugin.impl.asset import PluginAssetManager + + +class TestPluginAssetManager: + def test_fetch_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"asset-bytes") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.fetch_asset("tenant-1", "asset-1") + + assert result == b"asset-bytes" + request_mock.assert_called_once_with(method="GET", path="plugin/tenant-1/asset/asset-1") + + def test_fetch_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset asset-1"): + manager.fetch_asset("tenant-1", "asset-1") + + def test_extract_asset_success(self, mocker): + manager = PluginAssetManager() + response = MagicMock(status_code=200, content=b"file-content") + request_mock = mocker.patch.object(manager, "_request", return_value=response) + + result = manager.extract_asset("tenant-1", "org/plugin:1", "README.md") + + assert result == b"file-content" + request_mock.assert_called_once_with( + method="GET", + path="plugin/tenant-1/extract-asset/", + params={"plugin_unique_identifier": "org/plugin:1", "file_path": "README.md"}, + ) + + def test_extract_asset_not_found_raises(self, mocker): + manager = PluginAssetManager() + mocker.patch.object(manager, "_request", return_value=MagicMock(status_code=404, content=b"")) + + with pytest.raises(ValueError, match="can not found asset org/plugin:1, 404"): + manager.extract_asset("tenant-1", "org/plugin:1", "README.md") diff --git a/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py new file mode 100644 index 0000000000..c216906d68 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_base_client_impl.py @@ -0,0 +1,137 @@ +import json + +import pytest + +from core.plugin.endpoint.exc import EndpointSetupFailedError +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.base import BasePluginClient +from core.trigger.errors import ( + EventIgnoreError, + TriggerInvokeError, + TriggerPluginInvokeError, + TriggerProviderCredentialValidationError, +) + + +class _ResponseStub: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +class _StreamContext: + def __init__(self, lines): + self._lines = lines + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def iter_lines(self): + return self._lines + + +class TestBasePluginClientImpl: + def test_inject_trace_headers(self, mocker): + client = BasePluginClient() + mocker.patch("core.plugin.impl.base.dify_config.ENABLE_OTEL", True) + trace_header = "00-abc-xyz-01" + mocker.patch("core.helper.trace_id_helper.generate_traceparent_header", return_value=trace_header) + + headers = {} + client._inject_trace_headers(headers) + + assert headers["traceparent"] == trace_header + + headers_with_existing = {"TraceParent": "exists"} + client._inject_trace_headers(headers_with_existing) + assert headers_with_existing["TraceParent"] == "exists" + + def test_stream_request_handles_data_lines_and_dict_payload(self, mocker): + client = BasePluginClient() + stream_mock = mocker.patch( + "core.plugin.impl.base.httpx.stream", + return_value=_StreamContext([b"", b"data: hello", "world"]), + ) + + result = list(client._stream_request("POST", "plugin/tenant/stream", data={"k": "v"})) + + assert result == ["hello", "world"] + assert stream_mock.call_args.kwargs["data"] == {"k": "v"} + + def test_request_with_plugin_daemon_response_handles_request_exception(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", side_effect=RuntimeError("boom")) + + with pytest.raises(ValueError, match="Failed to request plugin daemon"): + client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool) + + def test_request_with_plugin_daemon_response_applies_transformer(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_request", return_value=_ResponseStub({"code": 0, "message": "", "data": True})) + + transformed = {} + + def transformer(payload): + transformed.update(payload) + return payload + + result = client._request_with_plugin_daemon_response("GET", "plugin/tenant/path", bool, transformer=transformer) + + assert result is True + assert transformed == {"code": 0, "message": "", "data": True} + + def test_request_with_plugin_daemon_response_stream_malformed_json_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"error":"bad-line"}'])) + + with pytest.raises(ValueError, match="bad-line"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_inner_error(self, mocker): + client = BasePluginClient() + mocker.patch.object( + client, "_stream_request", return_value=iter(['{"code":-500,"message":"not-json","data":null}']) + ) + + with pytest.raises(PluginDaemonInnerError) as exc_info: + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + assert exc_info.value.message == "not-json" + + def test_request_with_plugin_daemon_response_stream_plugin_daemon_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":-1,"message":"err","data":null}'])) + + with pytest.raises(ValueError, match="plugin daemon: err, code: -1"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + def test_request_with_plugin_daemon_response_stream_empty_data_error(self, mocker): + client = BasePluginClient() + mocker.patch.object(client, "_stream_request", return_value=iter(['{"code":0,"message":"","data":null}'])) + + with pytest.raises(ValueError, match="got empty data"): + list(client._request_with_plugin_daemon_response_stream("GET", "p", bool)) + + @pytest.mark.parametrize( + ("error_type", "expected"), + [ + (EndpointSetupFailedError.__name__, EndpointSetupFailedError), + (TriggerProviderCredentialValidationError.__name__, TriggerProviderCredentialValidationError), + (TriggerPluginInvokeError.__name__, TriggerPluginInvokeError), + (TriggerInvokeError.__name__, TriggerInvokeError), + (EventIgnoreError.__name__, EventIgnoreError), + ], + ) + def test_handle_plugin_daemon_error_trigger_branches(self, error_type, expected): + client = BasePluginClient() + message = json.dumps({"error_type": error_type, "message": "m"}) + + with pytest.raises(expected): + client._handle_plugin_daemon_error("PluginInvokeError", message) diff --git a/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py new file mode 100644 index 0000000000..4c5987d759 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_datasource_manager.py @@ -0,0 +1,234 @@ +from types import SimpleNamespace + +from core.datasource.entities.datasource_entities import ( + GetOnlineDocumentPageContentRequest, + OnlineDriveBrowseFilesRequest, + OnlineDriveDownloadFileRequest, +) +from core.plugin.impl.datasource import PluginDatasourceManager + + +def _datasource_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + datasources=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginDatasourceManager: + def test_fetch_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 2 + assert result[0].plugin_id == "langgenius/file" + assert result[1].declaration.identity.name == "org/plugin/remote" + assert result[1].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_installed_datasource_providers(self, mocker): + manager = PluginDatasourceManager() + provider = _datasource_provider("remote") + repack = mocker.patch("core.plugin.impl.datasource.ToolTransformService.repack_provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/doc"}}], + } + } + ] + } + transformer(payload) + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_installed_datasource_providers("tenant-1") + + assert request_mock.call_count == 1 + assert len(result) == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.datasources[0].identity.provider == "org/plugin/remote" + repack.assert_called_once_with(tenant_id="tenant-1", provider=provider) + + def test_fetch_datasource_provider_local_and_remote(self, mocker): + manager = PluginDatasourceManager() + + local = manager.fetch_datasource_provider("tenant-1", "langgenius/file/file") + assert local.plugin_id == "langgenius/file" + + remote = _datasource_provider("provider") + mocker.patch("core.plugin.impl.datasource.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": { + "datasources": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}] + } + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["datasources"][0]["output_schema"] == {"resolved": True} + return remote + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_datasource_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.datasources[0].identity.provider == "org/plugin/provider" + + def test_get_website_crawl_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["crawl"]) + + assert list( + manager.get_website_crawl( + "tenant-1", + "user-1", + "org/plugin/provider", + "crawl", + {"k": "v"}, + {"url": "https://example.com"}, + "website", + ) + ) == ["crawl"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_pages_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["pages"]) + + assert list( + manager.get_online_document_pages( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + {"workspace": "w1"}, + "online_document", + ) + ) == ["pages"] + + assert stream_mock.call_count == 1 + + def test_get_online_document_page_content_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["content"]) + + assert list( + manager.get_online_document_page_content( + "tenant-1", + "user-1", + "org/plugin/provider", + "docs", + {"k": "v"}, + GetOnlineDocumentPageContentRequest(workspace_id="w", page_id="p", type="doc"), + "online_document", + ) + ) == ["content"] + + assert stream_mock.call_count == 1 + + def test_online_drive_browse_files_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["browse"]) + + assert list( + manager.online_drive_browse_files( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveBrowseFilesRequest(prefix="/"), + "online_drive", + ) + ) == ["browse"] + + assert stream_mock.call_count == 1 + + def test_online_drive_download_file_streaming(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter(["download"]) + + assert list( + manager.online_drive_download_file( + "tenant-1", + "user-1", + "org/plugin/provider", + "drive", + {"k": "v"}, + OnlineDriveDownloadFileRequest(id="file-1"), + "online_drive", + ) + ) == ["download"] + + assert stream_mock.call_count == 1 + + def test_validate_provider_credentials_returns_true_when_stream_yields_result(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + + assert manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is True + + def test_validate_provider_credentials_returns_false_when_stream_empty(self, mocker): + manager = PluginDatasourceManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + stream_mock.return_value = iter([]) + + assert ( + manager.validate_provider_credentials("tenant-1", "user-1", "provider", "org/plugin", {"k": "v"}) is False + ) + + def test_local_file_provider_template(self): + manager = PluginDatasourceManager() + + payload = manager._get_local_file_datasource_provider() + + assert payload["plugin_id"] == "langgenius/file" + assert payload["provider"] == "file" + assert payload["declaration"]["provider_type"] == "local_file" diff --git a/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py new file mode 100644 index 0000000000..c80785aee0 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_debugging_client.py @@ -0,0 +1,21 @@ +from types import SimpleNamespace + +from core.plugin.impl.debugging import PluginDebuggingClient + + +class TestPluginDebuggingClient: + def test_get_debugging_key(self, mocker): + client = PluginDebuggingClient() + request_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response", + return_value=SimpleNamespace(key="debug-key"), + ) + + result = client.get_debugging_key("tenant-1") + + assert result == "debug-key" + request_mock.assert_called_once() + args = request_mock.call_args.args + assert args[0] == "POST" + assert args[1] == "plugin/tenant-1/debugging/key" diff --git a/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py new file mode 100644 index 0000000000..4cf657a050 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_endpoint_client_impl.py @@ -0,0 +1,71 @@ +import pytest + +from core.plugin.impl.endpoint import PluginEndpointClient +from core.plugin.impl.exc import PluginDaemonInternalServerError + + +class TestPluginEndpointClientImpl: + def test_create_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.create_endpoint("tenant-1", "user-1", "org/plugin:1", "endpoint-a", {"k": "v"}) + + assert result is True + assert request_mock.call_count == 1 + args = request_mock.call_args.args + kwargs = request_mock.call_args.kwargs + assert args[:3] == ("POST", "plugin/tenant-1/endpoint/setup", bool) + assert kwargs["data"]["plugin_unique_identifier"] == "org/plugin:1" + + def test_list_endpoints(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints("tenant-1", "user-1", 2, 20) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list" + assert request_mock.call_args.kwargs["params"] == {"page": 2, "page_size": 20} + + def test_list_endpoints_for_single_plugin(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["endpoint"]) + + result = client.list_endpoints_for_single_plugin("tenant-1", "user-1", "org/plugin", 1, 10) + + assert result == ["endpoint"] + assert request_mock.call_args.args[1] == "plugin/tenant-1/endpoint/list/plugin" + assert request_mock.call_args.kwargs["params"] == {"plugin_id": "org/plugin", "page": 1, "page_size": 10} + + def test_update_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + result = client.update_endpoint("tenant-1", "user-1", "endpoint-1", "renamed", {"x": 1}) + + assert result is True + assert request_mock.call_args.args[:3] == ("POST", "plugin/tenant-1/endpoint/update", bool) + + def test_enable_and_disable_endpoint(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=True) + + assert client.enable_endpoint("tenant-1", "user-1", "endpoint-1") is True + assert client.disable_endpoint("tenant-1", "user-1", "endpoint-1") is True + + calls = request_mock.call_args_list + assert calls[0].args[1] == "plugin/tenant-1/endpoint/enable" + assert calls[1].args[1] == "plugin/tenant-1/endpoint/disable" + + def test_delete_endpoint_idempotent_and_re_raise(self, mocker): + client = PluginEndpointClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response") + + request_mock.side_effect = PluginDaemonInternalServerError("record not found") + assert client.delete_endpoint("tenant-1", "user-1", "endpoint-1") is True + + request_mock.side_effect = PluginDaemonInternalServerError("permission denied") + with pytest.raises(PluginDaemonInternalServerError) as exc_info: + client.delete_endpoint("tenant-1", "user-1", "endpoint-1") + assert "permission denied" in exc_info.value.description diff --git a/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py new file mode 100644 index 0000000000..8c6f1c6b7f --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_exc_impl.py @@ -0,0 +1,41 @@ +import json + +from core.plugin.impl import exc as exc_module +from core.plugin.impl.exc import PluginDaemonError, PluginInvokeError + + +class TestPluginImplExceptions: + def test_plugin_daemon_error_str_contains_request_id(self, mocker): + mocker.patch("core.plugin.impl.exc.get_request_id", return_value="req-123") + error = PluginDaemonError("bad") + + assert str(error) == "req_id: req-123 PluginDaemonError: bad" + + def test_plugin_invoke_error_with_json_payload(self): + err = PluginInvokeError(json.dumps({"error_type": "RateLimit", "message": "too many"})) + + assert err.get_error_type() == "RateLimit" + assert err.get_error_message() == "too many" + friendly = err.to_user_friendly_error("test-plugin") + assert "test-plugin" in friendly + assert "RateLimit" in friendly + assert "too many" in friendly + + def test_plugin_invoke_error_invalid_json_and_fallback(self, mocker): + err = PluginInvokeError("plain text") + + assert err._get_error_object() == {} + assert err.get_error_type() == "unknown" + assert err.get_error_message() == "unknown" + + mocker.patch.object(PluginInvokeError, "_get_error_object", side_effect=RuntimeError("boom")) + err2 = PluginInvokeError("plain text") + assert err2.get_error_message() == "plain text" + + def test_plugin_invoke_error_get_error_object_handles_adapter_exception(self, mocker): + adapter = mocker.patch.object(exc_module, "TypeAdapter") + adapter.return_value.validate_json.side_effect = RuntimeError("invalid") + + err = PluginInvokeError("not-json") + + assert err._get_error_object() == {} diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_client.py b/api/tests/unit_tests/core/plugin/impl/test_model_client.py new file mode 100644 index 0000000000..bcbebbb38b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_client.py @@ -0,0 +1,490 @@ +from __future__ import annotations + +import io +from types import SimpleNamespace + +import pytest + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError +from core.plugin.impl.model import PluginModelClient + + +class TestPluginModelClient: + def test_fetch_model_providers(self, mocker): + client = PluginModelClient() + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", return_value=["provider-a"]) + + result = client.fetch_model_providers("tenant-1") + + assert result == ["provider-a"] + assert request_mock.call_args.args[:2] == ( + "GET", + "plugin/tenant-1/management/models", + ) + assert request_mock.call_args.kwargs["params"] == {"page": 1, "page_size": 256} + + def test_get_model_schema(self, mocker): + client = PluginModelClient() + schema = SimpleNamespace(name="schema") + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(model_schema=schema)]), + ) + + result = client.get_model_schema( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={"api_key": "key"}, + ) + + assert result is schema + assert stream_mock.call_args.args[:2] == ("POST", "plugin/tenant-1/dispatch/model/schema") + + def test_get_model_schema_empty_stream_returns_none(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + result = client.get_model_schema("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + + assert result is None + + def test_validate_provider_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"api_key": "new"})]), + ) + credentials = {"api_key": "old"} + + result = client.validate_provider_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + credentials=credentials, + ) + + assert result is True + assert credentials["api_key"] == "new" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_provider_credentials", + ) + + def test_validate_provider_credentials_without_dict_update(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=False, credentials="not-a-dict")]), + ) + credentials = {"api_key": "same"} + + result = client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", credentials) + + assert result is False + assert credentials == {"api_key": "same"} + + def test_validate_provider_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", {}) is False + + def test_validate_model_credentials(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True, credentials={"token": "rotated"})]), + ) + credentials = {"token": "old"} + + result = client.validate_model_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials=credentials, + ) + + assert result is True + assert credentials["token"] == "rotated" + assert stream_mock.call_args.args[:2] == ( + "POST", + "plugin/tenant-1/dispatch/model/validate_model_credentials", + ) + + def test_validate_model_credentials_empty_returns_false(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.validate_model_credentials("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}) + is False + ) + + def test_invoke_llm(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk-1"]) + ) + + result = list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={"api_key": "key"}, + prompt_messages=[], + model_parameters={"temperature": 0.1}, + tools=[], + stop=["STOP"], + stream=False, + ) + ) + + assert result == ["chunk-1"] + call_kwargs = stream_mock.call_args.kwargs + assert call_kwargs["path"] == "plugin/tenant-1/dispatch/llm/invoke" + assert call_kwargs["data"]["data"]["stream"] is False + assert call_kwargs["data"]["data"]["model_parameters"] == {"temperature": 0.1} + + def test_invoke_llm_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-500, message="invoke failed") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="invoke failed-500"): + list( + client.invoke_llm( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="gpt-test", + credentials={}, + prompt_messages=[], + ) + ) + + def test_get_llm_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=42)]), + ) + + result = client.get_llm_num_tokens( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model_type="llm", + model="gpt-test", + credentials={}, + prompt_messages=[], + tools=[], + ) + + assert result == 42 + + def test_get_llm_num_tokens_empty_returns_zero(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_llm_num_tokens("tenant-1", "user-1", "org/plugin:1", "provider-a", "llm", "gpt-test", {}, []) + == 0 + ) + + def test_invoke_text_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.1, 0.2]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_text_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + texts=["hello"], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_text_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke text embedding"): + client.invoke_text_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["hello"], "x" + ) + + def test_invoke_multimodal_embedding(self, mocker): + client = PluginModelClient() + embedding_result = SimpleNamespace(data=[[0.3, 0.4]]) + mocker.patch.object( + client, "_request_with_plugin_daemon_response_stream", return_value=iter([embedding_result]) + ) + + result = client.invoke_multimodal_embedding( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="embedding-a", + credentials={}, + documents=[{"type": "image", "value": "abc"}], + input_type="search_document", + ) + + assert result is embedding_result + + def test_invoke_multimodal_embedding_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke file embedding"): + client.invoke_multimodal_embedding( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, [{"type": "image"}], "x" + ) + + def test_get_text_embedding_num_tokens(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(num_tokens=[1, 2, 3])]), + ) + + assert client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) == [ + 1, + 2, + 3, + ] + + def test_get_text_embedding_num_tokens_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert ( + client.get_text_embedding_num_tokens( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "embedding-a", {}, ["a"] + ) + == [] + ) + + def test_invoke_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.9]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query="q", + docs=["doc-1"], + score_threshold=0.2, + top_n=5, + ) + + assert result is rerank_result + + def test_invoke_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke rerank"): + client.invoke_rerank("tenant-1", "user-1", "org/plugin:1", "provider-a", "rerank-a", {}, "q", ["doc-1"]) + + def test_invoke_multimodal_rerank(self, mocker): + client = PluginModelClient() + rerank_result = SimpleNamespace(scores=[0.8]) + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([rerank_result])) + + result = client.invoke_multimodal_rerank( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="rerank-a", + credentials={}, + query={"type": "text", "value": "q"}, + docs=[{"type": "image", "value": "doc"}], + score_threshold=0.1, + top_n=3, + ) + + assert result is rerank_result + + def test_invoke_multimodal_rerank_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke multimodal rerank"): + client.invoke_multimodal_rerank( + "tenant-1", + "user-1", + "org/plugin:1", + "provider-a", + "rerank-a", + {}, + {"type": "text"}, + [{"type": "image"}], + ) + + def test_invoke_tts(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="68656c6c6f"), SimpleNamespace(result="21")]), + ) + + result = list( + client.invoke_tts( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + content_text="hello", + voice="alloy", + ) + ) + + assert result == [b"hello", b"!"] + + def test_invoke_tts_wraps_plugin_daemon_inner_error(self, mocker): + client = PluginModelClient() + + def _boom(): + raise PluginDaemonInnerError(code=-400, message="tts error") + yield # pragma: no cover + + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=_boom()) + + with pytest.raises(ValueError, match="tts error-400"): + list(client.invoke_tts("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}, "hello", "alloy")) + + def test_get_tts_model_voices(self, mocker): + client = PluginModelClient() + mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter( + [ + SimpleNamespace( + voices=[ + SimpleNamespace(name="Alloy", value="alloy"), + SimpleNamespace(name="Echo", value="echo"), + ] + ) + ] + ), + ) + + result = client.get_tts_model_voices( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="tts-a", + credentials={}, + language="en", + ) + + assert result == [{"name": "Alloy", "value": "alloy"}, {"name": "Echo", "value": "echo"}] + + def test_get_tts_model_voices_empty_returns_list(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + assert client.get_tts_model_voices("tenant-1", "user-1", "org/plugin:1", "provider-a", "tts-a", {}) == [] + + def test_invoke_speech_to_text(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result="transcribed text")]), + ) + + result = client.invoke_speech_to_text( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="stt-a", + credentials={}, + file=io.BytesIO(b"abc"), + ) + + assert result == "transcribed text" + assert stream_mock.call_args.kwargs["data"]["data"]["file"] == "616263" + + def test_invoke_speech_to_text_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke speech to text"): + client.invoke_speech_to_text( + "tenant-1", "user-1", "org/plugin:1", "provider-a", "stt-a", {}, io.BytesIO(b"abc") + ) + + def test_invoke_moderation(self, mocker): + client = PluginModelClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(result=True)]), + ) + + result = client.invoke_moderation( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin:1", + provider="provider-a", + model="moderation-a", + credentials={}, + text="safe text", + ) + + assert result is True + assert stream_mock.call_args.kwargs["path"] == "plugin/tenant-1/dispatch/moderation/invoke" + + def test_invoke_moderation_empty_raises(self, mocker): + client = PluginModelClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Failed to invoke moderation"): + client.invoke_moderation("tenant-1", "user-1", "org/plugin:1", "provider-a", "moderation-a", {}, "unsafe") diff --git a/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py new file mode 100644 index 0000000000..6fb4c99432 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_oauth_handler.py @@ -0,0 +1,147 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.impl.oauth import OAuthHandler + + +def _build_request(body: bytes = b"payload") -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/oauth/callback", + "QUERY_STRING": "code=123", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(body), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": str(len(body)), + "HTTP_HOST": "localhost", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_X_TEST": "yes", + } + return Request(environ) + + +class TestOAuthHandler: + def test_get_authorization_url(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(authorization_url="https://auth.example.com")]), + ) + + response = handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + ) + + assert response.authorization_url == "https://auth.example.com" + assert stream_mock.call_count == 1 + + def test_get_authorization_url_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting authorization URL"): + handler.get_authorization_url( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + ) + + def test_get_credentials(self, mocker): + handler = OAuthHandler() + captured_data = {} + + def fake_stream(*args, **kwargs): + captured_data.update(kwargs["data"]) + return iter([SimpleNamespace(credentials={"token": "abc"}, metadata={}, expires_at=1)]) + + stream_mock = mocker.patch.object( + handler, "_request_with_plugin_daemon_response_stream", side_effect=fake_stream + ) + + response = handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + request=_build_request(), + ) + + assert response.credentials == {"token": "abc"} + assert "raw_http_request" in captured_data["data"] + assert stream_mock.call_count == 1 + + def test_get_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error getting credentials"): + handler.get_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + request=_build_request(), + ) + + def test_refresh_credentials(self, mocker): + handler = OAuthHandler() + stream_mock = mocker.patch.object( + handler, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(credentials={"token": "new"}, metadata={}, expires_at=1)]), + ) + + response = handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={"client_id": "id"}, + credentials={"refresh_token": "r"}, + ) + + assert response.credentials == {"token": "new"} + assert stream_mock.call_count == 1 + + def test_refresh_credentials_no_response_raises(self, mocker): + handler = OAuthHandler() + mocker.patch.object(handler, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="Error refreshing credentials"): + handler.refresh_credentials( + tenant_id="tenant-1", + user_id="user-1", + plugin_id="org/plugin", + provider="provider", + redirect_uri="https://dify.example.com/callback", + system_credentials={}, + credentials={}, + ) + + def test_convert_request_to_raw_data(self): + handler = OAuthHandler() + request = _build_request(b"body-data") + + raw = handler._convert_request_to_raw_data(request) + + assert raw.startswith(b"POST /oauth/callback?code=123 HTTP/1.1\r\n") + assert b"X-Test: yes\r\n" in raw + assert raw.endswith(b"body-data") diff --git a/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py new file mode 100644 index 0000000000..80cf46f9bb --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_tool_manager.py @@ -0,0 +1,121 @@ +from types import SimpleNamespace + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.tool import PluginToolManager + + +def _tool_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + tools=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +class TestPluginToolManager: + def test_fetch_tool_providers(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("remote") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "declaration": { + "identity": {"name": "remote"}, + "tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}], + } + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return [provider] + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.tools[0].identity.provider == "org/plugin/remote" + + def test_fetch_tool_provider(self, mocker): + manager = PluginToolManager() + provider = _tool_provider("provider") + mocker.patch("core.plugin.impl.tool.resolve_dify_schema_refs", return_value={"resolved": True}) + + def fake_request(method, path, type_, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": { + "declaration": {"tools": [{"identity": {"provider": "old"}, "output_schema": {"$ref": "#/x"}}]} + } + } + transformed = transformer(payload) + assert transformed["data"]["declaration"]["tools"][0]["output_schema"] == {"resolved": True} + return provider + + request_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = manager.fetch_tool_provider("tenant-1", "org/plugin/provider") + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.tools[0].identity.provider == "org/plugin/provider" + + def test_invoke_merges_chunks(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object( + manager, "_request_with_plugin_daemon_response_stream", return_value=iter(["chunk"]) + ) + merge_mock = mocker.patch("core.plugin.impl.tool.merge_blob_chunks", return_value=["merged"]) + + result = manager.invoke( + tenant_id="tenant-1", + user_id="user-1", + tool_provider="org/plugin/provider", + tool_name="search", + credentials={"api_key": "k"}, + credential_type=CredentialType.API_KEY, + tool_parameters={"q": "python"}, + conversation_id="conv-1", + app_id="app-1", + message_id="msg-1", + ) + + assert result == ["merged"] + assert merge_mock.call_count == 1 + assert stream_mock.call_args.kwargs["headers"]["X-Plugin-ID"] == "org/plugin" + + def test_validate_credentials_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + assert manager.validate_datasource_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is False + + def test_get_runtime_parameters_paths(self, mocker): + manager = PluginToolManager() + stream_mock = mocker.patch.object(manager, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(parameters=[{"name": "p"}])]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [{"name": "p"}] + + stream_mock.return_value = iter([]) + params = manager.get_runtime_parameters("tenant-1", "user-1", "org/plugin/provider", {}, "search") + assert params == [] diff --git a/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py new file mode 100644 index 0000000000..76da51c2c8 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_trigger_client.py @@ -0,0 +1,226 @@ +from io import BytesIO +from types import SimpleNamespace + +import pytest +from werkzeug import Request + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.impl.trigger import PluginTriggerClient +from core.trigger.entities.entities import Subscription +from models.provider_ids import TriggerProviderID + + +def _request() -> Request: + environ = { + "REQUEST_METHOD": "POST", + "PATH_INFO": "/events", + "QUERY_STRING": "", + "SERVER_NAME": "localhost", + "SERVER_PORT": "80", + "wsgi.input": BytesIO(b"payload"), + "wsgi.url_scheme": "http", + "CONTENT_LENGTH": "7", + "HTTP_HOST": "localhost", + } + return Request(environ) + + +def _subscription() -> Subscription: + return Subscription(expires_at=123, endpoint="https://example.com/hook", parameters={"a": 1}, properties={"p": 1}) + + +def _trigger_provider(name: str = "provider") -> SimpleNamespace: + return SimpleNamespace( + plugin_id="org/plugin", + declaration=SimpleNamespace( + identity=SimpleNamespace(name=name), + events=[SimpleNamespace(identity=SimpleNamespace(provider=""))], + ), + ) + + +def _subscription_call_kwargs(method_name: str) -> dict: + if method_name == "subscribe": + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + "endpoint": "https://example.com/hook", + "parameters": {"k": "v"}, + } + + return { + "tenant_id": "tenant-1", + "user_id": "user-1", + "provider": "org/plugin/provider", + "subscription": _subscription(), + "credentials": {"token": "x"}, + "credential_type": CredentialType.API_KEY, + } + + +class TestPluginTriggerClient: + def test_fetch_trigger_providers(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("remote") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = { + "data": [ + { + "plugin_id": "org/plugin", + "provider": "remote", + "declaration": {"events": [{"identity": {"provider": "old"}}]}, + } + ] + } + transformed = transformer(payload) + assert transformed["data"][0]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/remote" + return [provider] + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_providers("tenant-1") + + assert request_mock.call_count == 1 + assert result[0].declaration.identity.name == "org/plugin/remote" + assert result[0].declaration.events[0].identity.provider == "org/plugin/remote" + + def test_fetch_trigger_provider(self, mocker): + client = PluginTriggerClient() + provider = _trigger_provider("provider") + + def fake_request(*args, **kwargs): + transformer = kwargs["transformer"] + payload = {"data": {"declaration": {"events": [{"identity": {"provider": "old"}}]}}} + transformed = transformer(payload) + assert transformed["data"]["declaration"]["events"][0]["identity"]["provider"] == "org/plugin/provider" + return provider + + request_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response", side_effect=fake_request) + + result = client.fetch_trigger_provider("tenant-1", TriggerProviderID("org/plugin/provider")) + + assert request_mock.call_count == 1 + assert result.declaration.identity.name == "org/plugin/provider" + assert result.declaration.events[0].identity.provider == "org/plugin/provider" + + def test_invoke_trigger_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(variables={"ok": True}, cancelled=False)]), + ) + + result = client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + assert result.variables == {"ok": True} + assert stream_mock.call_count == 1 + + def test_invoke_trigger_event_no_response_raises(self, mocker): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + + with pytest.raises(ValueError, match="No response received from plugin daemon for invoke trigger"): + client.invoke_trigger_event( + tenant_id="tenant-1", + user_id="user-1", + provider="org/plugin/provider", + event_name="created", + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + request=_request(), + parameters={"k": "v"}, + subscription=_subscription(), + payload={"payload": 1}, + ) + + def test_validate_provider_credentials(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object(client, "_request_with_plugin_daemon_response_stream") + + stream_mock.return_value = iter([SimpleNamespace(result=True)]) + assert client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) is True + + stream_mock.return_value = iter([]) + with pytest.raises( + ValueError, match="No response received from plugin daemon for validate provider credentials" + ): + client.validate_provider_credentials("tenant-1", "user-1", "org/plugin/provider", {"k": "v"}) + + def test_dispatch_event(self, mocker): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(user_id="u", events=["e"])]), + ) + + result = client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result.user_id == "u" + assert stream_mock.call_count == 1 + + stream_mock.return_value = iter([]) + with pytest.raises(ValueError, match="No response received from plugin daemon for dispatch event"): + client.dispatch_event( + tenant_id="tenant-1", + provider="org/plugin/provider", + subscription={"id": "sub"}, + request=_request(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + @pytest.mark.parametrize("method_name", ["subscribe", "unsubscribe", "refresh"]) + def test_subscription_operations_success(self, mocker, method_name): + client = PluginTriggerClient() + stream_mock = mocker.patch.object( + client, + "_request_with_plugin_daemon_response_stream", + return_value=iter([SimpleNamespace(subscription={"id": "sub"})]), + ) + + method = getattr(client, method_name) + result = method(**_subscription_call_kwargs(method_name)) + + assert result.subscription == {"id": "sub"} + assert stream_mock.call_count == 1 + + @pytest.mark.parametrize( + ("method_name", "expected"), + [ + ("subscribe", "No response received from plugin daemon for subscribe"), + ("unsubscribe", "No response received from plugin daemon for unsubscribe"), + ("refresh", "No response received from plugin daemon for refresh"), + ], + ) + def test_subscription_operations_no_response(self, mocker, method_name, expected): + client = PluginTriggerClient() + mocker.patch.object(client, "_request_with_plugin_daemon_response_stream", return_value=iter([])) + method = getattr(client, method_name) + + with pytest.raises(ValueError, match=expected): + method(**_subscription_call_kwargs(method_name)) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index a380149554..c2778f082b 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -1,72 +1,359 @@ +import json from types import SimpleNamespace from unittest.mock import MagicMock +import pytest +from pydantic import BaseModel + from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from models.model import AppMode -def test_invoke_chat_app_advanced_chat_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" +class _Chunk(BaseModel): + value: int - app = MagicMock() - app.mode = AppMode.ADVANCED_CHAT - app.workflow = workflow - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), +class TestBaseBackwardsInvocation: + def test_convert_to_event_stream_with_generator_and_error(self): + def _stream(): + yield _Chunk(value=1) + yield {"x": 2} + yield "ignored" + raise RuntimeError("boom") + + chunks = list(BaseBackwardsInvocation.convert_to_event_stream(_stream())) + + assert len(chunks) == 3 + first = json.loads(chunks[0].decode()) + second = json.loads(chunks[1].decode()) + error = json.loads(chunks[2].decode()) + assert first["data"]["value"] == 1 + assert second["data"]["x"] == 2 + assert error["error"] == "boom" + + def test_convert_to_event_stream_with_non_generator(self): + chunks = list(BaseBackwardsInvocation.convert_to_event_stream({"ok": True})) + payload = json.loads(chunks[0].decode()) + assert payload["data"] == {"ok": True} + assert payload["error"] == "" + + +class TestPluginAppBackwardsInvocation: + def test_fetch_app_info_workflow_path(self, mocker): + workflow = MagicMock() + workflow.features_dict = {"feature": "v"} + workflow.user_input_form.return_value = [{"name": "foo"}] + app = MagicMock(mode=AppMode.WORKFLOW, workflow=workflow) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mapper = mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result == {"data": {"mapped": True}} + mapper.assert_called_once_with(features_dict={"feature": "v"}, user_input_form=[{"name": "foo"}]) + + def test_fetch_app_info_model_config_path(self, mocker): + model_config = MagicMock() + model_config.to_dict.return_value = {"user_input_form": [{"name": "bar"}], "k": "v"} + app = MagicMock(mode=AppMode.COMPLETION, app_model_config=model_config) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch( + "core.plugin.backwards_invocation.app.get_parameters_from_feature_dict", + return_value={"mapped": True}, + ) + + result = PluginAppBackwardsInvocation.fetch_app_info("app-1", "tenant-1") + + assert result["data"] == {"mapped": True} + + @pytest.mark.parametrize( + ("mode", "route_method"), + [ + (AppMode.CHAT, "invoke_chat_app"), + (AppMode.ADVANCED_CHAT, "invoke_chat_app"), + (AppMode.AGENT_CHAT, "invoke_chat_app"), + (AppMode.WORKFLOW, "invoke_workflow_app"), + (AppMode.COMPLETION, "invoke_completion_app"), + ], ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, + def test_invoke_app_routes_by_mode(self, mocker, mode, route_method): + app = MagicMock(mode=mode) + user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=user) + route = mocker.patch.object(PluginAppBackwardsInvocation, route_method, return_value={"routed": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="hello", + stream=False, + inputs={"x": 1}, + files=[], + ) + + assert result == {"routed": True} + assert route.call_count == 1 + + def test_invoke_app_uses_end_user_when_user_id_missing(self, mocker): + app = MagicMock(mode=AppMode.WORKFLOW) + end_user = MagicMock() + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=app) + get_or_create = mocker.patch( + "core.plugin.backwards_invocation.app.EndUserService.get_or_create_end_user", + return_value=end_user, + ) + route = mocker.patch.object(PluginAppBackwardsInvocation, "invoke_workflow_app", return_value={"ok": True}) + + result = PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="", + tenant_id="tenant", + conversation_id="", + query=None, + stream=True, + inputs={}, + files=[], + ) + + assert result == {"ok": True} + get_or_create.assert_called_once_with(app) + assert route.call_args.args[1] is end_user + + def test_invoke_app_missing_query_for_chat_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode=AppMode.CHAT)) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="missing query"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_app_unexpected_mode_raises(self, mocker): + mocker.patch.object(PluginAppBackwardsInvocation, "_get_app", return_value=MagicMock(mode="other")) + mocker.patch.object(PluginAppBackwardsInvocation, "_get_user", return_value=MagicMock()) + + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_app( + app_id="app", + user_id="user", + tenant_id="tenant", + conversation_id=None, + query="q", + stream=False, + inputs={}, + files=[], + ) + + @pytest.mark.parametrize( + ("mode", "generator_path"), + [ + (AppMode.AGENT_CHAT, "core.plugin.backwards_invocation.app.AgentChatAppGenerator.generate"), + (AppMode.CHAT, "core.plugin.backwards_invocation.app.ChatAppGenerator.generate"), + ], ) + def test_invoke_chat_app_agent_and_chat(self, mocker, mode, generator_path): + app = MagicMock(mode=mode, workflow=None) + spy = mocker.patch(generator_path, return_value={"result": "ok"}) - result = PluginAppBackwardsInvocation.invoke_chat_app( - app=app, - user=MagicMock(), - conversation_id="conv-1", - query="hello", - stream=False, - inputs={"k": "v"}, - files=[], - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + assert result == {"result": "ok"} + assert spy.call_count == 1 + def test_invoke_chat_app_advanced_chat_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" -def test_invoke_workflow_app_injects_pause_state_config(mocker): - workflow = MagicMock() - workflow.created_by = "owner-id" + app = MagicMock() + app.mode = AppMode.ADVANCED_CHAT + app.workflow = workflow - app = MagicMock() - app.mode = AppMode.WORKFLOW - app.workflow = workflow + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.AdvancedChatAppGenerator.generate", + return_value={"result": "ok"}, + ) - mocker.patch( - "core.plugin.backwards_invocation.app.db", - SimpleNamespace(engine=MagicMock()), - ) - generator_spy = mocker.patch( - "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) + result = PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={"k": "v"}, + files=[], + ) - result = PluginAppBackwardsInvocation.invoke_workflow_app( - app=app, - user=MagicMock(), - stream=False, - inputs={"k": "v"}, - files=[], - ) + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" - assert result == {"result": "ok"} - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert isinstance(pause_state_config, PauseStateLayerConfig) - assert pause_state_config.state_owner_user_id == "owner-id" + def test_invoke_chat_app_advanced_chat_without_workflow_raises(self): + app = MagicMock(mode=AppMode.ADVANCED_CHAT, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_chat_app_unexpected_mode_raises(self): + app = MagicMock(mode="invalid") + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_chat_app( + app=app, + user=MagicMock(), + conversation_id="conv-1", + query="hello", + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_workflow_app_injects_pause_state_config(self, mocker): + workflow = MagicMock() + workflow.created_by = "owner-id" + + app = MagicMock() + app.mode = AppMode.WORKFLOW + app.workflow = workflow + + mocker.patch( + "core.plugin.backwards_invocation.app.db", + SimpleNamespace(engine=MagicMock()), + ) + generator_spy = mocker.patch( + "core.plugin.backwards_invocation.app.WorkflowAppGenerator.generate", + return_value={"result": "ok"}, + ) + + result = PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={"k": "v"}, + files=[], + ) + + assert result == {"result": "ok"} + call_kwargs = generator_spy.call_args.kwargs + pause_state_config = call_kwargs.get("pause_state_config") + assert isinstance(pause_state_config, PauseStateLayerConfig) + assert pause_state_config.state_owner_user_id == "owner-id" + + def test_invoke_workflow_app_without_workflow_raises(self): + app = MagicMock(mode=AppMode.WORKFLOW, workflow=None) + with pytest.raises(ValueError, match="unexpected app type"): + PluginAppBackwardsInvocation.invoke_workflow_app( + app=app, + user=MagicMock(), + stream=False, + inputs={}, + files=[], + ) + + def test_invoke_completion_app(self, mocker): + spy = mocker.patch( + "core.plugin.backwards_invocation.app.CompletionAppGenerator.generate", return_value={"ok": 1} + ) + app = MagicMock(mode=AppMode.COMPLETION) + + result = PluginAppBackwardsInvocation.invoke_completion_app(app, MagicMock(), False, {"x": 1}, []) + + assert result == {"ok": 1} + assert spy.call_count == 1 + + def test_get_user_returns_end_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [MagicMock(id="end-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "end-user" + + def test_get_user_falls_back_to_account_user(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, MagicMock(id="account-user")] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + user = PluginAppBackwardsInvocation._get_user("uid") + assert user.id == "account-user" + + def test_get_user_raises_when_user_not_found(self, mocker): + session = MagicMock() + session.scalar.side_effect = [None, None] + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + mocker.patch("core.plugin.backwards_invocation.app.Session", return_value=session_ctx) + mocker.patch("core.plugin.backwards_invocation.app.db", SimpleNamespace(engine=MagicMock())) + + with pytest.raises(ValueError, match="user not found"): + PluginAppBackwardsInvocation._get_user("uid") + + def test_get_app_returns_app(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + app_obj = MagicMock(id="app") + query_chain.first.return_value = app_obj + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj + + def test_get_app_raises_when_missing(self, mocker): + query_chain = MagicMock() + query_chain.where.return_value = query_chain + query_chain.first.return_value = None + db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") + + def test_get_app_raises_when_query_fails(self, mocker): + db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + mocker.patch("core.plugin.backwards_invocation.app.db", db) + + with pytest.raises(ValueError, match="app not found"): + PluginAppBackwardsInvocation._get_app("app", "tenant") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py new file mode 100644 index 0000000000..b0b64a601b --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -0,0 +1,347 @@ +import binascii +import datetime +from enum import StrEnum + +import pytest +from flask import Response +from pydantic import ValidationError + +from core.plugin.entities.endpoint import EndpointEntityWithInstance +from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot +from core.plugin.entities.parameters import ( + PluginParameter, + PluginParameterOption, + PluginParameterType, + as_normal_type, + cast_parameter_value, + init_frontend_parameter, +) +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import ( + RequestInvokeLLM, + RequestInvokeSpeech2Text, + TriggerDispatchResponse, + TriggerInvokeEventResponse, +) +from core.plugin.utils.http_parser import serialize_response +from core.tools.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) + + +class TestEndpointEntity: + def test_endpoint_entity_with_instance_renders_url(self, mocker): + mocker.patch("core.plugin.entities.endpoint.dify_config.ENDPOINT_URL_TEMPLATE", "https://dify.test/{hook_id}") + now = datetime.datetime.now(datetime.UTC) + + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + } + ) + + assert entity.url == "https://dify.test/hook-123" + + def test_endpoint_entity_with_instance_keeps_existing_url(self): + now = datetime.datetime.now(datetime.UTC) + entity = EndpointEntityWithInstance.model_validate( + { + "id": "ep-1", + "created_at": now, + "updated_at": now, + "settings": {}, + "tenant_id": "tenant", + "plugin_id": "org/plugin", + "expired_at": now, + "name": "my-endpoint", + "enabled": True, + "hook_id": "hook-123", + "url": "https://preset.test/hook-123", + } + ) + assert entity.url == "https://preset.test/hook-123" + + +class TestMarketplaceEntities: + def test_marketplace_declaration_strips_empty_optional_fields(self): + declaration = MarketplacePluginDeclaration.model_validate( + { + "name": "plugin", + "org": "org", + "plugin_id": "org/plugin", + "icon": "icon.png", + "label": {"en_US": "Plugin"}, + "brief": {"en_US": "Brief"}, + "resource": {"memory": 256}, + "endpoint": {}, + "model": {}, + "tool": {}, + "latest_version": "1.0.0", + "latest_package_identifier": "org/plugin@1.0.0", + "status": "active", + "deprecated_reason": "", + "alternative_plugin_id": "", + } + ) + + assert declaration.endpoint is None + assert declaration.model is None + assert declaration.tool is None + + def test_marketplace_snapshot_computed_plugin_id(self): + snapshot = MarketplacePluginSnapshot( + org="langgenius", + name="search", + latest_version="1.0.0", + latest_package_identifier="langgenius/search@1.0.0", + latest_package_url="https://example.com/pkg", + ) + assert snapshot.plugin_id == "langgenius/search" + + +class TestPluginParameterEntities: + def _label(self) -> I18nObject: + return I18nObject(en_US="label") + + def test_parameter_option_value_casts_to_string(self): + option = PluginParameterOption(value=123, label=self._label()) + assert option.value == "123" + + def test_plugin_parameter_options_non_list_defaults_to_empty(self): + parameter = PluginParameter(name="p", label=self._label(), options="invalid") # type: ignore[arg-type] + assert parameter.options == [] + + @pytest.mark.parametrize( + ("parameter_type", "expected"), + [ + (PluginParameterType.SECRET_INPUT, "string"), + (PluginParameterType.SELECT, "string"), + (PluginParameterType.CHECKBOX, "string"), + (PluginParameterType.NUMBER, PluginParameterType.NUMBER.value), + ], + ) + def test_as_normal_type(self, parameter_type, expected): + assert as_normal_type(parameter_type) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [(None, ""), (1, "1"), ("abc", "abc")], + ) + def test_cast_parameter_value_string_like(self, value, expected): + assert cast_parameter_value(PluginParameterType.STRING, value) == expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (None, False), + ("true", True), + ("yes", True), + ("1", True), + ("false", False), + ("0", False), + ("random", True), + (1, True), + (0, False), + ], + ) + def test_cast_parameter_value_boolean(self, value, expected): + assert cast_parameter_value(PluginParameterType.BOOLEAN, value) is expected + + @pytest.mark.parametrize( + ("value", "expected"), + [ + (1, 1), + (1.5, 1.5), + ("2", 2), + ("2.5", 2.5), + ], + ) + def test_cast_parameter_value_number(self, value, expected): + assert cast_parameter_value(PluginParameterType.NUMBER, value) == expected + + def test_cast_parameter_value_file_and_files(self): + assert cast_parameter_value(PluginParameterType.FILES, "f1") == ["f1"] + assert cast_parameter_value(PluginParameterType.SYSTEM_FILES, ["f1", "f2"]) == ["f1", "f2"] + assert cast_parameter_value(PluginParameterType.FILE, ["one"]) == "one" + assert cast_parameter_value(PluginParameterType.FILE, "one") == "one" + with pytest.raises(ValueError, match="only accepts one file"): + cast_parameter_value(PluginParameterType.FILE, ["a", "b"]) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.MODEL_SELECTOR, {"m": "gpt"}, {"m": "gpt"}), + (PluginParameterType.APP_SELECTOR, {"app": "a"}, {"app": "a"}), + (PluginParameterType.TOOLS_SELECTOR, [], []), + (PluginParameterType.ANY, {"k": "v"}, {"k": "v"}), + ], + ) + def test_cast_parameter_value_selectors_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "message"), + [ + (PluginParameterType.MODEL_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.APP_SELECTOR, "bad", "selector must be a dictionary"), + (PluginParameterType.TOOLS_SELECTOR, "bad", "tools selector must be a list"), + (PluginParameterType.ANY, object(), "var selector must be"), + ], + ) + def test_cast_parameter_value_selectors_invalid(self, parameter_type, value, message): + with pytest.raises(ValueError, match=message): + cast_parameter_value(parameter_type, value) + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, [1, 2], [1, 2]), + (PluginParameterType.ARRAY, "[1, 2]", [1, 2]), + (PluginParameterType.OBJECT, {"k": "v"}, {"k": "v"}), + (PluginParameterType.OBJECT, '{"a":1}', {"a": 1}), + ], + ) + def test_cast_parameter_value_array_and_object_valid(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + @pytest.mark.parametrize( + ("parameter_type", "value", "expected"), + [ + (PluginParameterType.ARRAY, "bad-json", ["bad-json"]), + (PluginParameterType.OBJECT, "bad-json", {}), + ], + ) + def test_cast_parameter_value_array_and_object_invalid_json_fallback(self, parameter_type, value, expected): + assert cast_parameter_value(parameter_type, value) == expected + + def test_cast_parameter_value_default_branch_and_wrapped_exception(self): + class _Unknown(StrEnum): + CUSTOM = "custom" + + assert cast_parameter_value(_Unknown.CUSTOM, 12) == "12" + + class _BadString: + def __str__(self): + raise RuntimeError("boom") + + with pytest.raises( + ValueError, + match=r"The tool parameter value <.*_BadString object at .* is not in correct type of string\.", + ): + cast_parameter_value(PluginParameterType.STRING, _BadString()) + + def test_init_frontend_parameter(self): + rule = PluginParameter( + name="choice", + label=self._label(), + required=True, + default="a", + options=[PluginParameterOption(value="a", label=self._label())], + ) + + assert init_frontend_parameter(rule, PluginParameterType.SELECT, None) == "a" + assert init_frontend_parameter(rule, PluginParameterType.NUMBER, 0) == 0 + with pytest.raises(ValueError, match="not in options"): + init_frontend_parameter(rule, PluginParameterType.SELECT, "b") + + required_rule = PluginParameter(name="required", label=self._label(), required=True, default=None) + with pytest.raises(ValueError, match="not found in tool config"): + init_frontend_parameter(required_rule, PluginParameterType.STRING, None) + + +class TestPluginDaemonEntities: + def test_credential_type_helpers(self): + assert CredentialType.API_KEY.get_name() == "API KEY" + assert CredentialType.OAUTH2.get_name() == "AUTH" + assert CredentialType.UNAUTHORIZED.get_name() == "UNAUTHORIZED" + + class _FakeCredential: + value = "custom-type" + + assert CredentialType.get_name(_FakeCredential()) == "CUSTOM TYPE" + assert CredentialType.API_KEY.is_editable() is True + assert CredentialType.OAUTH2.is_editable() is False + assert CredentialType.API_KEY.is_validate_allowed() is True + assert CredentialType.UNAUTHORIZED.is_validate_allowed() is False + assert set(CredentialType.values()) == {"api-key", "oauth2", "unauthorized"} + + @pytest.mark.parametrize( + ("raw", "expected"), + [ + ("api-key", CredentialType.API_KEY), + ("api_key", CredentialType.API_KEY), + ("oauth2", CredentialType.OAUTH2), + ("oauth", CredentialType.OAUTH2), + ("unauthorized", CredentialType.UNAUTHORIZED), + ], + ) + def test_credential_type_of(self, raw, expected): + assert CredentialType.of(raw) == expected + + def test_credential_type_of_invalid(self): + with pytest.raises(ValueError, match="Invalid credential type"): + CredentialType.of("invalid") + + +class TestPluginRequestEntities: + def test_request_invoke_llm_converts_prompt_messages(self): + payload = RequestInvokeLLM( + provider="openai", + model="gpt-4", + mode="chat", + prompt_messages=[ + {"role": "user", "content": "u"}, + {"role": "assistant", "content": "a"}, + {"role": "system", "content": "s"}, + {"role": "tool", "content": "t", "tool_call_id": "call-1"}, + ], + ) + + assert isinstance(payload.prompt_messages[0], UserPromptMessage) + assert isinstance(payload.prompt_messages[1], AssistantPromptMessage) + assert isinstance(payload.prompt_messages[2], SystemPromptMessage) + assert isinstance(payload.prompt_messages[3], ToolPromptMessage) + + def test_request_invoke_llm_prompt_messages_must_be_list(self): + with pytest.raises(ValidationError): + RequestInvokeLLM(provider="openai", model="gpt-4", mode="chat", prompt_messages="invalid") # type: ignore[arg-type] + + def test_request_invoke_speech2text_hex_conversion_and_error(self): + payload = RequestInvokeSpeech2Text(provider="openai", model="m", file=binascii.hexlify(b"abc").decode()) + assert payload.file == b"abc" + with pytest.raises(ValidationError): + RequestInvokeSpeech2Text(provider="openai", model="m", file=b"abc") # type: ignore[arg-type] + + def test_trigger_invoke_event_response_variables_conversion(self): + converted = TriggerInvokeEventResponse(variables='{"a": 1}', cancelled=False) + assert converted.variables == {"a": 1} + passthrough = TriggerInvokeEventResponse(variables={"b": 2}, cancelled=True) + assert passthrough.variables == {"b": 2} + + def test_trigger_dispatch_response_convert_response(self): + response = Response("ok", status=202, headers={"X-Req": "1"}) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.response.status_code == 202 + assert parsed.response.get_data() == b"ok" + with pytest.raises(ValidationError): + TriggerDispatchResponse(user_id="u", events=["e"], response="not-hex") + + def test_trigger_dispatch_response_payload_default(self): + response = Response("ok", status=200) + encoded = binascii.hexlify(serialize_response(response)).decode() + parsed = TriggerDispatchResponse(user_id="u", events=["e"], response=encoded) + assert parsed.payload == {} diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 9e871fcb74..4f038d4a5b 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -19,14 +19,6 @@ import httpx import pytest from pydantic import BaseModel -from core.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.entities.plugin_daemon import ( CredentialType, PluginDaemonInnerError, @@ -44,6 +36,14 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index e0eace0f2d..c7e94aa4cf 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -4,7 +4,10 @@ import pytest from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks -from core.tools.entities.tool_entities import ToolInvokeMessage +from core.plugin.utils.converter import convert_parameters_to_plugin_format +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File class TestChunkMerger: @@ -458,3 +461,89 @@ class TestChunkMerger: assert len(result) == 1 assert isinstance(result[0].message, ToolInvokeMessage.BlobMessage) assert result[0].message.blob == b"FirstSecondThird" + + +class TestConverter: + def test_convert_parameters_to_plugin_format_with_single_file_and_selector(self): + file_param = File( + tenant_id="tenant-1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/file.png", + storage_key="", + ) + selector = ToolSelector( + provider_id="org/plugin/provider", + credential_id=None, + tool_name="search", + tool_description="search tool", + tool_configuration={"k": "v"}, + tool_parameters={ + "query": ToolSelector.Parameter( + name="query", + type=ToolParameter.ToolParameterType.STRING, + required=True, + description="query", + default="python", + options=[], + ) + }, + ) + params = {"file": file_param, "selector": selector, "plain": 123} + + converted = convert_parameters_to_plugin_format(params) + + assert converted["file"]["url"] == "https://example.com/file.png" + assert converted["selector"]["provider_id"] == "org/plugin/provider" + assert converted["plain"] == 123 + + def test_convert_parameters_to_plugin_format_with_lists_and_passthrough_values(self): + file_one = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/a.txt", + storage_key="", + ) + file_two = File( + tenant_id="tenant-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/b.txt", + storage_key="", + ) + selector_one = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-1", + tool_name="t1", + tool_description="tool 1", + tool_configuration={}, + tool_parameters={}, + ) + selector_two = ToolSelector( + provider_id="org/plugin/provider", + credential_id="cred-2", + tool_name="t2", + tool_description="tool 2", + tool_configuration={}, + tool_parameters={}, + ) + + params = { + "files": [file_one, file_two], + "selectors": [selector_one, selector_two], + "empty_list": [], + "mixed_list": [file_one, "raw"], + "none_value": None, + } + + converted = convert_parameters_to_plugin_format(params) + + assert [item["url"] for item in converted["files"]] == [ + "https://example.com/a.txt", + "https://example.com/b.txt", + ] + assert [item["tool_name"] for item in converted["selectors"]] == ["t1", "t2"] + assert converted["empty_list"] == [] + assert converted["mixed_list"] == [file_one, "raw"] + assert converted["none_value"] is None diff --git a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py index 1c2e0c96f8..71144695bc 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_http_parser.py +++ b/api/tests/unit_tests/core/plugin/utils/test_http_parser.py @@ -381,6 +381,54 @@ class TestEdgeCases: assert response.status_code == 200 assert response.get_data() == binary_body + def test_deserialize_request_with_lf_only_newlines(self): + raw_data = b"POST /lf-only?x=1 HTTP/1.1\nHost: localhost\nX-Test: yes\n\npayload" + + request = deserialize_request(raw_data) + + assert request.method == "POST" + assert request.path == "/lf-only" + assert request.args.get("x") == "1" + assert request.headers.get("X-Test") == "yes" + assert request.get_data() == b"payload" + + def test_deserialize_request_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"GET /no-separator HTTP/1.1\nHost: localhost\nInvalidHeader\n" + + request = deserialize_request(raw_data) + + assert request.method == "GET" + assert request.path == "/no-separator" + assert request.headers.get("Host") == "localhost" + assert request.headers.get("InvalidHeader") is None + + def test_deserialize_request_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP request"): + deserialize_request(b"") + + def test_deserialize_response_with_lf_only_newlines(self): + raw_data = b"HTTP/1.1 202 Accepted\nX-Test: yes\n\nbody" + + response = deserialize_response(raw_data) + + assert response.status_code == 202 + assert response.headers.get("X-Test") == "yes" + assert response.get_data() == b"body" + + def test_deserialize_response_without_header_separator_uses_full_input_as_headers(self): + raw_data = b"HTTP/1.1 204 No Content\nX-Test: yes\nInvalidHeader\n" + + response = deserialize_response(raw_data) + + assert response.status_code == 204 + assert response.headers.get("X-Test") == "yes" + assert response.headers.get("InvalidHeader") is None + assert response.get_data() == b"" + + def test_deserialize_response_empty_payload_raises(self): + with pytest.raises(ValueError, match="Empty HTTP response"): + deserialize_response(b"") + class TestFileUploads: def test_serialize_request_with_text_file_upload(self): diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 1d25639343..3d08525aba 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -1,3 +1,4 @@ +from typing import cast from unittest.mock import MagicMock, patch import pytest @@ -5,16 +6,18 @@ import pytest from configs import dify_config from core.app.app_config.entities import ModelConfigEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - UserPromptMessage, -) from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from core.workflow.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import Conversation @@ -142,7 +145,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("core.workflow.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) @@ -188,3 +191,328 @@ def get_chat_model_args(): context = "I am superman." return model_config_mock, memory_config, prompt_messages, inputs, context + + +def test_get_prompt_dispatches_completion_and_chat_and_invalid(): + transform = AdvancedPromptTransform() + model_config = MagicMock(spec=ModelConfigEntity) + completion_template = CompletionModelPromptTemplate(text="Hello {{name}}", edition_type="basic") + chat_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="basic")] + + transform._get_completion_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="c")]) + transform._get_chat_model_prompt_messages = MagicMock(return_value=[UserPromptMessage(content="h")]) + + completion_result = transform.get_prompt( + prompt_template=completion_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert completion_result[0].content == "c" + + chat_result = transform.get_prompt( + prompt_template=chat_template, + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert chat_result[0].content == "h" + + invalid_result = transform.get_prompt( + prompt_template=cast(list, ["not-chat-model-message"]), + inputs={"name": "john"}, + query="q", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config, + ) + assert invalid_result == [] + + +def test_completion_prompt_jinja2_with_files(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + completion_template = CompletionModelPromptTemplate(text="Hi {{name}}", edition_type="jinja2") + + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with ( + patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hi John"), + patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content, + ): + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_completion_model_prompt_messages( + prompt_template=completion_template, + inputs={"name": "John"}, + query="", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0].content, list) + assert messages[0].content[0].data == "https://example.com/image.jpg" + assert isinstance(messages[0].content[1], TextPromptMessageContent) + assert messages[0].content[1].data == "Hi John" + + +def test_completion_prompt_basic_sets_query_variable(): + model_config_mock = MagicMock(spec=ModelConfigEntity) + transform = AdvancedPromptTransform() + template = CompletionModelPromptTemplate(text="Q={{#query#}}", edition_type="basic") + + messages = transform._get_completion_model_prompt_messages( + prompt_template=template, + inputs={}, + query="what?", + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert messages[0].content == "Q=what?" + + +def test_chat_prompt_with_variable_template_and_context(): + transform = AdvancedPromptTransform(with_variable_tmpl=True) + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="sys={{#node.name#}} ctx={{#context#}}", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"#node.name#": "john"}, + query=None, + files=[], + context="context-text", + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert len(messages) == 1 + assert isinstance(messages[0], SystemPromptMessage) + assert messages[0].content == "sys=john ctx=context-text" + + +def test_chat_prompt_jinja2_branch_and_invalid_edition(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + prompt_template = [ChatModelMessage(text="Hello {{name}}", role=PromptMessageRole.USER, edition_type="jinja2")] + + with patch("core.prompt.advanced_prompt_transform.Jinja2Formatter.format", return_value="Hello John"): + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={"name": "John"}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert messages[0].content == "Hello John" + + bad_prompt_template = [ChatModelMessage.model_construct(text="bad", role=PromptMessageRole.USER, edition_type="x")] + with pytest.raises(ValueError, match="Invalid edition type"): + transform._get_chat_model_prompt_messages( + prompt_template=bad_prompt_template, + inputs={}, + query=None, + files=[], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + +def test_chat_prompt_query_template_and_query_only_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + query_prompt_template="query={{#sys.query#}} ctx={{#context#}}", + ) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="what", + files=[], + context="ctx", + memory_config=memory_config, + memory=None, + model_config=model_config_mock, + ) + assert messages[-1].content == "query={{#sys.query#}} ctx=ctx" + + +def test_chat_prompt_memory_with_files_and_query(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) + memory = MagicMock(spec=TokenBufferMemory) + prompt_template = [ChatModelMessage(text="sys", role=PromptMessageRole.SYSTEM)] + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + transform._append_chat_histories = MagicMock( + side_effect=lambda memory, memory_config, prompt_messages, **kwargs: prompt_messages + ) + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_template, + inputs={}, + query="q", + files=[file], + context=None, + memory_config=memory_config, + memory=memory, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "q" + + +def test_chat_prompt_files_without_query_updates_last_user_or_appends_new(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + prompt_with_last_user = [ChatModelMessage(text="u", role=PromptMessageRole.USER)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_with_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "u" + + prompt_without_last_user = [ChatModelMessage(text="s", role=PromptMessageRole.SYSTEM)] + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=prompt_without_last_user, + inputs={}, + query=None, + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + assert isinstance(messages[-1], UserPromptMessage) + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "" + + +def test_chat_prompt_files_with_query_branch(): + transform = AdvancedPromptTransform() + model_config_mock = MagicMock(spec=ModelConfigEntity) + file = File( + id="file1", + tenant_id="tenant1", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + storage_key="", + ) + + with patch("core.prompt.advanced_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.return_value = ImagePromptMessageContent( + url="https://example.com/image.jpg", format="jpg", mime_type="image/jpg" + ) + messages = transform._get_chat_model_prompt_messages( + prompt_template=[], + inputs={}, + query="query-text", + files=[file], + context=None, + memory_config=None, + memory=None, + model_config=model_config_mock, + ) + + assert isinstance(messages[-1].content, list) + assert messages[-1].content[1].data == "query-text" + + +def test_set_context_query_histories_variable_helpers(): + transform = AdvancedPromptTransform() + parser_context = PromptTemplateParser(template="{{#context#}}") + parser_query = PromptTemplateParser(template="{{#query#}}") + parser_hist = PromptTemplateParser(template="{{#histories#}}") + model_config_mock = MagicMock(spec=ModelConfigEntity) + memory_config = MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ) + + assert transform._set_context_variable(None, parser_context, {})["#context#"] == "" + assert transform._set_query_variable("", parser_query, {})["#query#"] == "" + assert transform._set_query_variable("x", parser_query, {})["#query#"] == "x" + assert ( + transform._set_histories_variable( + memory=None, # type: ignore[arg-type] + memory_config=memory_config, + raw_prompt="{{#histories#}}", + role_prefix=memory_config.role_prefix, # type: ignore[arg-type] + parser=parser_hist, + prompt_inputs={}, + model_config=model_config_mock, + )["#histories#"] + == "" + ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index d157a41d2c..634703740c 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -5,14 +5,14 @@ from core.app.entities.app_invoke_entities import ( ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import ( +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, ToolPromptMessage, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py index e3e500e310..1b114b369a 100644 --- a/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py +++ b/api/tests/unit_tests/core/prompt/test_extract_thread_messages.py @@ -2,12 +2,14 @@ from uuid import uuid4 from constants import UUID_NIL from core.prompt.utils.extract_thread_messages import extract_thread_messages +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length class MockMessage: - def __init__(self, id, parent_message_id): + def __init__(self, id, parent_message_id, answer="answer"): self.id = id self.parent_message_id = parent_message_id + self.answer = answer def __getitem__(self, item): return getattr(self, item) @@ -89,3 +91,44 @@ def test_extract_thread_messages_mixed_with_legacy_messages(): result = extract_thread_messages(messages) assert len(result) == 4 assert [msg["id"] for msg in result] == [id5, id4, id2, id1] + + +def test_extract_thread_messages_breaks_when_parent_is_none(): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [MockMessage(id2, None), MockMessage(id1, UUID_NIL)] + + result = extract_thread_messages(messages) + + assert len(result) == 1 + assert result[0].id == id2 + + +def test_get_thread_messages_length_excludes_newly_created_empty_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer=""), # newest generated message should be excluded + MockMessage(id1, UUID_NIL, answer="ok"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-1") + + assert length == 1 + mock_scalars.assert_called_once() + + +def test_get_thread_messages_length_keeps_non_empty_latest_answer(mocker): + id1, id2 = str(uuid4()), str(uuid4()) + messages = [ + MockMessage(id2, id1, answer="latest-answer"), + MockMessage(id1, UUID_NIL, answer="older-answer"), + ] + + mock_scalars = mocker.patch("core.prompt.utils.get_thread_messages_length.db.session.scalars") + mock_scalars.return_value.all.return_value = messages + + length = get_thread_messages_length("conversation-2") + + assert length == 2 diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index e5da51d733..9fc300348a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,11 @@ -from core.model_runtime.entities.message_entities import ( +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, + ToolPromptMessage, UserPromptMessage, ) @@ -25,3 +30,82 @@ def test_dump_prompt_message(): ) data = prompt.model_dump() assert data["content"][0].get("url") == example_url + + +def test_prompt_messages_to_prompt_for_saving_chat_mode(): + chat_messages = [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="hello "), + ImagePromptMessageContent( + url="https://example.com/image1.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + AudioPromptMessageContent( + url="https://example.com/audio1.mp3", + format="mp3", + mime_type="audio/mpeg", + ), + TextPromptMessageContent(data="world"), + ] + ), + AssistantPromptMessage( + content="assistant-text", + tool_calls=[ + { + "id": "tool-1", + "type": "function", + "function": {"name": "search", "arguments": '{"q":"python"}'}, + } + ], + ), + ToolPromptMessage(content="tool-output", name="search", tool_call_id="tool-1"), + UserPromptMessage.model_construct(role="unknown", content="skip"), # type: ignore[arg-type] + ] + + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving(ModelMode.CHAT, chat_messages) + + assert len(prompts) == 3 + assert prompts[0]["role"] == "user" + assert prompts[0]["text"] == "hello world" + assert prompts[0]["files"][0]["type"] == "image" + assert prompts[0]["files"][1]["type"] == "audio" + + assert prompts[1]["role"] == "assistant" + assert prompts[1]["text"] == "assistant-text" + assert prompts[1]["tool_calls"][0]["function"]["name"] == "search" + assert prompts[2]["role"] == "tool" + + +def test_prompt_messages_to_prompt_for_saving_completion_mode_with_and_without_files(): + completion_message_with_files = UserPromptMessage( + content=[ + TextPromptMessageContent(data="first "), + TextPromptMessageContent(data="second"), + ImagePromptMessageContent( + url="https://example.com/image2.jpg", + format="jpg", + mime_type="image/jpeg", + detail=ImagePromptMessageContent.DETAIL.LOW, + ), + ] + ) + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_with_files] + ) + assert prompts == [ + { + "role": "user", + "text": "first second", + "files": prompts[0]["files"], + } + ] + assert prompts[0]["files"][0]["type"] == "image" + + completion_message_text_only = UserPromptMessage(content="plain text") + prompts = PromptMessageUtil.prompt_messages_to_prompt_for_saving( + ModelMode.COMPLETION, [completion_message_text_only] + ) + assert prompts == [{"role": "user", "text": "plain text"}] diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index 16896a0c6c..d379e3067a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -1,52 +1,231 @@ -# from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.prompt.prompt_transform import PromptTransform +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from core.model_runtime.entities.message_entities import UserPromptMessage -# from core.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from core.model_runtime.entities.provider_entities import ProviderEntity -# from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage +# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform -# def test__calculate_rest_token(): -# model_schema_mock = MagicMock(spec=AIModelEntity) -# parameter_rule_mock = MagicMock(spec=ParameterRule) -# parameter_rule_mock.name = "max_tokens" -# model_schema_mock.parameter_rules = [parameter_rule_mock] -# model_schema_mock.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 62} +class TestPromptTransform: + def test_resolve_model_runtime_requires_model_config_or_instance(self): + transform = PromptTransform() -# large_language_model_mock = MagicMock(spec=LargeLanguageModel) -# large_language_model_mock.get_num_tokens.return_value = 6 + with pytest.raises(ValueError, match="Either model_config or model_instance must be provided."): + transform._resolve_model_runtime() -# provider_mock = MagicMock(spec=ProviderEntity) -# provider_mock.provider = "openai" + def test_resolve_model_runtime_builds_model_instance_from_model_config(self): + transform = PromptTransform() + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = fake_model_schema + fake_model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials=None, + parameters=None, + stop=None, + ) + model_config = SimpleNamespace( + provider_model_bundle=object(), + model="config-model", + credentials={"api_key": "secret"}, + parameters={"temperature": 0.1}, + stop=["END"], + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + ) -# provider_configuration_mock = MagicMock(spec=ProviderConfiguration) -# provider_configuration_mock.provider = provider_mock -# provider_configuration_mock.model_settings = None + with patch( + "core.prompt.prompt_transform.ModelInstance", return_value=fake_model_instance + ) as model_instance_cls: + model_instance, model_schema = transform._resolve_model_runtime(model_config=model_config) -# provider_model_bundle_mock = MagicMock(spec=ProviderModelBundle) -# provider_model_bundle_mock.model_type_instance = large_language_model_mock -# provider_model_bundle_mock.configuration = provider_configuration_mock + model_instance_cls.assert_called_once_with( + provider_model_bundle=model_config.provider_model_bundle, + model=model_config.model, + ) + fake_model_type_instance.get_model_schema.assert_called_once_with( + model="resolved-model", + credentials={"api_key": "secret"}, + ) + assert model_instance is fake_model_instance + assert model_instance.credentials == {"api_key": "secret"} + assert model_instance.parameters == {"temperature": 0.1} + assert model_instance.stop == ["END"] + assert model_schema is fake_model_schema -# model_config_mock = MagicMock(spec=ModelConfigEntity) -# model_config_mock.model = "gpt-4" -# model_config_mock.credentials = {} -# model_config_mock.parameters = {"max_tokens": 50} -# model_config_mock.model_schema = model_schema_mock -# model_config_mock.provider_model_bundle = provider_model_bundle_mock + def test_resolve_model_runtime_uses_model_config_schema_fallback(self): + transform = PromptTransform() + fallback_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + model_config = SimpleNamespace(model_schema=fallback_schema) -# prompt_transform = PromptTransform() + resolved_model_instance, resolved_schema = transform._resolve_model_runtime( + model_config=model_config, + model_instance=model_instance, + ) -# prompt_messages = [UserPromptMessage(content="Hello, how are you?")] -# rest_tokens = prompt_transform._calculate_rest_token(prompt_messages, model_config_mock) + assert resolved_model_instance is model_instance + assert resolved_schema is fallback_schema -# # Validate based on the mock configuration and expected logic -# expected_rest_tokens = ( -# model_schema_mock.model_properties[ModelPropertyKey.CONTEXT_SIZE] -# - model_config_mock.parameters["max_tokens"] -# - large_language_model_mock.get_num_tokens.return_value -# ) -# assert rest_tokens == expected_rest_tokens -# assert rest_tokens == 6 + def test_resolve_model_runtime_raises_when_schema_missing_without_model_config(self): + transform = PromptTransform() + fake_model_type_instance = MagicMock() + fake_model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=fake_model_type_instance, + model_name="resolved-model", + credentials={"api_key": "secret"}, + parameters={}, + ) + + with pytest.raises(ValueError, match="Model schema not found for the provided model instance."): + transform._resolve_model_runtime(model_instance=model_instance) + + def test_calculate_rest_token_defaults_when_context_size_missing(self): + transform = PromptTransform() + fake_model_instance = SimpleNamespace(parameters={}, get_llm_num_tokens=lambda _: 0) + fake_model_schema = SimpleNamespace(model_properties={}, parameter_rules=[]) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace(model_properties={}, parameter_rules=[]), + provider_model_bundle=object(), + model="test-model", + parameters={}, + ) + + rest = transform._calculate_rest_token([], model_config=model_config) + + assert rest == 2000 + + def test_calculate_rest_token_uses_max_tokens_and_clamps_to_zero(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="max_tokens", use_template=None) + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 50}, get_llm_num_tokens=lambda _: 95) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 100}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 50}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 0 + + def test_calculate_rest_token_supports_use_template_parameter(self): + transform = PromptTransform() + + parameter_rule = SimpleNamespace(name="generation_max", use_template="max_tokens") + fake_model_instance = SimpleNamespace(parameters={"max_tokens": 30}, get_llm_num_tokens=lambda _: 20) + fake_model_schema = SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ) + transform._resolve_model_runtime = MagicMock(return_value=(fake_model_instance, fake_model_schema)) + model_config = SimpleNamespace( + model_schema=SimpleNamespace( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 200}, + parameter_rules=[parameter_rule], + ), + provider_model_bundle=object(), + model="test-model", + parameters={"max_tokens": 30}, + ) + + rest = transform._calculate_rest_token([SimpleNamespace()], model_config=model_config) + + assert rest == 150 + + def test_get_history_messages_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_text.return_value = "history" + + memory_config_with_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=3)) + result = transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_with_window, + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert result == "history" + memory.get_history_prompt_text.assert_called_with( + max_token_limit=100, + human_prefix="Human", + ai_prefix="Assistant", + message_limit=3, + ) + + memory.reset_mock() + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=False, size=2)) + transform._get_history_messages_from_memory( + memory=memory, + memory_config=memory_config_no_window, + max_token_limit=50, + ) + memory.get_history_prompt_text.assert_called_with(max_token_limit=50) + + def test_get_history_messages_list_from_memory_with_and_without_window(self): + transform = PromptTransform() + memory = MagicMock() + memory.get_history_prompt_messages.return_value = ["m1", "m2"] + + memory_config_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=2)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_window, 120) + assert result == ["m1", "m2"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=120, message_limit=2) + + memory.reset_mock() + memory.get_history_prompt_messages.return_value = ["only"] + memory_config_no_window = SimpleNamespace(window=SimpleNamespace(enabled=True, size=0)) + result = transform._get_history_messages_list_from_memory(memory, memory_config_no_window, 10) + assert result == ["only"] + memory.get_history_prompt_messages.assert_called_with(max_token_limit=10, message_limit=None) + + def test_append_chat_histories_extends_prompt_messages(self, monkeypatch): + transform = PromptTransform() + memory = MagicMock() + memory_config = SimpleNamespace(window=SimpleNamespace(enabled=False, size=None)) + + monkeypatch.setattr(transform, "_calculate_rest_token", lambda prompt_messages, **kwargs: 99) + monkeypatch.setattr( + transform, + "_get_history_messages_list_from_memory", + lambda memory, memory_config, max_token_limit: ["h1", "h2"], + ) + + result = transform._append_chat_histories( + memory=memory, + memory_config=memory_config, + prompt_messages=["p1"], + model_config=SimpleNamespace(), + ) + + assert result == ["p1", "h1", "h2"] diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index c822ecbe78..e6d28224d7 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -1,9 +1,29 @@ -from unittest.mock import MagicMock +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_runtime.entities.message_entities import AssistantPromptMessage, UserPromptMessage +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) from core.prompt.simple_prompt_transform import SimplePromptTransform +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from models.model import AppMode, Conversation @@ -244,3 +264,178 @@ def test__get_completion_model_prompt_messages(): assert len(prompt_messages) == 1 assert stops == prompt_rules.get("stops") assert prompt_messages[0].content == real_prompt + + +def test_get_prompt_dispatches_chat_and_completion(): + transform = SimplePromptTransform() + model_config_chat = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_chat.mode = "chat" + model_config_completion = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config_completion.mode = "completion" + prompt_entity = SimpleNamespace(simple_prompt_template="hello") + + transform._get_chat_model_prompt_messages = MagicMock(return_value=(["chat-msg"], None)) + transform._get_completion_model_prompt_messages = MagicMock(return_value=(["completion-msg"], ["stop"])) + + chat_messages, chat_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_chat, + ) + assert chat_messages == ["chat-msg"] + assert chat_stops is None + + completion_messages, completion_stops = transform.get_prompt( + app_mode=AppMode.CHAT, + prompt_template_entity=prompt_entity, + inputs={"n": 1}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config_completion, + ) + assert completion_messages == ["completion-msg"] + assert completion_stops == ["stop"] + + +def test_get_prompt_str_and_rules_type_validation_errors(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + model_config.provider = "openai" + model_config.model = "gpt-4" + valid_prompt_template = SimplePromptTransform().get_prompt_template( + AppMode.CHAT, "openai", "gpt-4", "", False, False + )["prompt_template"] + + bad_custom_keys = { + "prompt_template": valid_prompt_template, + "custom_variable_keys": "not-list", + "special_variable_keys": [], + "prompt_rules": {}, + } + transform.get_prompt_template = MagicMock(return_value=bad_custom_keys) + with pytest.raises(TypeError, match="custom_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_special_keys = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": "not-list", + } + transform.get_prompt_template = MagicMock(return_value=bad_special_keys) + with pytest.raises(TypeError, match="special_variable_keys"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_template = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": 123, + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_template) + with pytest.raises(TypeError, match="PromptTemplateParser"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + bad_prompt_rules = { + **bad_custom_keys, + "custom_variable_keys": [], + "special_variable_keys": [], + "prompt_template": valid_prompt_template, + "prompt_rules": "not-dict", + } + transform.get_prompt_template = MagicMock(return_value=bad_prompt_rules) + with pytest.raises(TypeError, match="prompt_rules"): + transform._get_prompt_str_and_rules(AppMode.CHAT, model_config, "", {}, query=None, context=None) + + +def test_chat_model_prompt_messages_uses_prompt_when_query_empty(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt-text", {})) + transform._get_last_user_message = MagicMock(return_value=UserPromptMessage(content="prompt-text")) + + prompt_messages, _ = transform._get_chat_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert prompt_messages[0].content == "prompt-text" + transform._get_last_user_message.assert_called_once_with("prompt-text", [], None, None) + + +def test_completion_model_prompt_messages_empty_stops_becomes_none(): + transform = SimplePromptTransform() + model_config = MagicMock(spec=ModelConfigWithCredentialsEntity) + transform._get_prompt_str_and_rules = MagicMock(return_value=("prompt", {"stops": []})) + + prompt_messages, stops = transform._get_completion_model_prompt_messages( + app_mode=AppMode.CHAT, + pre_prompt="", + inputs={}, + query="q", + files=[], + context=None, + memory=None, + model_config=model_config, + ) + + assert len(prompt_messages) == 1 + assert stops is None + + +def test_get_last_user_message_with_files_and_context_files(): + transform = SimplePromptTransform() + file = SimpleNamespace() + context_file = SimpleNamespace() + + with patch("core.prompt.simple_prompt_transform.file_manager.to_prompt_message_content") as to_content: + to_content.side_effect = [ + ImagePromptMessageContent(url="https://example.com/a.jpg", format="jpg", mime_type="image/jpg"), + ImagePromptMessageContent(url="https://example.com/b.jpg", format="jpg", mime_type="image/jpg"), + ] + message = transform._get_last_user_message( + prompt="hello", + files=[file], + context_files=[context_file], + image_detail_config=None, + ) + + assert isinstance(message.content, list) + assert message.content[0].data == "https://example.com/a.jpg" + assert message.content[1].data == "https://example.com/b.jpg" + assert isinstance(message.content[2], TextPromptMessageContent) + assert message.content[2].data == "hello" + + +def test_prompt_file_name_branches(): + transform = SimplePromptTransform() + + assert transform._prompt_file_name(AppMode.CHAT, "openai", "gpt-4") == "common_chat" + assert transform._prompt_file_name(AppMode.COMPLETION, "openai", "gpt-4") == "common_completion" + assert transform._prompt_file_name(AppMode.COMPLETION, "baichuan", "Baichuan2") == "baichuan_completion" + assert transform._prompt_file_name(AppMode.CHAT, "huggingface_hub", "baichuan-13b") == "baichuan_chat" + + +def test_advanced_prompt_templates_constants_are_importable(): + assert isinstance(CONTEXT, str) + assert isinstance(BAICHUAN_CONTEXT, str) + assert "completion_prompt_config" in CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in COMPLETION_APP_COMPLETION_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + assert "chat_prompt_config" in BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + assert "completion_prompt_config" in BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py new file mode 100644 index 0000000000..c25af79ae4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/tidb_on_qdrant/test_tidb_on_qdrant_vector.py @@ -0,0 +1,160 @@ +from unittest.mock import patch + +import httpx +import pytest +from qdrant_client.http import models as rest +from qdrant_client.http.exceptions import UnexpectedResponse + +from core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector import ( + TidbOnQdrantConfig, + TidbOnQdrantVector, +) + + +class TestTidbOnQdrantVectorDeleteByIds: + """Unit tests for TidbOnQdrantVector.delete_by_ids method.""" + + @pytest.fixture + def vector_instance(self): + """Create a TidbOnQdrantVector instance for testing.""" + config = TidbOnQdrantConfig( + endpoint="http://localhost:6333", + api_key="test_api_key", + ) + + with patch("core.rag.datasource.vdb.tidb_on_qdrant.tidb_on_qdrant_vector.qdrant_client.QdrantClient"): + vector = TidbOnQdrantVector( + collection_name="test_collection", + group_id="test_group", + config=config, + ) + return vector + + def test_delete_by_ids_with_multiple_ids(self, vector_instance): + """Test batch deletion with multiple document IDs.""" + ids = ["doc1", "doc2", "doc3"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once with MatchAny filter + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Check collection name + assert call_args[1]["collection_name"] == "test_collection" + + # Verify filter uses MatchAny with all IDs + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + assert len(filter_obj.must) == 1 + + field_condition = filter_obj.must[0] + assert field_condition.key == "metadata.doc_id" + assert isinstance(field_condition.match, rest.MatchAny) + assert set(field_condition.match.any) == {"doc1", "doc2", "doc3"} + + def test_delete_by_ids_with_single_id(self, vector_instance): + """Test deletion with a single document ID.""" + ids = ["doc1"] + + vector_instance.delete_by_ids(ids) + + # Verify that delete was called once + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + # Verify filter uses MatchAny with single ID + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ["doc1"] + + def test_delete_by_ids_with_empty_list(self, vector_instance): + """Test deletion with empty ID list returns early without API call.""" + vector_instance.delete_by_ids([]) + + # Verify that delete was NOT called + vector_instance._client.delete.assert_not_called() + + def test_delete_by_ids_with_404_error(self, vector_instance): + """Test that 404 errors (collection not found) are handled gracefully.""" + ids = ["doc1", "doc2"] + + # Mock a 404 error + error = UnexpectedResponse( + status_code=404, + reason_phrase="Not Found", + content=b"Collection not found", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should not raise an exception + vector_instance.delete_by_ids(ids) + + # Verify delete was called + vector_instance._client.delete.assert_called_once() + + def test_delete_by_ids_with_unexpected_error(self, vector_instance): + """Test that non-404 errors are re-raised.""" + ids = ["doc1", "doc2"] + + # Mock a 500 error + error = UnexpectedResponse( + status_code=500, + reason_phrase="Internal Server Error", + content=b"Server error", + headers=httpx.Headers(), + ) + vector_instance._client.delete.side_effect = error + + # Should re-raise the exception + with pytest.raises(UnexpectedResponse) as exc_info: + vector_instance.delete_by_ids(ids) + + assert exc_info.value.status_code == 500 + + def test_delete_by_ids_with_large_batch(self, vector_instance): + """Test deletion with a large batch of IDs.""" + # Create 1000 IDs + ids = [f"doc_{i}" for i in range(1000)] + + vector_instance.delete_by_ids(ids) + + # Verify single delete call with all IDs + vector_instance._client.delete.assert_called_once() + call_args = vector_instance._client.delete.call_args + + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + field_condition = filter_obj.must[0] + + # Verify all 1000 IDs are in the batch + assert len(field_condition.match.any) == 1000 + assert "doc_0" in field_condition.match.any + assert "doc_999" in field_condition.match.any + + def test_delete_by_ids_filter_structure(self, vector_instance): + """Test that the filter structure is correctly constructed.""" + ids = ["doc1", "doc2"] + + vector_instance.delete_by_ids(ids) + + call_args = vector_instance._client.delete.call_args + filter_selector = call_args[1]["points_selector"] + filter_obj = filter_selector.filter + + # Verify Filter structure + assert isinstance(filter_obj, rest.Filter) + assert filter_obj.must is not None + assert len(filter_obj.must) == 1 + + # Verify FieldCondition structure + field_condition = filter_obj.must[0] + assert isinstance(field_condition, rest.FieldCondition) + assert field_condition.key == "metadata.doc_id" + + # Verify MatchAny structure + assert isinstance(field_condition.match, rest.MatchAny) + assert field_condition.match.any == ids diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py new file mode 100644 index 0000000000..baf8c9e5f8 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weavaite.py @@ -0,0 +1,33 @@ +from unittest.mock import MagicMock, patch + +from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector + + +def test_init_client_with_valid_config(): + """Test successful client initialization with valid configuration.""" + config = WeaviateConfig( + endpoint="http://localhost:8080", + api_key="WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih", + ) + + with patch("weaviate.connect_to_custom") as mock_connect: + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_connect.return_value = mock_client + + vector = WeaviateVector( + collection_name="test_collection", + config=config, + attributes=["doc_id"], + ) + + assert vector._client == mock_client + mock_connect.assert_called_once() + call_kwargs = mock_connect.call_args[1] + assert call_kwargs["http_host"] == "localhost" + assert call_kwargs["http_port"] == 8080 + assert call_kwargs["http_secure"] is False + assert call_kwargs["grpc_host"] == "localhost" + assert call_kwargs["grpc_port"] == 50051 + assert call_kwargs["grpc_secure"] is False + assert call_kwargs["auth_credentials"] is not None diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py new file mode 100644 index 0000000000..3bd656ba84 --- /dev/null +++ b/api/tests/unit_tests/core/rag/datasource/vdb/weaviate/test_weaviate_vector.py @@ -0,0 +1,335 @@ +"""Unit tests for Weaviate vector database implementation. + +Focuses on verifying that doc_type is properly handled in: +- Collection schema creation (_create_collection) +- Property migration (_ensure_properties) +- Vector search result metadata (search_by_vector) +- Full-text search result metadata (search_by_full_text) +""" + +import unittest +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from core.rag.datasource.vdb.weaviate import weaviate_vector as weaviate_vector_module +from core.rag.datasource.vdb.weaviate.weaviate_vector import WeaviateConfig, WeaviateVector +from core.rag.models.document import Document + + +class TestWeaviateVector(unittest.TestCase): + """Tests for WeaviateVector class with focus on doc_type metadata handling.""" + + def setUp(self): + weaviate_vector_module._weaviate_client = None + self.config = WeaviateConfig( + endpoint="http://localhost:8080", + api_key="test-key", + batch_size=100, + ) + self.collection_name = "Test_Collection_Node" + self.attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"] + + def tearDown(self): + weaviate_vector_module._weaviate_client = None + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def _create_weaviate_vector(self, mock_weaviate_module): + """Helper to create a WeaviateVector instance with mocked client.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + return wv, mock_client + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_init(self, mock_weaviate_module): + """Test WeaviateVector initialization stores attributes including doc_type.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + assert wv._collection_name == self.collection_name + assert "doc_type" in wv._attributes + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.redis_client") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.dify_config") + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_create_collection_includes_doc_type_property(self, mock_weaviate_module, mock_dify_config, mock_redis): + """Test that _create_collection defines doc_type in the schema properties.""" + # Mock Redis + mock_lock = MagicMock() + mock_lock.__enter__ = MagicMock() + mock_lock.__exit__ = MagicMock() + mock_redis.lock.return_value = mock_lock + mock_redis.get.return_value = None + mock_redis.set.return_value = None + + # Mock dify_config + mock_dify_config.WEAVIATE_TOKENIZATION = None + + # Mock client + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + mock_client.collections.exists.return_value = False + + # Mock _ensure_properties to avoid side effects + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + mock_cfg = MagicMock() + mock_cfg.properties = [] + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._create_collection() + + # Verify collections.create was called + mock_client.collections.create.assert_called_once() + + # Extract properties from the create call + call_kwargs = mock_client.collections.create.call_args + properties = call_kwargs.kwargs.get("properties") + + # Verify doc_type is among the defined properties + property_names = [p.name for p in properties] + assert "doc_type" in property_names, ( + f"doc_type should be in collection schema properties, got: {property_names}" + ) + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_adds_missing_doc_type(self, mock_weaviate_module): + """Test that _ensure_properties adds doc_type when it's missing from existing schema.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + # Collection exists but doc_type property is missing + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # Simulate existing properties WITHOUT doc_type + existing_props = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="chunk_index"), + ] + mock_cfg = MagicMock() + mock_cfg.properties = existing_props + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + # Verify add_property was called and includes doc_type + add_calls = mock_col.config.add_property.call_args_list + added_names = [call.args[0].name for call in add_calls] + assert "doc_type" in added_names, f"doc_type should be added to existing collection, added: {added_names}" + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_ensure_properties_skips_existing_doc_type(self, mock_weaviate_module): + """Test that _ensure_properties does not add doc_type when it already exists.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # Simulate existing properties WITH doc_type already present + existing_props = [ + SimpleNamespace(name="text"), + SimpleNamespace(name="document_id"), + SimpleNamespace(name="doc_id"), + SimpleNamespace(name="doc_type"), + SimpleNamespace(name="chunk_index"), + ] + mock_cfg = MagicMock() + mock_cfg.properties = existing_props + mock_col.config.get.return_value = mock_cfg + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + wv._ensure_properties() + + # No properties should be added + mock_col.config.add_property.assert_not_called() + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_vector_returns_doc_type_in_metadata(self, mock_weaviate_module): + """Test that search_by_vector returns doc_type in document metadata. + + This is the core bug fix verification: when doc_type is in _attributes, + it should appear in return_properties and thus be included in results. + """ + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # Simulate search result with doc_type in properties + mock_obj = MagicMock() + mock_obj.properties = { + "text": "image content description", + "doc_id": "upload_file_id_123", + "dataset_id": "dataset_1", + "document_id": "doc_1", + "doc_hash": "hash_abc", + "doc_type": "image", + } + mock_obj.metadata.distance = 0.1 + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.near_vector.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_vector(query_vector=[0.1] * 128, top_k=1) + + # Verify doc_type is in return_properties + call_kwargs = mock_col.query.near_vector.call_args + return_props = call_kwargs.kwargs.get("return_properties") + assert "doc_type" in return_props, f"doc_type should be in return_properties, got: {return_props}" + + # Verify doc_type is in result metadata + assert len(docs) == 1 + assert docs[0].metadata.get("doc_type") == "image" + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_search_by_full_text_returns_doc_type_in_metadata(self, mock_weaviate_module): + """Test that search_by_full_text also returns doc_type in document metadata.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_client.collections.exists.return_value = True + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # Simulate BM25 search result with doc_type + mock_obj = MagicMock() + mock_obj.properties = { + "text": "image content description", + "doc_id": "upload_file_id_456", + "doc_type": "image", + } + mock_obj.vector = {"default": [0.1] * 128} + + mock_result = MagicMock() + mock_result.objects = [mock_obj] + mock_col.query.bm25.return_value = mock_result + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + docs = wv.search_by_full_text(query="image", top_k=1) + + # Verify doc_type is in return_properties + call_kwargs = mock_col.query.bm25.call_args + return_props = call_kwargs.kwargs.get("return_properties") + assert "doc_type" in return_props, ( + f"doc_type should be in return_properties for BM25 search, got: {return_props}" + ) + + # Verify doc_type is in result metadata + assert len(docs) == 1 + assert docs[0].metadata.get("doc_type") == "image" + + @patch("core.rag.datasource.vdb.weaviate.weaviate_vector.weaviate") + def test_add_texts_stores_doc_type_in_properties(self, mock_weaviate_module): + """Test that add_texts includes doc_type from document metadata in stored properties.""" + mock_client = MagicMock() + mock_client.is_ready.return_value = True + mock_weaviate_module.connect_to_custom.return_value = mock_client + + mock_col = MagicMock() + mock_client.collections.use.return_value = mock_col + + # Create a document with doc_type metadata (as produced by multimodal indexing) + doc = Document( + page_content="an image of a cat", + metadata={ + "doc_id": "upload_file_123", + "doc_type": "image", + "dataset_id": "ds_1", + "document_id": "doc_1", + "doc_hash": "hash_xyz", + }, + ) + + wv = WeaviateVector( + collection_name=self.collection_name, + config=self.config, + attributes=self.attributes, + ) + + # Mock batch context manager + mock_batch = MagicMock() + mock_batch.__enter__ = MagicMock(return_value=mock_batch) + mock_batch.__exit__ = MagicMock(return_value=False) + mock_col.batch.dynamic.return_value = mock_batch + + wv.add_texts(documents=[doc], embeddings=[[0.1] * 128]) + + # Verify batch.add_object was called with doc_type in properties + mock_batch.add_object.assert_called_once() + call_kwargs = mock_batch.add_object.call_args + stored_props = call_kwargs.kwargs.get("properties") + assert stored_props.get("doc_type") == "image", f"doc_type should be stored in properties, got: {stored_props}" + + +class TestVectorDefaultAttributes(unittest.TestCase): + """Tests for Vector class default attributes list.""" + + @patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings") + @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") + def test_default_attributes_include_doc_type(self, mock_init_vector, mock_get_embeddings): + """Test that Vector class default attributes include doc_type.""" + from core.rag.datasource.vdb.vector_factory import Vector + + mock_get_embeddings.return_value = MagicMock() + mock_init_vector.return_value = MagicMock() + + mock_dataset = MagicMock() + mock_dataset.index_struct_dict = None + + vector = Vector(dataset=mock_dataset) + + assert "doc_type" in vector._attributes, f"doc_type should be in default attributes, got: {vector._attributes}" + + +if __name__ == "__main__": + unittest.main() diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py new file mode 100644 index 0000000000..13285cdad0 --- /dev/null +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -0,0 +1,813 @@ +""" +Unit tests for DatasetDocumentStore. + +Tests cover all public methods and error paths of the DatasetDocumentStore class +which provides document storage and retrieval functionality for datasets in the RAG system. +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from core.rag.docstore.dataset_docstore import DatasetDocumentStore, DocumentSegment +from core.rag.models.document import AttachmentDocument, Document +from models.dataset import Dataset + + +class TestDatasetDocumentStoreInit: + """Tests for DatasetDocumentStore initialization.""" + + def test_init_with_all_parameters(self): + """Test initialization with dataset, user_id, and document_id.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + assert store._dataset == mock_dataset + assert store._user_id == "test-user-id" + assert store._document_id == "test-doc-id" + assert store.dataset_id == "test-dataset-id" + assert store.user_id == "test-user-id" + + def test_init_without_document_id(self): + """Test initialization without document_id.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + assert store._document_id is None + assert store.dataset_id == "test-dataset-id" + + +class TestDatasetDocumentStoreSerialization: + """Tests for to_dict and from_dict methods.""" + + def test_to_dict(self): + """Test serialization to dictionary.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.to_dict() + + assert result == {"dataset_id": "test-dataset-id"} + + def test_from_dict(self): + """Test deserialization from dictionary.""" + + config_dict = { + "dataset": MagicMock(spec=["id"]), + "user_id": "test-user", + "document_id": "test-doc", + } + config_dict["dataset"].id = "ds-123" + + store = DatasetDocumentStore.from_dict(config_dict) + + assert store._user_id == "test-user" + assert store._document_id == "test-doc" + + +class TestDatasetDocumentStoreDocs: + """Tests for the docs property.""" + + def test_docs_returns_document_dict(self): + """Test that docs property returns a dictionary of documents.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + mock_segment.index_node_id = "node-1" + mock_segment.index_node_hash = "hash-1" + mock_segment.document_id = "doc-1" + mock_segment.dataset_id = "test-dataset-id" + mock_segment.content = "Test content" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalars.return_value.all.return_value = [mock_segment] + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.docs + + assert "node-1" in result + assert isinstance(result["node-1"], Document) + + def test_docs_empty_dataset(self): + """Test docs property with no segments.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalars.return_value.all.return_value = [] + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.docs + + assert result == {} + + +class TestDatasetDocumentStoreAddDocuments: + """Tests for add_documents method.""" + + def test_add_documents_new_document_with_embedding(self): + """Test adding new documents with embedding model.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "provider" + mock_dataset.embedding_model = "model" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_model_instance = MagicMock() + mock_model_instance.get_text_embedding_num_tokens.return_value = [10] + + with ( + patch("core.rag.docstore.dataset_docstore.db") as mock_db, + patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + ): + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + mock_manager = MagicMock() + mock_manager.get_model_instance.return_value = mock_model_instance + mock_manager_class.return_value = mock_manager + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + def test_add_documents_update_existing_document(self): + """Test updating existing document with allow_update=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + mock_dataset.embedding_model_provider = None + mock_dataset.embedding_model = None + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.commit.assert_called() + + def test_add_documents_raises_when_not_allowed(self): + """Test that adding existing doc without allow_update raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="already exists"): + store.add_documents([mock_doc], allow_update=False) + + def test_add_documents_with_answer_metadata(self): + """Test adding document with answer in metadata.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + "answer": "Test answer", + } + mock_doc.attachments = None + mock_doc.children = None + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.add.assert_called() + + def test_add_documents_with_invalid_document_type(self): + """Test that non-Document raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="must be a Document"): + store.add_documents(["not a document"]) + + def test_add_documents_with_none_metadata(self): + """Test that document with None metadata raises ValueError.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = None + + with patch("core.rag.docstore.dataset_docstore.db"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + with pytest.raises(ValueError, match="metadata must be a dict"): + store.add_documents([mock_doc]) + + def test_add_documents_with_save_child(self): + """Test adding documents with save_child=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_child = MagicMock(spec=Document) + mock_child.page_content = "Child content" + mock_child.metadata = { + "doc_id": "child-1", + "doc_hash": "child-hash", + } + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Test content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "hash-1", + } + mock_doc.attachments = None + mock_doc.children = [mock_child] + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc], save_child=True) + + mock_db.session.add.assert_called() + + +class TestDatasetDocumentStoreExists: + """Tests for document_exists method.""" + + def test_document_exists_returns_true(self): + """Test document_exists returns True when segment exists.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.document_exists("doc-1") + + assert result is True + + def test_document_exists_returns_false(self): + """Test document_exists returns False when segment doesn't exist.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.document_exists("doc-1") + + assert result is False + + +class TestDatasetDocumentStoreGetDocument: + """Tests for get_document method.""" + + def test_get_document_success(self): + """Test getting a document successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + mock_segment.index_node_id = "node-1" + mock_segment.index_node_hash = "hash-1" + mock_segment.document_id = "doc-1" + mock_segment.dataset_id = "test-dataset-id" + mock_segment.content = "Test content" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document("node-1", raise_error=False) + + assert isinstance(result, Document) + assert result.page_content == "Test content" + + def test_get_document_returns_none_when_not_found(self): + """Test get_document returns None when not found and raise_error=False.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document("nonexistent", raise_error=False) + + assert result is None + + def test_get_document_raises_when_not_found(self): + """Test get_document raises ValueError when not found and raise_error=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + with pytest.raises(ValueError, match="not found"): + store.get_document("nonexistent", raise_error=True) + + +class TestDatasetDocumentStoreDeleteDocument: + """Tests for delete_document method.""" + + def test_delete_document_success(self): + """Test deleting a document successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + store.delete_document("doc-1") + + mock_db.session.delete.assert_called_with(mock_segment) + mock_db.session.commit.assert_called() + + def test_delete_document_returns_none_when_not_found(self): + """Test delete_document returns None when not found and raise_error=False.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.delete_document("nonexistent", raise_error=False) + + assert result is None + + def test_delete_document_raises_when_not_found(self): + """Test delete_document raises ValueError when not found and raise_error=True.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + with pytest.raises(ValueError, match="not found"): + store.delete_document("nonexistent", raise_error=True) + + +class TestDatasetDocumentStoreHashOperations: + """Tests for set_document_hash and get_document_hash methods.""" + + def test_set_document_hash_success(self): + """Test setting document hash successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + mock_segment.index_node_hash = "old-hash" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + store.set_document_hash("doc-1", "new-hash") + + assert mock_segment.index_node_hash == "new-hash" + mock_db.session.commit.assert_called() + + def test_set_document_hash_returns_none_when_not_found(self): + """Test set_document_hash returns None when segment not found.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.set_document_hash("nonexistent", "new-hash") + + assert result is None + + def test_get_document_hash_success(self): + """Test getting document hash successfully.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock() + mock_segment.index_node_hash = "test-hash" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_segment): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_hash("doc-1") + + assert result == "test-hash" + + def test_get_document_hash_returns_none_when_not_found(self): + """Test get_document_hash returns None when segment not found.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db"): + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=None): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_hash("nonexistent") + + assert result is None + + +class TestDatasetDocumentStoreSegment: + """Tests for get_document_segment method.""" + + def test_get_document_segment_returns_segment(self): + """Test getting a document segment.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + mock_segment = MagicMock(spec=DocumentSegment) + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalar.return_value = mock_segment + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_segment("doc-1") + + assert result == mock_segment + + def test_get_document_segment_returns_none(self): + """Test getting a non-existent document segment.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.scalar.return_value = None + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + ) + + result = store.get_document_segment("nonexistent") + + assert result is None + + +class TestDatasetDocumentStoreMultimodelBinding: + """Tests for add_multimodel_documents_binding method.""" + + def test_add_multimodel_documents_binding_with_attachments(self): + """Test adding multimodel document bindings.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + mock_attachment = MagicMock(spec=AttachmentDocument) + mock_attachment.metadata = {"doc_id": "attachment-1"} + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", [mock_attachment]) + + mock_db.session.add.assert_called() + + def test_add_multimodel_documents_binding_without_attachments(self): + """Test adding bindings with None attachments.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", None) + + mock_db.session.add.assert_not_called() + + def test_add_multimodel_documents_binding_with_empty_list(self): + """Test adding bindings with empty list.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_multimodel_documents_binding("seg-1", []) + + mock_db.session.add.assert_not_called() + + +class TestDatasetDocumentStoreAddDocumentsUpdateChild: + """Tests for add_documents when updating existing documents with children.""" + + def test_add_documents_update_existing_with_children(self): + """Test updating existing document with save_child=True and children.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_child = MagicMock(spec=Document) + mock_child.page_content = "Updated child content" + mock_child.metadata = { + "doc_id": "child-1", + "doc_hash": "new-child-hash", + } + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + } + mock_doc.attachments = None + mock_doc.children = [mock_child] + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc], save_child=True) + + mock_db.session.query.return_value.where.return_value.delete.assert_called() + mock_db.session.commit.assert_called() + + +class TestDatasetDocumentStoreAddDocumentsUpdateAnswer: + """Tests for add_documents when updating existing documents with answer metadata.""" + + def test_add_documents_update_existing_with_answer(self): + """Test updating existing document with answer in metadata.""" + + mock_dataset = MagicMock(spec=Dataset) + mock_dataset.id = "test-dataset-id" + mock_dataset.tenant_id = "tenant-1" + mock_dataset.indexing_technique = "economy" + + mock_doc = MagicMock(spec=Document) + mock_doc.page_content = "Updated content" + mock_doc.metadata = { + "doc_id": "doc-1", + "doc_hash": "new-hash", + "answer": "Updated answer", + } + mock_doc.attachments = None + mock_doc.children = None + + mock_existing_segment = MagicMock() + mock_existing_segment.id = "seg-1" + + with patch("core.rag.docstore.dataset_docstore.db") as mock_db: + mock_session = MagicMock() + mock_db.session = mock_session + mock_db.session.query.return_value.where.return_value.scalar.return_value = 5 + + with patch.object(DatasetDocumentStore, "get_document_segment", return_value=mock_existing_segment): + with patch.object(DatasetDocumentStore, "add_multimodel_documents_binding"): + store = DatasetDocumentStore( + dataset=mock_dataset, + user_id="test-user-id", + document_id="test-doc-id", + ) + + store.add_documents([mock_doc]) + + mock_db.session.commit.assert_called() diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py new file mode 100644 index 0000000000..a0db25174d --- /dev/null +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -0,0 +1,555 @@ +"""Unit tests for cached_embedding.py - CacheEmbedding class. + +This test file covers the methods not fully tested in test_embedding_service.py: +- embed_multimodal_documents +- embed_multimodal_query +- Error handling scenarios in embed_query (DEBUG mode) +""" + +import base64 +from decimal import Decimal +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from sqlalchemy.exc import IntegrityError + +from core.rag.embedding.cached_embedding import CacheEmbedding +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from models.dataset import Embedding + + +class TestCacheEmbeddingMultimodalDocuments: + """Test suite for CacheEmbedding.embed_multimodal_documents method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "vision-embedding-model" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + + model_type_instance = Mock() + model_instance.model_type_instance = model_type_instance + + model_schema = Mock() + model_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + model_type_instance.get_model_schema.return_value = model_schema + + return model_instance + + @pytest.fixture + def sample_multimodal_result(self): + """Create a sample multimodal EmbeddingResult.""" + embedding_vector = np.random.randn(1536) + normalized_vector = (embedding_vector / np.linalg.norm(embedding_vector)).tolist() + + usage = EmbeddingUsage( + tokens=10, + total_tokens=10, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000001"), + currency="USD", + latency=0.5, + ) + + return EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized_vector], + usage=usage, + ) + + def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): + """Test embedding a single multimodal document when cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + documents = [{"file_id": "file123", "content": "test content"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + assert isinstance(result[0], list) + assert len(result[0]) == 1536 + + mock_model_instance.invoke_multimodal_embedding.assert_called_once() + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + def test_embed_multiple_multimodal_documents_cache_miss(self, mock_model_instance): + """Test embedding multiple multimodal documents when cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [ + {"file_id": "file1", "content": "content 1"}, + {"file_id": "file2", "content": "content 2"}, + {"file_id": "file3", "content": "content 3"}, + ] + + embeddings = [] + for _ in range(3): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=30, + total_tokens=30, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000003"), + currency="USD", + latency=0.8, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 3 + assert all(len(emb) == 1536 for emb in result) + + def test_embed_multimodal_documents_cache_hit(self, mock_model_instance): + """Test embedding multimodal documents when embeddings are cached.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + cached_vector = np.random.randn(1536) + normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist() + + mock_cached_embedding = Mock(spec=Embedding) + mock_cached_embedding.get_embedding.return_value = normalized_cached + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_cached_embedding + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + assert result[0] == normalized_cached + mock_model_instance.invoke_multimodal_embedding.assert_not_called() + + def test_embed_multimodal_documents_partial_cache_hit(self, mock_model_instance): + """Test embedding multimodal documents with mixed cache hits and misses.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [ + {"file_id": "cached_file"}, + {"file_id": "new_file_1"}, + {"file_id": "new_file_2"}, + ] + + cached_vector = np.random.randn(1536) + normalized_cached = (cached_vector / np.linalg.norm(cached_vector)).tolist() + + mock_cached_embedding = Mock(spec=Embedding) + mock_cached_embedding.get_embedding.return_value = normalized_cached + + new_embeddings = [] + for _ in range(2): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + new_embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=20, + total_tokens=20, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000002"), + currency="USD", + latency=0.6, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=new_embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + call_count = [0] + + def mock_filter_by(**kwargs): + call_count[0] += 1 + mock_query = Mock() + if call_count[0] == 1: + mock_query.first.return_value = mock_cached_embedding + else: + mock_query.first.return_value = None + return mock_query + + mock_session.query.return_value.filter_by = mock_filter_by + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 3 + assert result[0] == normalized_cached + + def test_embed_multimodal_documents_nan_handling(self, mock_model_instance): + """Test handling of NaN values in multimodal embeddings.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "valid"}, {"file_id": "nan"}] + + valid_vector = np.random.randn(1536).tolist() + nan_vector = [float("nan")] * 1536 + + usage = EmbeddingUsage( + tokens=20, + total_tokens=20, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.000002"), + currency="USD", + latency=0.5, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[valid_vector, nan_vector], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 2 + assert result[0] is not None + assert result[1] is None + + mock_logger.warning.assert_called_once() + + def test_embed_multimodal_documents_large_batch(self, mock_model_instance): + """Test embedding large batch of multimodal documents respecting MAX_CHUNKS.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": f"file{i}"} for i in range(25)] + + def create_batch_result(batch_size): + embeddings = [] + for _ in range(batch_size): + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + embeddings.append(normalized) + + usage = EmbeddingUsage( + tokens=batch_size * 10, + total_tokens=batch_size * 10, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal(str(batch_size * 0.000001)), + currency="USD", + latency=0.5, + ) + + return EmbeddingResult( + model="vision-embedding-model", + embeddings=embeddings, + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + batch_results = [create_batch_result(10), create_batch_result(10), create_batch_result(5)] + mock_model_instance.invoke_multimodal_embedding.side_effect = batch_results + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 25 + assert mock_model_instance.invoke_multimodal_embedding.call_count == 3 + + def test_embed_multimodal_documents_api_error(self, mock_model_instance): + """Test handling of API errors during multimodal embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") + + with pytest.raises(Exception) as exc_info: + cache_embedding.embed_multimodal_documents(documents) + + assert "API Error" in str(exc_info.value) + mock_session.rollback.assert_called() + + def test_embed_multimodal_documents_integrity_error_during_transform( + self, mock_model_instance, sample_multimodal_result + ): + """Test handling of IntegrityError during embedding transformation.""" + cache_embedding = CacheEmbedding(mock_model_instance) + documents = [{"file_id": "file123"}] + + with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: + mock_session.query.return_value.filter_by.return_value.first.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = sample_multimodal_result + + mock_session.commit.side_effect = IntegrityError("Duplicate key", None, None) + + result = cache_embedding.embed_multimodal_documents(documents) + + assert len(result) == 1 + mock_session.rollback.assert_called() + + +class TestCacheEmbeddingMultimodalQuery: + """Test suite for CacheEmbedding.embed_multimodal_query method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "vision-embedding-model" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + return model_instance + + def test_embed_multimodal_query_cache_miss(self, mock_model_instance): + """Test embedding multimodal query when Redis cache is empty.""" + cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + document = {"file_id": "file123"} + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + result = cache_embedding.embed_multimodal_query(document) + + assert isinstance(result, list) + assert len(result) == 1536 + mock_redis.setex.assert_called_once() + + def test_embed_multimodal_query_cache_hit(self, mock_model_instance): + """Test embedding multimodal query when Redis cache has the value.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + embedding_vector = np.random.randn(1536) + vector_bytes = embedding_vector.tobytes() + encoded_vector = base64.b64encode(vector_bytes).decode("utf-8") + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = encoded_vector.encode() + + result = cache_embedding.embed_multimodal_query(document) + + assert isinstance(result, list) + assert len(result) == 1536 + mock_redis.expire.assert_called_once() + mock_model_instance.invoke_multimodal_embedding.assert_not_called() + + def test_embed_multimodal_query_nan_handling(self, mock_model_instance): + """Test handling of NaN values in multimodal query embeddings.""" + cache_embedding = CacheEmbedding(mock_model_instance) + + nan_vector = [float("nan")] * 1536 + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[nan_vector], + usage=usage, + ) + + document = {"file_id": "file123"} + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + + with pytest.raises(ValueError) as exc_info: + cache_embedding.embed_multimodal_query(document) + + assert "Normalized embedding is nan" in str(exc_info.value) + + def test_embed_multimodal_query_api_error(self, mock_model_instance): + """Test handling of API errors during multimodal query embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.side_effect = Exception("API Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = False + + with pytest.raises(Exception) as exc_info: + cache_embedding.embed_multimodal_query(document) + + assert "API Error" in str(exc_info.value) + + def test_embed_multimodal_query_redis_set_error(self, mock_model_instance): + """Test handling of Redis set errors during multimodal query embedding.""" + cache_embedding = CacheEmbedding(mock_model_instance) + document = {"file_id": "file123"} + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="vision-embedding-model", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_multimodal_embedding.return_value = embedding_result + mock_redis.setex.side_effect = RuntimeError("Redis Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with pytest.raises(RuntimeError): + cache_embedding.embed_multimodal_query(document) + + +class TestCacheEmbeddingQueryErrors: + """Test suite for error handling in CacheEmbedding.embed_query method.""" + + @pytest.fixture + def mock_model_instance(self): + """Create a mock ModelInstance for testing.""" + model_instance = Mock() + model_instance.model = "text-embedding-ada-002" + model_instance.provider = "openai" + model_instance.credentials = {"api_key": "test-key"} + return model_instance + + def test_embed_query_api_error_debug_mode(self, mock_model_instance): + """Test handling of API errors in debug mode.""" + cache_embedding = CacheEmbedding(mock_model_instance) + query = "test query" + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_text_embedding.side_effect = RuntimeError("API Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with pytest.raises(RuntimeError) as exc_info: + cache_embedding.embed_query(query) + + assert "API Error" in str(exc_info.value) + mock_logger.exception.assert_called() + + def test_embed_query_redis_set_error_debug_mode(self, mock_model_instance): + """Test handling of Redis set errors in debug mode.""" + cache_embedding = CacheEmbedding(mock_model_instance) + query = "test query" + + vector = np.random.randn(1536) + normalized = (vector / np.linalg.norm(vector)).tolist() + + usage = EmbeddingUsage( + tokens=5, + total_tokens=5, + unit_price=Decimal("0.0001"), + price_unit=Decimal(1000), + total_price=Decimal("0.0000005"), + currency="USD", + latency=0.3, + ) + + embedding_result = EmbeddingResult( + model="text-embedding-ada-002", + embeddings=[normalized], + usage=usage, + ) + + with patch("core.rag.embedding.cached_embedding.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_model_instance.invoke_text_embedding.return_value = embedding_result + mock_redis.setex.side_effect = RuntimeError("Redis Error") + + with patch("core.rag.embedding.cached_embedding.dify_config") as mock_config: + mock_config.DEBUG = True + + with patch("core.rag.embedding.cached_embedding.logger") as mock_logger: + with pytest.raises(RuntimeError): + cache_embedding.embed_query(query) + + mock_logger.exception.assert_called() + + +class TestCacheEmbeddingInitialization: + """Test suite for CacheEmbedding initialization.""" + + def test_initialization_with_user(self): + """Test CacheEmbedding initialization with user parameter.""" + model_instance = Mock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + + cache_embedding = CacheEmbedding(model_instance, user="test-user") + + assert cache_embedding._model_instance == model_instance + assert cache_embedding._user == "test-user" + + def test_initialization_without_user(self): + """Test CacheEmbedding initialization without user parameter.""" + model_instance = Mock() + model_instance.model = "test-model" + model_instance.provider = "test-provider" + + cache_embedding = CacheEmbedding(model_instance) + + assert cache_embedding._model_instance == model_instance + assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py new file mode 100644 index 0000000000..033933e886 --- /dev/null +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_base.py @@ -0,0 +1,220 @@ +"""Unit tests for embedding_base.py - the abstract Embeddings base class.""" + +import asyncio +import inspect +from typing import Any + +import pytest + +from core.rag.embedding.embedding_base import Embeddings + + +class ConcreteEmbeddings(Embeddings): + """Concrete implementation of Embeddings for testing.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] * 10 for _ in texts] + + def embed_multimodal_documents(self, multimodel_documents: list[dict[str, Any]]) -> list[list[float]]: + return [[1.0] * 10 for _ in multimodel_documents] + + def embed_query(self, text: str) -> list[float]: + return [1.0] * 10 + + def embed_multimodal_query(self, multimodel_document: dict[str, Any]) -> list[float]: + return [1.0] * 10 + + +class TestEmbeddingsBase: + """Test suite for the abstract Embeddings base class.""" + + def test_embeddings_is_abc(self): + """Test that Embeddings is an abstract base class.""" + assert hasattr(Embeddings, "__abstractmethods__") + assert len(Embeddings.__abstractmethods__) > 0 + + def test_embed_documents_is_abstract(self): + """Test that embed_documents is an abstract method.""" + assert "embed_documents" in Embeddings.__abstractmethods__ + + def test_embed_multimodal_documents_is_abstract(self): + """Test that embed_multimodal_documents is an abstract method.""" + assert "embed_multimodal_documents" in Embeddings.__abstractmethods__ + + def test_embed_query_is_abstract(self): + """Test that embed_query is an abstract method.""" + assert "embed_query" in Embeddings.__abstractmethods__ + + def test_embed_multimodal_query_is_abstract(self): + """Test that embed_multimodal_query is an abstract method.""" + assert "embed_multimodal_query" in Embeddings.__abstractmethods__ + + def test_embed_documents_raises_not_implemented(self): + """Test that embed_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_documents) + assert "raise NotImplementedError" in source + + def test_embed_multimodal_documents_raises_not_implemented(self): + """Test that embed_multimodal_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_multimodal_documents) + assert "raise NotImplementedError" in source + + def test_embed_query_raises_not_implemented(self): + """Test that embed_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_query) + assert "raise NotImplementedError" in source + + def test_embed_multimodal_query_raises_not_implemented(self): + """Test that embed_multimodal_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.embed_multimodal_query) + assert "raise NotImplementedError" in source + + def test_aembed_documents_raises_not_implemented(self): + """Test that aembed_documents raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.aembed_documents) + assert "raise NotImplementedError" in source + + def test_aembed_query_raises_not_implemented(self): + """Test that aembed_query raises NotImplementedError in its body.""" + source = inspect.getsource(Embeddings.aembed_query) + assert "raise NotImplementedError" in source + + def test_concrete_implementation_works(self): + """Test that a concrete implementation of Embeddings works correctly.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_documents(["test1", "test2"]) + assert len(result) == 2 + assert all(len(emb) == 10 for emb in result) + + def test_concrete_implementation_embed_query(self): + """Test concrete implementation of embed_query.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_query("test query") + assert len(result) == 10 + + def test_concrete_implementation_embed_multimodal_documents(self): + """Test concrete implementation of embed_multimodal_documents.""" + concrete = ConcreteEmbeddings() + docs: list[dict[str, Any]] = [{"file_id": "file1"}, {"file_id": "file2"}] + result = concrete.embed_multimodal_documents(docs) + assert len(result) == 2 + + def test_concrete_implementation_embed_multimodal_query(self): + """Test concrete implementation of embed_multimodal_query.""" + concrete = ConcreteEmbeddings() + result = concrete.embed_multimodal_query({"file_id": "test"}) + assert len(result) == 10 + + +class TestEmbeddingsNotImplemented: + """Test that abstract methods raise NotImplementedError when called.""" + + def test_embed_query_raises_not_implemented(self): + """Test that embed_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_query("test") + + def test_embed_documents_raises_not_implemented(self): + """Test that embed_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_documents(["test"]) + + def test_embed_multimodal_documents_raises_not_implemented(self): + """Test that embed_multimodal_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_multimodal_documents([{"file_id": "test"}]) + + def test_embed_multimodal_query_raises_not_implemented(self): + """Test that embed_multimodal_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + with pytest.raises(NotImplementedError): + partial.embed_multimodal_query({"file_id": "test"}) + + def test_aembed_documents_raises_not_implemented(self): + """Test that aembed_documents raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + + async def run_test(): + with pytest.raises(NotImplementedError): + await partial.aembed_documents(["test"]) + + asyncio.run(run_test()) + + def test_aembed_query_raises_not_implemented(self): + """Test that aembed_query raises NotImplementedError.""" + + class PartialImpl: + pass + + PartialImpl.embed_query = lambda self, text: Embeddings.embed_query(self, text) + PartialImpl.embed_documents = lambda self, texts: Embeddings.embed_documents(self, texts) + PartialImpl.embed_multimodal_documents = lambda self, docs: Embeddings.embed_multimodal_documents(self, docs) + PartialImpl.embed_multimodal_query = lambda self, doc: Embeddings.embed_multimodal_query(self, doc) + PartialImpl.aembed_documents = lambda self, texts: Embeddings.aembed_documents(self, texts) + PartialImpl.aembed_query = lambda self, text: Embeddings.aembed_query(self, text) + + partial = PartialImpl() + + async def run_test(): + with pytest.raises(NotImplementedError): + await partial.aembed_query("test") + + asyncio.run(run_test()) diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 63596bc320..6e71f0c61f 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -52,14 +52,14 @@ import pytest from sqlalchemy.exc import IntegrityError from core.entities.embedding_type import EmbeddingInputType -from core.model_runtime.entities.model_entities import ModelPropertyKey -from core.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from core.model_runtime.errors.invoke import ( +from core.rag.embedding.cached_embedding import CacheEmbedding +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, ) -from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding diff --git a/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py new file mode 100644 index 0000000000..eb14622d7a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/blob/test_blob.py @@ -0,0 +1,85 @@ +from io import BytesIO + +import pytest + +from core.rag.extractor.blob.blob import Blob + + +class TestBlob: + def test_requires_data_or_path(self): + with pytest.raises(ValueError, match="Either data or path must be provided"): + Blob() + + def test_source_property_and_repr_include_path(self, tmp_path): + file_path = tmp_path / "sample.txt" + file_path.write_text("hello", encoding="utf-8") + + blob = Blob.from_path(str(file_path)) + + assert blob.source == str(file_path) + assert str(file_path) in repr(blob) + + def test_as_string_from_bytes_and_str(self): + assert Blob.from_data(b"abc").as_string() == "abc" + assert Blob.from_data("plain-text").as_string() == "plain-text" + + def test_as_string_from_path(self, tmp_path): + file_path = tmp_path / "sample.txt" + file_path.write_text("from-file", encoding="utf-8") + + blob = Blob.from_path(str(file_path)) + + assert blob.as_string() == "from-file" + + def test_as_string_raises_for_invalid_state(self): + blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8") + + with pytest.raises(ValueError, match="Unable to get string for blob"): + blob.as_string() + + def test_as_bytes_from_bytes_str_and_path(self, tmp_path): + from_bytes = Blob.from_data(b"abc") + from_str = Blob.from_data("abc", encoding="utf-8") + + file_path = tmp_path / "sample.bin" + file_path.write_bytes(b"from-path") + from_path = Blob.from_path(str(file_path)) + + assert from_bytes.as_bytes() == b"abc" + assert from_str.as_bytes() == b"abc" + assert from_path.as_bytes() == b"from-path" + + def test_as_bytes_raises_for_invalid_state(self): + blob = Blob.model_construct(data=None, path=None, mimetype=None, encoding="utf-8") + + with pytest.raises(ValueError, match="Unable to get bytes for blob"): + blob.as_bytes() + + def test_as_bytes_io_for_bytes_and_path(self, tmp_path): + data_blob = Blob.from_data(b"bytes-io") + with data_blob.as_bytes_io() as stream: + assert isinstance(stream, BytesIO) + assert stream.read() == b"bytes-io" + + file_path = tmp_path / "stream.bin" + file_path.write_bytes(b"path-stream") + path_blob = Blob.from_path(str(file_path)) + with path_blob.as_bytes_io() as stream: + assert stream.read() == b"path-stream" + + def test_as_bytes_io_raises_for_unsupported_data_type(self): + blob = Blob.from_data("text-value") + + with pytest.raises(NotImplementedError, match="Unable to convert blob"): + with blob.as_bytes_io(): + pass + + def test_from_path_respects_guessing_and_explicit_mime(self, tmp_path): + file_path = tmp_path / "example.txt" + file_path.write_text("x", encoding="utf-8") + + guessed = Blob.from_path(str(file_path)) + explicit = Blob.from_path(str(file_path), mime_type="custom/type", guess_type=False) + + assert guessed.mimetype == "text/plain" + assert explicit.mimetype == "custom/type" diff --git a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py index 4ee04ddebc..2add12fd09 100644 --- a/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py +++ b/api/tests/unit_tests/core/rag/extractor/firecrawl/test_firecrawl.py @@ -1,61 +1,340 @@ -import os +"""Unit tests for Firecrawl app and extractor integration points.""" + +import json +from collections.abc import Mapping +from typing import Any from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture +import core.rag.extractor.firecrawl.firecrawl_app as firecrawl_module from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp -from tests.unit_tests.core.rag.extractor.test_notion_extractor import _mock_response +from core.rag.extractor.firecrawl.firecrawl_web_extractor import FirecrawlWebExtractor -def test_firecrawl_web_extractor_crawl_mode(mocker: MockerFixture): - url = "https://firecrawl.dev" - api_key = os.getenv("FIRECRAWL_API_KEY") or "fc-" - base_url = "https://api.firecrawl.dev" - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=base_url) - params = { - "includePaths": [], - "excludePaths": [], - "maxDepth": 1, - "limit": 1, - } - mocked_firecrawl = { - "id": "test", - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_firecrawl)) - job_id = firecrawl_app.crawl_url(url, params) - - assert job_id is not None - assert isinstance(job_id, str) +def _response(status_code: int, json_data: Mapping[str, Any] | None = None, text: str = "") -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.text = text + response.json.return_value = json_data if json_data is not None else {} + return response -def test_build_url_normalizes_slashes_for_crawl(mocker: MockerFixture): - api_key = "fc-" - base_urls = ["https://custom.firecrawl.dev", "https://custom.firecrawl.dev/"] - for base in base_urls: - app = FirecrawlApp(api_key=api_key, base_url=base) - mock_post = mocker.patch("httpx.post") - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.json.return_value = {"id": "job123"} - mock_post.return_value = mock_resp - app.crawl_url("https://example.com", params=None) - called_url = mock_post.call_args[0][0] - assert called_url == "https://custom.firecrawl.dev/v2/crawl" +class TestFirecrawlApp: + def test_init_requires_api_key_for_default_base_url(self): + with pytest.raises(ValueError, match="No API key provided"): + FirecrawlApp(api_key=None, base_url="https://api.firecrawl.dev") + + def test_prepare_headers_and_build_url(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev/") + + assert app._prepare_headers() == { + "Content-Type": "application/json", + "Authorization": "Bearer fc-key", + } + assert app._build_url("/v2/crawl") == "https://custom.firecrawl.dev/v2/crawl" + + def test_scrape_url_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch( + "httpx.post", + return_value=_response( + 200, + { + "data": { + "metadata": { + "title": "t", + "description": "d", + "sourceURL": "https://example.com", + }, + "markdown": "body", + } + }, + ), + ) + + result = app.scrape_url("https://example.com", params={"onlyMainContent": False}) + + assert result == { + "title": "t", + "description": "d", + "source_url": "https://example.com", + "markdown": "body", + } + + def test_scrape_url_handles_known_error_status(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("boom")) + mocker.patch("httpx.post", return_value=_response(429, {"error": "limit"})) + + with pytest.raises(Exception, match="boom"): + app.scrape_url("https://example.com") + + mock_handle.assert_called_once() + + def test_scrape_url_unknown_status_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(404, text="Not Found")) + + with pytest.raises(Exception, match="Failed to scrape URL. Status code: 404"): + app.scrape_url("https://example.com") + + def test_crawl_url_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"id": "job-1"})) + + assert app.crawl_url("https://example.com") == "job-1" + + def test_crawl_url_non_200_uses_error_handler(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl failed")) + mocker.patch("httpx.post", return_value=_response(500, {"error": "server"})) + + with pytest.raises(Exception, match="crawl failed"): + app.crawl_url("https://example.com") + + mock_handle.assert_called_once() + + def test_map_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": True, "links": ["a", "b"]})) + + assert app.map("https://example.com") == {"success": True, "links": ["a", "b"]} + + def test_map_known_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("map error")) + mocker.patch("httpx.post", return_value=_response(409, {"error": "conflict"})) + + with pytest.raises(Exception, match="map error"): + app.map("https://example.com") + mock_handle.assert_called_once() + + def test_map_unknown_error_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(418, text="teapot")) + + with pytest.raises(Exception, match="Failed to start map job. Status code: 418"): + app.map("https://example.com") + + def test_check_crawl_status_completed_with_data(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = { + "status": "completed", + "total": 2, + "completed": 2, + "data": [ + { + "metadata": {"title": "a", "description": "desc-a", "sourceURL": "https://a"}, + "markdown": "m-a", + }, + { + "metadata": {"title": "b", "description": "desc-b", "sourceURL": "https://b"}, + "markdown": "m-b", + }, + {"metadata": {"title": "skip"}}, + ], + } + mocker.patch("httpx.get", return_value=_response(200, payload)) + + save_calls: list[tuple[str, bytes]] = [] + delete_calls: list[str] = [] + + mock_storage = MagicMock() + mock_storage.exists.return_value = True + mock_storage.delete.side_effect = lambda key: delete_calls.append(key) + mock_storage.save.side_effect = lambda key, data: save_calls.append((key, data)) + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + result = app.check_crawl_status("job-42") + + assert result["status"] == "completed" + assert result["total"] == 2 + assert result["current"] == 2 + assert len(result["data"]) == 2 + assert delete_calls == ["website_files/job-42.txt"] + assert len(save_calls) == 1 + assert save_calls[0][0] == "website_files/job-42.txt" + + def test_check_crawl_status_completed_with_zero_total_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.get", return_value=_response(200, {"status": "completed", "total": 0, "data": []})) + + with pytest.raises(Exception, match="No page found"): + app.check_crawl_status("job-1") + + def test_check_crawl_status_non_completed(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = {"status": "processing", "total": 5, "completed": 1, "data": []} + mocker.patch("httpx.get", return_value=_response(200, payload)) + + assert app.check_crawl_status("job-1") == { + "status": "processing", + "total": 5, + "current": 1, + "data": [], + } + + def test_check_crawl_status_non_200_uses_error_handler(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("crawl error")) + mocker.patch("httpx.get", return_value=_response(500, {"error": "server"})) + + with pytest.raises(Exception, match="crawl error"): + app.check_crawl_status("job-1") + mock_handle.assert_called_once() + + def test_check_crawl_status_save_failure_raises(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + payload = { + "status": "completed", + "total": 1, + "completed": 1, + "data": [{"metadata": {"title": "a", "sourceURL": "https://a"}, "markdown": "m-a"}], + } + mocker.patch("httpx.get", return_value=_response(200, payload)) + + mock_storage = MagicMock() + mock_storage.exists.return_value = False + mock_storage.save.side_effect = RuntimeError("save failed") + mocker.patch.object(firecrawl_module, "storage", mock_storage) + + with pytest.raises(Exception, match="Error saving crawl data"): + app.check_crawl_status("job-err") + + def test_extract_common_fields_and_status_formatter(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + + fields = app._extract_common_fields( + {"metadata": {"title": "t", "description": "d", "sourceURL": "u"}, "markdown": "m"} + ) + assert fields == {"title": "t", "description": "d", "source_url": "u", "markdown": "m"} + + status = app._format_crawl_status_response("completed", {"total": 1, "completed": 1}, [fields]) + assert status == {"status": "completed", "total": 1, "current": 1, "data": [fields]} + + def test_post_and_get_request_retry_logic(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep") + + resp_502_a = _response(502) + resp_502_b = _response(502) + resp_200 = _response(200) + + mocker.patch("httpx.post", side_effect=[resp_502_a, resp_200]) + post_result = app._post_request("u", {"x": 1}, {"h": 1}, retries=3, backoff_factor=0.5) + assert post_result is resp_200 + + mocker.patch("httpx.get", side_effect=[resp_502_b, _response(200)]) + get_result = app._get_request("u", {"h": 1}, retries=3, backoff_factor=0.25) + assert get_result.status_code == 200 + + assert sleep_mock.call_count == 2 + + def test_post_and_get_request_return_last_502(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + sleep_mock = mocker.patch.object(firecrawl_module.time, "sleep") + + last_post = _response(502) + mocker.patch("httpx.post", side_effect=[_response(502), last_post]) + assert app._post_request("u", {}, {}, retries=2).status_code == 502 + + last_get = _response(502) + mocker.patch("httpx.get", side_effect=[_response(502), last_get]) + assert app._get_request("u", {}, retries=2).status_code == 502 + + assert sleep_mock.call_count == 4 + + def test_handle_error_with_json_and_plain_text(self): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + + json_error = _response(400, {"message": "bad request"}) + with pytest.raises(Exception, match="bad request"): + app._handle_error(json_error, "run task") + + non_json = MagicMock() + non_json.status_code = 400 + non_json.text = "plain error" + non_json.json.side_effect = json.JSONDecodeError("bad", "x", 0) + + with pytest.raises(Exception, match="plain error"): + app._handle_error(non_json, "run task") + + def test_search_success(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": True, "data": [{"url": "x"}]})) + assert app.search("python")["success"] is True + + def test_search_warning_failure(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(200, {"success": False, "warning": "bad search"})) + with pytest.raises(Exception, match="bad search"): + app.search("python") + + def test_search_known_http_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mock_handle = mocker.patch.object(app, "_handle_error", side_effect=Exception("search error")) + mocker.patch("httpx.post", return_value=_response(408, {"error": "timeout"})) + with pytest.raises(Exception, match="search error"): + app.search("python") + mock_handle.assert_called_once() + + def test_search_unknown_http_error(self, mocker: MockerFixture): + app = FirecrawlApp(api_key="fc-key", base_url="https://custom.firecrawl.dev") + mocker.patch("httpx.post", return_value=_response(418, text="teapot")) + with pytest.raises(Exception, match="Failed to perform search. Status code: 418"): + app.search("python") -def test_error_handler_handles_non_json_error_bodies(mocker: MockerFixture): - api_key = "fc-" - app = FirecrawlApp(api_key=api_key, base_url="https://custom.firecrawl.dev/") - mock_post = mocker.patch("httpx.post") - mock_resp = MagicMock() - mock_resp.status_code = 404 - mock_resp.text = "Not Found" - mock_resp.json.side_effect = Exception("Not JSON") - mock_post.return_value = mock_resp +class TestFirecrawlWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "markdown": "crawl content", + "source_url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) - with pytest.raises(Exception) as excinfo: - app.scrape_url("https://example.com") + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() - # Should not raise a JSONDecodeError; current behavior reports status code only - assert str(excinfo.value) == "Failed to scrape URL. Status code: 404" + assert len(docs) == 1 + assert docs[0].page_content == "crawl content" + assert docs[0].metadata["source_url"] == "https://example.com" + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + assert extractor.extract() == [] + + def test_extract_scrape_mode_returns_document(self, mocker: MockerFixture): + mock_scrape = mocker.patch( + "core.rag.extractor.firecrawl.firecrawl_web_extractor.WebsiteService.get_scrape_url_data", + return_value={ + "markdown": "scrape content", + "source_url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = FirecrawlWebExtractor( + "https://example.com", "job-1", "tenant-1", mode="scrape", only_main_content=False + ) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "scrape content" + mock_scrape.assert_called_once_with("firecrawl", "https://example.com", "tenant-1", False) + + def test_extract_unknown_mode_returns_empty(self): + extractor = FirecrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="unknown") + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py new file mode 100644 index 0000000000..e6a06f163e --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_csv_extractor.py @@ -0,0 +1,95 @@ +import csv +import io +from types import SimpleNamespace + +import pandas as pd +import pytest + +import core.rag.extractor.csv_extractor as csv_module +from core.rag.extractor.csv_extractor import CSVExtractor + + +class _ManagedStringIO(io.StringIO): + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.close() + return False + + +class TestCSVExtractor: + def test_extract_success_with_source_column(self, tmp_path): + file_path = tmp_path / "data.csv" + file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") + + extractor = CSVExtractor(str(file_path), source_column="id") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "id: source-1;body: hello" + assert docs[0].metadata == {"source": "source-1", "row": 0} + + def test_extract_raises_when_source_column_missing(self, tmp_path): + file_path = tmp_path / "data.csv" + file_path.write_text("id,body\nsource-1,hello\n", encoding="utf-8") + + extractor = CSVExtractor(str(file_path), source_column="missing_col") + + with pytest.raises(ValueError, match="Source column 'missing_col' not found"): + extractor.extract() + + def test_extract_wraps_unicode_error_when_autodetect_disabled(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=False) + + def raise_decode(*args, **kwargs): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + + monkeypatch.setattr("builtins.open", raise_decode) + + with pytest.raises(RuntimeError, match="Error loading dummy.csv"): + extractor.extract() + + def test_extract_autodetect_encoding_success(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=True) + attempted_encodings: list[str | None] = [] + + def fake_open(path, newline="", encoding=None): + attempted_encodings.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + return _ManagedStringIO("id,body\nsource-1,hello\n") + + monkeypatch.setattr("builtins.open", fake_open) + monkeypatch.setattr( + csv_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "id: source-1;body: hello" + assert attempted_encodings == [None, "bad", "utf-8"] + + def test_extract_autodetect_encoding_all_attempts_fail_returns_empty(self, monkeypatch): + extractor = CSVExtractor("dummy.csv", autodetect_encoding=True) + + def always_raise(*args, **kwargs): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode error") + + monkeypatch.setattr("builtins.open", always_raise) + monkeypatch.setattr(csv_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="bad")]) + + assert extractor.extract() == [] + + def test_read_from_file_re_raises_csv_error(self, monkeypatch): + extractor = CSVExtractor("dummy.csv") + + monkeypatch.setattr(pd, "read_csv", lambda *args, **kwargs: (_ for _ in ()).throw(csv.Error("bad csv"))) + + with pytest.raises(csv.Error, match="bad csv"): + extractor._read_from_file(io.StringIO("x")) diff --git a/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py new file mode 100644 index 0000000000..d2bcc1e2c4 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_excel_extractor.py @@ -0,0 +1,117 @@ +from types import SimpleNamespace + +import pandas as pd +import pytest + +import core.rag.extractor.excel_extractor as excel_module +from core.rag.extractor.excel_extractor import ExcelExtractor + + +class _FakeCell: + def __init__(self, value, hyperlink=None): + self.value = value + self.hyperlink = hyperlink + + +class _FakeSheet: + def __init__(self, header_rows, data_rows): + self._header_rows = header_rows + self._data_rows = data_rows + + def iter_rows(self, min_row=1, max_row=None, max_col=None, values_only=False): + if values_only: + for row in self._header_rows: + yield tuple(row) + return + + for row in self._data_rows: + if max_col is not None: + yield tuple(row[:max_col]) + else: + yield tuple(row) + + +class _FakeWorkbook: + def __init__(self, sheets): + self._sheets = sheets + self.sheetnames = list(sheets.keys()) + self.closed = False + + def __getitem__(self, key): + return self._sheets[key] + + def close(self): + self.closed = True + + +class TestExcelExtractor: + def test_extract_xlsx_with_hyperlinks_and_sheet_skip(self, monkeypatch): + sheet_with_data = _FakeSheet( + header_rows=[("Name", "Link")], + data_rows=[ + (_FakeCell("Alice"), _FakeCell("Doc", hyperlink=SimpleNamespace(target="https://example.com/doc"))), + (_FakeCell(None), _FakeCell(123)), + (_FakeCell(None), _FakeCell(None)), + ], + ) + empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[]) + + workbook = _FakeWorkbook({"Data": sheet_with_data, "Empty": empty_sheet}) + monkeypatch.setattr(excel_module, "load_workbook", lambda *args, **kwargs: workbook) + + extractor = ExcelExtractor("/tmp/sample.xlsx") + docs = extractor.extract() + + assert workbook.closed is True + assert len(docs) == 2 + assert docs[0].page_content == '"Name":"Alice";"Link":"[Doc](https://example.com/doc)"' + assert docs[1].page_content == '"Name":"";"Link":"123"' + assert all(doc.metadata["source"] == "/tmp/sample.xlsx" for doc in docs) + + def test_extract_xls_path(self, monkeypatch): + class FakeExcelFile: + sheet_names = ["Sheet1"] + + def parse(self, sheet_name): + assert sheet_name == "Sheet1" + return pd.DataFrame([{"A": "x", "B": 1}, {"A": None, "B": None}]) + + monkeypatch.setattr(pd, "ExcelFile", lambda path, engine=None: FakeExcelFile()) + + extractor = ExcelExtractor("/tmp/sample.xls") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == '"A":"x";"B":"1.0"' + assert docs[0].metadata == {"source": "/tmp/sample.xls"} + + def test_extract_unsupported_extension_raises(self): + extractor = ExcelExtractor("/tmp/sample.txt") + + with pytest.raises(ValueError, match="Unsupported file extension"): + extractor.extract() + + def test_find_header_and_columns_prefers_first_row_with_two_columns(self): + sheet = _FakeSheet( + header_rows=[(None, None, None), ("A", "B", None), ("X", None, None)], + data_rows=[], + ) + extractor = ExcelExtractor("dummy.xlsx") + + header_row_idx, column_map, max_col_idx = extractor._find_header_and_columns(sheet) + + assert header_row_idx == 2 + assert column_map == {0: "A", 1: "B"} + assert max_col_idx == 2 + + def test_find_header_and_columns_fallback_and_empty_case(self): + extractor = ExcelExtractor("dummy.xlsx") + + fallback_sheet = _FakeSheet(header_rows=[("Only", None), (None, "Second")], data_rows=[]) + row_idx, column_map, max_col_idx = extractor._find_header_and_columns(fallback_sheet) + assert row_idx == 1 + assert column_map == {0: "Only"} + assert max_col_idx == 1 + + empty_sheet = _FakeSheet(header_rows=[(None, None)], data_rows=[]) + assert extractor._find_header_and_columns(empty_sheet) == (0, {}, 0) diff --git a/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py new file mode 100644 index 0000000000..5beed88971 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_extract_processor.py @@ -0,0 +1,272 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.extract_processor as processor_module +from core.rag.extractor.entity.datasource_type import DatasourceType +from core.rag.extractor.extract_processor import ExtractProcessor +from core.rag.models.document import Document + + +class _ExtractorFactory: + def __init__(self) -> None: + self.calls = [] + + def make(self, name: str) -> type[object]: + calls = self.calls + + class DummyExtractor: + def __init__(self, *args, **kwargs): + calls.append((name, args, kwargs)) + + def extract(self): + return [Document(page_content=f"extracted-by-{name}")] + + return DummyExtractor + + +def _patch_all_extractors(monkeypatch) -> _ExtractorFactory: + factory = _ExtractorFactory() + + for cls_name in [ + "CSVExtractor", + "ExcelExtractor", + "FirecrawlWebExtractor", + "HtmlExtractor", + "JinaReaderWebExtractor", + "MarkdownExtractor", + "NotionExtractor", + "PdfExtractor", + "TextExtractor", + "UnstructuredEmailExtractor", + "UnstructuredEpubExtractor", + "UnstructuredMarkdownExtractor", + "UnstructuredMsgExtractor", + "UnstructuredPPTExtractor", + "UnstructuredPPTXExtractor", + "UnstructuredWordExtractor", + "UnstructuredXmlExtractor", + "WaterCrawlWebExtractor", + "WordExtractor", + ]: + monkeypatch.setattr(processor_module, cls_name, factory.make(cls_name)) + + return factory + + +class TestExtractProcessorLoaders: + def test_load_from_upload_file_return_docs_and_text(self, monkeypatch): + monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs)) + + monkeypatch.setattr( + ExtractProcessor, + "extract", + lambda extract_setting, is_automatic=False, file_path=None: [ + Document(page_content="doc-1"), + Document(page_content="doc-2"), + ], + ) + + upload_file = SimpleNamespace(key="file.txt") + + docs = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=False) + text = ExtractProcessor.load_from_upload_file(upload_file=upload_file, return_text=True) + + assert len(docs) == 2 + assert text == "doc-1\ndoc-2" + + @pytest.mark.parametrize( + ("url", "headers", "expected_suffix"), + [ + ("https://example.com/file.txt", {"Content-Type": "text/plain"}, ".txt"), + ("https://example.com/no_suffix", {"Content-Type": "application/pdf"}, ".pdf"), + ( + "https://example.com/no_suffix", + {"Content-Disposition": 'attachment; filename="report.md"'}, + ".md", + ), + ( + "https://example.com/no_suffix", + {"Content-Disposition": 'attachment; filename="report"'}, + "", + ), + ], + ) + def test_load_from_url_builds_temp_file_with_correct_suffix(self, monkeypatch, url, headers, expected_suffix): + response = SimpleNamespace(headers=headers, content=b"body") + monkeypatch.setattr(processor_module.ssrf_proxy, "get", lambda *args, **kwargs: response) + monkeypatch.setattr(processor_module, "ExtractSetting", lambda **kwargs: SimpleNamespace(**kwargs)) + + captured = {} + + def fake_extract(extract_setting, is_automatic=False, file_path=None): + key = "file_path_docs" if "file_path_docs" not in captured else "file_path_text" + captured[key] = file_path + return [Document(page_content="u1"), Document(page_content="u2")] + + monkeypatch.setattr(ExtractProcessor, "extract", fake_extract) + + docs = ExtractProcessor.load_from_url(url, return_text=False) + assert captured["file_path_docs"].endswith(expected_suffix) + + text = ExtractProcessor.load_from_url(url, return_text=True) + assert captured["file_path_text"].endswith(expected_suffix) + + assert len(docs) == 2 + assert text == "u1\nu2" + + +class TestExtractProcessorFileRouting: + @pytest.fixture(autouse=True) + def _set_unstructured_config(self, monkeypatch): + monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_URL", "https://unstructured") + monkeypatch.setattr(processor_module.dify_config, "UNSTRUCTURED_API_KEY", "key") + + def _run_extract_for_extension(self, monkeypatch, extension: str, etl_type: str, is_automatic: bool = False): + factory = _patch_all_extractors(monkeypatch) + monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type) + + def fake_download(key: str, local_path: str): + Path(local_path).write_text("content", encoding="utf-8") + + monkeypatch.setattr(processor_module.storage, "download", fake_download) + monkeypatch.setattr(processor_module.tempfile, "_get_candidate_names", lambda: iter(["candidate-name"])) + + setting = SimpleNamespace( + datasource_type=DatasourceType.FILE, + upload_file=SimpleNamespace(key=f"uploaded{extension}", tenant_id="tenant-1", created_by="user-1"), + ) + + docs = ExtractProcessor.extract(setting, is_automatic=is_automatic) + + assert len(docs) == 1 + assert docs[0].page_content.startswith("extracted-by-") + return factory.calls[-1][0], factory.calls[-1][1], factory.calls[-1][2] + + @pytest.mark.parametrize( + ("extension", "expected_extractor", "is_automatic"), + [ + (".xlsx", "ExcelExtractor", False), + (".xls", "ExcelExtractor", False), + (".pdf", "PdfExtractor", False), + (".md", "UnstructuredMarkdownExtractor", True), + (".mdx", "MarkdownExtractor", False), + (".htm", "HtmlExtractor", False), + (".html", "HtmlExtractor", False), + (".docx", "WordExtractor", False), + (".doc", "UnstructuredWordExtractor", False), + (".csv", "CSVExtractor", False), + (".msg", "UnstructuredMsgExtractor", False), + (".eml", "UnstructuredEmailExtractor", False), + (".ppt", "UnstructuredPPTExtractor", False), + (".pptx", "UnstructuredPPTXExtractor", False), + (".xml", "UnstructuredXmlExtractor", False), + (".epub", "UnstructuredEpubExtractor", False), + (".txt", "TextExtractor", False), + ], + ) + def test_extract_routes_file_extensions_for_unstructured_mode( + self, monkeypatch, extension, expected_extractor, is_automatic + ): + extractor_name, args, kwargs = self._run_extract_for_extension( + monkeypatch, extension, etl_type="Unstructured", is_automatic=is_automatic + ) + + assert extractor_name == expected_extractor + assert args + + @pytest.mark.parametrize( + ("extension", "expected_extractor"), + [ + (".xlsx", "ExcelExtractor"), + (".pdf", "PdfExtractor"), + (".markdown", "MarkdownExtractor"), + (".html", "HtmlExtractor"), + (".docx", "WordExtractor"), + (".csv", "CSVExtractor"), + (".epub", "UnstructuredEpubExtractor"), + (".txt", "TextExtractor"), + ], + ) + def test_extract_routes_file_extensions_for_default_mode(self, monkeypatch, extension, expected_extractor): + extractor_name, _, _ = self._run_extract_for_extension(monkeypatch, extension, etl_type="SelfHosted") + + assert extractor_name == expected_extractor + + def test_extract_requires_upload_file_when_file_path_not_provided(self): + setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None) + + with pytest.raises(AssertionError, match="upload_file is required"): + ExtractProcessor.extract(setting) + + +class TestExtractProcessorDatasourceRouting: + def test_extract_routes_notion_datasource(self, monkeypatch): + factory = _patch_all_extractors(monkeypatch) + + notion_info = SimpleNamespace( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + document="doc", + tenant_id="tenant", + credential_id="cred", + ) + setting = SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=notion_info) + + docs = ExtractProcessor.extract(setting) + + assert docs[0].page_content == "extracted-by-NotionExtractor" + assert factory.calls[-1][0] == "NotionExtractor" + + @pytest.mark.parametrize( + ("provider", "expected"), + [ + ("firecrawl", "FirecrawlWebExtractor"), + ("watercrawl", "WaterCrawlWebExtractor"), + ("jinareader", "JinaReaderWebExtractor"), + ], + ) + def test_extract_routes_website_datasource_providers(self, monkeypatch, provider: str, expected: str): + factory = _patch_all_extractors(monkeypatch) + + website_info = SimpleNamespace( + provider=provider, + url="https://example.com", + job_id="job", + tenant_id="tenant", + mode="crawl", + only_main_content=True, + ) + setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=website_info) + + docs = ExtractProcessor.extract(setting) + assert docs[0].page_content == f"extracted-by-{expected}" + assert factory.calls[-1][0] == expected + + def test_extract_unsupported_website_provider(self): + bad_provider = SimpleNamespace( + provider="unknown", + url="https://example.com", + job_id="job", + tenant_id="tenant", + mode="crawl", + only_main_content=True, + ) + setting = SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=bad_provider) + + with pytest.raises(ValueError, match="Unsupported website provider"): + ExtractProcessor.extract(setting) + + def test_extract_unsupported_datasource_type(self): + with pytest.raises(ValueError, match="Unsupported datasource type"): + ExtractProcessor.extract(SimpleNamespace(datasource_type="unknown")) + + def test_extract_requires_notion_info(self): + with pytest.raises(AssertionError, match="notion_info is required"): + ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.NOTION, notion_info=None)) + + def test_extract_requires_website_info(self): + with pytest.raises(AssertionError, match="website_info is required"): + ExtractProcessor.extract(SimpleNamespace(datasource_type=DatasourceType.WEBSITE, website_info=None)) diff --git a/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py b/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py new file mode 100644 index 0000000000..1d5f27181b --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_extractor_base.py @@ -0,0 +1,26 @@ +import pytest + +from core.rag.extractor.extractor_base import BaseExtractor + + +class _CallsBaseExtractor(BaseExtractor): + def extract(self): + return super().extract() + + +class _ConcreteExtractor(BaseExtractor): + def extract(self): + return ["ok"] + + +class TestBaseExtractor: + def test_extract_default_raises_not_implemented(self): + extractor = _CallsBaseExtractor() + + with pytest.raises(NotImplementedError): + extractor.extract() + + def test_concrete_extractor_can_override(self): + extractor = _ConcreteExtractor() + + assert extractor.extract() == ["ok"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_helpers.py b/api/tests/unit_tests/core/rag/extractor/test_helpers.py index edf8735e57..74387f749d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_helpers.py +++ b/api/tests/unit_tests/core/rag/extractor/test_helpers.py @@ -1,10 +1,55 @@ import tempfile +from types import SimpleNamespace -from core.rag.extractor.helpers import FileEncoding, detect_file_encodings +import pytest + +from core.rag.extractor import helpers +from core.rag.extractor.helpers import detect_file_encodings -def test_detect_file_encodings() -> None: - with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: - temp.write("Shared data") - temp_path = temp.name - assert detect_file_encodings(temp_path) == [FileEncoding(encoding="utf_8", confidence=0.0, language="Unknown")] +class TestHelpers: + def test_detect_file_encodings(self) -> None: + with tempfile.NamedTemporaryFile(mode="w+t", suffix=".txt") as temp: + temp.write("Shared data") + temp.flush() + temp_path = temp.name + encodings = detect_file_encodings(temp_path) + + assert len(encodings) == 1 + assert encodings[0].encoding in {"utf_8", "ascii"} + assert encodings[0].confidence == 0.0 + # Assert the language field for full coverage + assert encodings[0].language is not None + + def test_detect_file_encodings_timeout(self, monkeypatch): + class FakeFuture: + def result(self, timeout=None): + raise helpers.concurrent.futures.TimeoutError() + + class FakeExecutor: + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def submit(self, fn, file_path): + return FakeFuture() + + monkeypatch.setattr(helpers.concurrent.futures, "ThreadPoolExecutor", lambda: FakeExecutor()) + + with pytest.raises(TimeoutError, match="Timeout reached while detecting encoding"): + detect_file_encodings("file.txt", timeout=1) + + def test_detect_file_encodings_raises_when_encoding_not_detected(self, monkeypatch): + class FakeResult: + encoding = None + coherence = 0.0 + language = None + + monkeypatch.setattr( + helpers.charset_normalizer, "from_path", lambda _: SimpleNamespace(best=lambda: FakeResult()) + ) + + with pytest.raises(RuntimeError, match="Could not detect encoding"): + detect_file_encodings("file.txt") diff --git a/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py new file mode 100644 index 0000000000..8bc65e5654 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_html_extractor.py @@ -0,0 +1,21 @@ +from core.rag.extractor.html_extractor import HtmlExtractor + + +class TestHtmlExtractor: + def test_extract_returns_text_content(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text("

Title

Hello

", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert "".join(docs[0].page_content.split()) == "TitleHello" + + def test_load_as_text_strips_whitespace_and_handles_empty(self, tmp_path): + file_path = tmp_path / "sample.html" + file_path.write_text(" \n ", encoding="utf-8") + + extractor = HtmlExtractor(str(file_path)) + + assert extractor._load_as_text() == "" diff --git a/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py new file mode 100644 index 0000000000..0b4c9bd809 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_jina_reader_extractor.py @@ -0,0 +1,47 @@ +from pytest_mock import MockerFixture + +from core.rag.extractor.jina_reader_extractor import JinaReaderWebExtractor + + +class TestJinaReaderWebExtractor: + def test_extract_crawl_mode_returns_document(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={ + "content": "markdown-content", + "url": "https://example.com", + "description": "desc", + "title": "title", + }, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "markdown-content" + assert docs[0].metadata == { + "source_url": "https://example.com", + "description": "desc", + "title": "title", + } + + def test_extract_crawl_mode_with_missing_data_returns_empty(self, mocker: MockerFixture): + mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value=None, + ) + + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_non_crawl_mode_returns_empty(self, mocker: MockerFixture): + mock_get_crawl = mocker.patch( + "core.rag.extractor.jina_reader_extractor.WebsiteService.get_crawl_url_data", + return_value={"content": "unused"}, + ) + extractor = JinaReaderWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert extractor.extract() == [] + mock_get_crawl.assert_not_called() diff --git a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py index d4cf534c56..7e78c86c7d 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_markdown_extractor.py @@ -1,8 +1,15 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.markdown_extractor as markdown_module from core.rag.extractor.markdown_extractor import MarkdownExtractor -def test_markdown_to_tups(): - markdown = """ +class TestMarkdownExtractor: + def test_markdown_to_tups(self): + markdown = """ this is some text without header # title 1 @@ -11,12 +18,113 @@ this is balabala text ## title 2 this is more specific text. """ - extractor = MarkdownExtractor(file_path="dummy_path") - updated_output = extractor.markdown_to_tups(markdown) - assert len(updated_output) == 3 - key, header_value = updated_output[0] - assert key == None - assert header_value.strip() == "this is some text without header" - title_1, value = updated_output[1] - assert title_1.strip() == "title 1" - assert value.strip() == "this is balabala text" + extractor = MarkdownExtractor(file_path="dummy_path") + updated_output = extractor.markdown_to_tups(markdown) + + assert len(updated_output) == 3 + key, header_value = updated_output[0] + assert key is None + assert header_value.strip() == "this is some text without header" + + title_1, value = updated_output[1] + assert title_1.strip() == "title 1" + assert value.strip() == "this is balabala text" + + def test_markdown_to_tups_keeps_code_block_headers_literal(self): + markdown = """# Header +before +```python +# this is not a heading +print('x') +``` +after +""" + extractor = MarkdownExtractor(file_path="dummy_path") + + tups = extractor.markdown_to_tups(markdown) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "# this is not a heading" in tups[1][1] + + def test_remove_images_and_hyperlinks(self): + extractor = MarkdownExtractor(file_path="dummy_path") + + with_images = "before ![[image.png]] after" + with_links = "[OpenAI](https://openai.com)" + + assert extractor.remove_images(with_images) == "before after" + assert extractor.remove_hyperlinks(with_links) == "OpenAI" + + def test_parse_tups_reads_file_and_applies_options(self, tmp_path): + markdown_file = tmp_path / "doc.md" + markdown_file.write_text("# Header\nText with [link](https://example.com) and ![[img.png]]", encoding="utf-8") + + extractor = MarkdownExtractor( + file_path=str(markdown_file), + remove_hyperlinks=True, + remove_images=True, + autodetect_encoding=False, + ) + + tups = extractor.parse_tups(str(markdown_file)) + + assert len(tups) == 2 + assert tups[1][0] == "Header" + assert "[link]" not in tups[1][1] + assert "img.png" not in tups[1][1] + + def test_parse_tups_autodetects_encoding_after_decode_error(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=True) + + calls: list[str | None] = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + if encoding == "bad-encoding": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + return "# H\ncontent" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + markdown_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad-encoding"), SimpleNamespace(encoding="utf-8")], + ) + + tups = extractor.parse_tups("dummy_path") + + assert len(tups) == 2 + assert calls == [None, "bad-encoding", "utf-8"] + + def test_parse_tups_decode_error_with_autodetect_disabled_raises(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path", autodetect_encoding=False) + + def raise_decode(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "fail") + + monkeypatch.setattr(Path, "read_text", raise_decode, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_parse_tups_other_exceptions_are_wrapped(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + + def raise_other(self, encoding=None): + raise OSError("disk error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy_path"): + extractor.parse_tups("dummy_path") + + def test_extract_builds_documents_for_header_and_non_header(self, monkeypatch): + extractor = MarkdownExtractor(file_path="dummy_path") + monkeypatch.setattr(extractor, "parse_tups", lambda _: [(None, "plain"), ("Header", "value")]) + + docs = extractor.extract() + + assert [doc.page_content for doc in docs] == ["plain", "\n\nHeader\nvalue"] diff --git a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py index 58bec7d19e..6daee11f8f 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_notion_extractor.py @@ -1,93 +1,499 @@ +from types import SimpleNamespace from unittest import mock +import httpx +import pytest from pytest_mock import MockerFixture from core.rag.extractor import notion_extractor -user_id = "user1" -database_id = "database1" -page_id = "page1" - -extractor = notion_extractor.NotionExtractor( - notion_workspace_id="x", notion_obj_id="x", notion_page_type="page", tenant_id="x", notion_access_token="x" -) - - -def _generate_page(page_title: str): - return { - "object": "page", - "id": page_id, - "properties": { - "Page": { - "type": "title", - "title": [{"type": "text", "text": {"content": page_title}, "plain_text": page_title}], - } - }, - } - - -def _generate_block(block_id: str, block_type: str, block_text: str): - return { - "object": "block", - "id": block_id, - "parent": {"type": "page_id", "page_id": page_id}, - "type": block_type, - "has_children": False, - block_type: { - "rich_text": [ - { - "type": "text", - "text": {"content": block_text}, - "plain_text": block_text, - } - ] - }, - } - - -def _mock_response(data): +def _mock_response(data, status_code: int = 200, text: str = ""): response = mock.Mock() - response.status_code = 200 + response.status_code = status_code + response.text = text response.json.return_value = data return response -def _remove_multiple_new_lines(text): - while "\n\n" in text: - text = text.replace("\n\n", "\n") - return text.strip() +class TestNotionExtractorInitAndPublicMethods: + def test_init_with_explicit_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor._notion_access_token == "token" + + def test_init_falls_back_to_env_token_when_credential_lookup_fails(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", "env-token", raising=False) + + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + assert extractor._notion_access_token == "env-token" + + def test_init_raises_if_no_credential_and_no_env_token(self, monkeypatch): + monkeypatch.setattr( + notion_extractor.NotionExtractor, + "_get_access_token", + classmethod(lambda cls, tenant_id, credential_id: (_ for _ in ()).throw(Exception("credential error"))), + ) + monkeypatch.setattr(notion_extractor.dify_config, "NOTION_INTEGRATION_TOKEN", None, raising=False) + + with pytest.raises(ValueError, match="Must specify `integration_token`"): + notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + credential_id="cred", + ) + + def test_extract_updates_last_edited_and_loads_documents(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + update_mock = mock.Mock() + load_mock = mock.Mock(return_value=[SimpleNamespace(page_content="doc")]) + monkeypatch.setattr(extractor, "update_last_edited_time", update_mock) + monkeypatch.setattr(extractor, "_load_data_as_documents", load_mock) + + docs = extractor.extract() + + update_mock.assert_called_once_with(None) + load_mock.assert_called_once_with("obj", "page") + assert len(docs) == 1 + + def test_load_data_as_documents_page_database_and_invalid(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + monkeypatch.setattr(extractor, "_get_notion_block_data", lambda _: ["line1", "line2"]) + page_docs = extractor._load_data_as_documents("page-id", "page") + assert page_docs[0].page_content == "line1\nline2" + + monkeypatch.setattr(extractor, "_get_notion_database_data", lambda _: [SimpleNamespace(page_content="db")]) + db_docs = extractor._load_data_as_documents("db-id", "database") + assert db_docs[0].page_content == "db" + + with pytest.raises(ValueError, match="notion page type not supported"): + extractor._load_data_as_documents("obj", "unsupported") -def test_notion_page(mocker: MockerFixture): - texts = ["Head 1", "1.1", "paragraph 1", "1.1.1"] - mocked_notion_page = { - "object": "list", - "results": [ - _generate_block("b1", "heading_1", texts[0]), - _generate_block("b2", "heading_2", texts[1]), - _generate_block("b3", "paragraph", texts[2]), - _generate_block("b4", "heading_3", texts[3]), - ], - "next_cursor": None, - } - mocker.patch("httpx.request", return_value=_mock_response(mocked_notion_page)) +class TestNotionDatabase: + def test_get_notion_database_data_parses_property_types_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) - page_docs = extractor._load_data_as_documents(page_id, "page") - assert len(page_docs) == 1 - content = _remove_multiple_new_lines(page_docs[0].page_content) - assert content == "# Head 1\n## 1.1\nparagraph 1\n### 1.1.1" + first_page = { + "results": [ + { + "properties": { + "tags": { + "type": "multi_select", + "multi_select": [{"name": "A"}, {"name": "B"}], + }, + "title_prop": {"type": "title", "title": [{"plain_text": "Title"}]}, + "empty_title": {"type": "title", "title": []}, + "rich": {"type": "rich_text", "rich_text": [{"plain_text": "RichText"}]}, + "empty_rich": {"type": "rich_text", "rich_text": []}, + "select_prop": {"type": "select", "select": {"name": "Selected"}}, + "empty_select": {"type": "select", "select": None}, + "status_prop": {"type": "status", "status": {"name": "Open"}}, + "empty_status": {"type": "status", "status": None}, + "number_prop": {"type": "number", "number": 10}, + "dict_prop": {"type": "date", "date": {"start": "2024-01-01", "end": None}}, + }, + "url": "https://notion.so/page-1", + } + ], + "has_more": True, + "next_cursor": "cursor-2", + } + second_page = {"results": [], "has_more": False, "next_cursor": None} + + mock_post = mocker.patch("httpx.post", side_effect=[_mock_response(first_page), _mock_response(second_page)]) + + docs = extractor._get_notion_database_data("db-1", query_dict={"filter": {"x": 1}}) + + assert len(docs) == 1 + content = docs[0].page_content + assert "tags:['A', 'B']" in content + assert "title_prop:Title" in content + assert "rich:RichText" in content + assert "number_prop:10" in content + assert "dict_prop:start:2024-01-01" in content + assert "Row Page URL:https://notion.so/page-1" in content + assert mock_post.call_count == 2 + + def test_get_notion_database_data_handles_missing_results_and_empty_content(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.post", return_value=_mock_response({"results": None})) + assert extractor._get_notion_database_data("db-1") == [] + + def test_get_notion_database_data_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor._get_notion_database_data("db-1") -def test_notion_database(mocker: MockerFixture): - page_title_list = ["page1", "page2", "page3"] - mocked_notion_database = { - "object": "list", - "results": [_generate_page(i) for i in page_title_list], - "next_cursor": None, - } - mocker.patch("httpx.post", return_value=_mock_response(mocked_notion_database)) - database_docs = extractor._load_data_as_documents(database_id, "database") - assert len(database_docs) == 1 - content = _remove_multiple_new_lines(database_docs[0].page_content) - assert content == "\n".join([f"Page:{i}" for i in page_title_list]) +class TestNotionBlocks: + def test_get_notion_block_data_success_with_table_headings_children_and_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + first_response = { + "results": [ + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + { + "type": "heading_1", + "id": "h1", + "has_children": False, + "heading_1": {"rich_text": [{"text": {"content": "Heading"}}]}, + }, + { + "type": "paragraph", + "id": "p1", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Paragraph"}}]}, + }, + { + "type": "child_page", + "id": "cp1", + "has_children": True, + "child_page": {"rich_text": []}, + }, + ], + "next_cursor": "cursor-2", + } + second_response = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "SubHeading"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(first_response), _mock_response(second_response)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE") + mocker.patch.object(extractor, "_read_block", return_value="CHILD") + + lines = extractor._get_notion_block_data("page-1") + + assert lines[0] == "TABLE\n\n" + assert "# Heading" in lines[1] + assert "Paragraph\nCHILD\n\n" in lines[2] + assert "## SubHeading" in lines[-1] + + def test_get_notion_block_data_handles_http_error_and_invalid_payload(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + mocker.patch("httpx.request", side_effect=httpx.HTTPError("network")) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"bad": "payload"}, status_code=200)) + with pytest.raises(ValueError, match="Error fetching Notion block data"): + extractor._get_notion_block_data("page-1") + + mocker.patch("httpx.request", return_value=_mock_response({"results": []}, status_code=500, text="boom")) + with pytest.raises(ValueError, match="Error fetching Notion block data: boom"): + extractor._get_notion_block_data("page-1") + + def test_read_block_supports_heading_table_and_recursion(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + root_payload = { + "results": [ + { + "type": "heading_2", + "id": "h2", + "has_children": False, + "heading_2": {"rich_text": [{"text": {"content": "Root"}}]}, + }, + { + "type": "paragraph", + "id": "child-block", + "has_children": True, + "paragraph": {"rich_text": [{"text": {"content": "Parent"}}]}, + }, + {"type": "table", "id": "tbl-1", "has_children": False, "table": {}}, + ], + "next_cursor": None, + } + child_payload = { + "results": [ + { + "type": "paragraph", + "id": "leaf", + "has_children": False, + "paragraph": {"rich_text": [{"text": {"content": "Child"}}]}, + } + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(root_payload), _mock_response(child_payload)]) + mocker.patch.object(extractor, "_read_table_rows", return_value="TABLE-MD") + + content = extractor._read_block("root") + + assert "## Root" in content + assert "Parent" in content + assert "Child" in content + assert "TABLE-MD" in content + + def test_read_block_breaks_on_missing_results(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + mocker.patch("httpx.request", return_value=_mock_response({"results": None, "next_cursor": None})) + + assert extractor._read_block("root") == "" + + def test_read_table_rows_formats_markdown_with_pagination(self, mocker: MockerFixture): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + page_one = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [{"text": {"content": "H2"}}], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R1C1"}}], + [{"text": {"content": "R1C2"}}], + ] + } + }, + ], + "next_cursor": "next", + } + page_two = { + "results": [ + { + "table_row": { + "cells": [ + [{"text": {"content": "H1"}}], + [], + ] + } + }, + { + "table_row": { + "cells": [ + [{"text": {"content": "R2C1"}}], + [{"text": {"content": "R2C2"}}], + ] + } + }, + ], + "next_cursor": None, + } + + mocker.patch("httpx.request", side_effect=[_mock_response(page_one), _mock_response(page_two)]) + + markdown = extractor._read_table_rows("tbl-1") + + assert "| H1 | H2 |" in markdown + assert "| R1C1 | R1C2 |" in markdown + assert "| H1 | |" in markdown + assert "| R2C1 | R2C2 |" in markdown + + +class TestNotionMetadataAndCredentialMethods: + def test_update_last_edited_time_no_document_model(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + assert extractor.update_last_edited_time(None) is None + + def test_update_last_edited_time_updates_document_and_commits(self, monkeypatch): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + + class FakeDocumentModel: + data_source_info = "data_source_info" + + update_calls = [] + + class FakeQuery: + def filter_by(self, **kwargs): + return self + + def update(self, payload): + update_calls.append(payload) + + class FakeSession: + committed = False + + def query(self, model): + assert model is FakeDocumentModel + return FakeQuery() + + def commit(self): + self.committed = True + + fake_db = SimpleNamespace(session=FakeSession()) + monkeypatch.setattr(notion_extractor, "DocumentModel", FakeDocumentModel) + monkeypatch.setattr(notion_extractor, "db", fake_db) + monkeypatch.setattr(extractor, "get_notion_last_edited_time", lambda: "2026-01-01T00:00:00.000Z") + + doc_model = SimpleNamespace(id="doc-1", data_source_info_dict={"source": "notion"}) + extractor.update_last_edited_time(doc_model) + + assert update_calls + assert fake_db.session.committed is True + + def test_get_notion_last_edited_time_uses_page_and_database_urls(self, mocker: MockerFixture): + extractor_page = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="page-id", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-05-01T00:00:00.000Z"}) + ) + + assert extractor_page.get_notion_last_edited_time() == "2025-05-01T00:00:00.000Z" + assert "pages/page-id" in request_mock.call_args[0][1] + + extractor_db = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="db-id", + notion_page_type="database", + tenant_id="tenant", + notion_access_token="token", + ) + request_mock = mocker.patch( + "httpx.request", return_value=_mock_response({"last_edited_time": "2025-06-01T00:00:00.000Z"}) + ) + + assert extractor_db.get_notion_last_edited_time() == "2025-06-01T00:00:00.000Z" + assert "databases/db-id" in request_mock.call_args[0][1] + + def test_get_notion_last_edited_time_requires_access_token(self): + extractor = notion_extractor.NotionExtractor( + notion_workspace_id="ws", + notion_obj_id="obj", + notion_page_type="page", + tenant_id="tenant", + notion_access_token="token", + ) + extractor._notion_access_token = None + + with pytest.raises(AssertionError, match="Notion access token is required"): + extractor.get_notion_last_edited_time() + + def test_get_access_token_success_and_errors(self, monkeypatch): + with pytest.raises(Exception, match="No credential id found"): + notion_extractor.NotionExtractor._get_access_token("tenant", None) + + class FakeProviderServiceMissing: + def get_datasource_credentials(self, **kwargs): + return None + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceMissing) + with pytest.raises(Exception, match="No notion credential found"): + notion_extractor.NotionExtractor._get_access_token("tenant", "cred") + + class FakeProviderServiceFound: + def get_datasource_credentials(self, **kwargs): + return {"integration_secret": "token-from-credential"} + + monkeypatch.setattr(notion_extractor, "DatasourceProviderService", FakeProviderServiceFound) + + assert notion_extractor.NotionExtractor._get_access_token("tenant", "cred") == "token-from-credential" diff --git a/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py new file mode 100644 index 0000000000..fb3c6e52c6 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/test_text_extractor.py @@ -0,0 +1,79 @@ +from pathlib import Path +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.text_extractor as text_module +from core.rag.extractor.text_extractor import TextExtractor + + +class TestTextExtractor: + def test_extract_success(self, tmp_path): + file_path = tmp_path / "data.txt" + file_path.write_text("hello world", encoding="utf-8") + + extractor = TextExtractor(str(file_path)) + docs = extractor.extract() + + assert len(docs) == 1 + assert docs[0].page_content == "hello world" + assert docs[0].metadata == {"source": str(file_path)} + + def test_extract_autodetect_success_after_decode_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + calls = [] + + def fake_read_text(self, encoding=None): + calls.append(encoding) + if encoding is None: + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + if encoding == "bad": + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + return "decoded text" + + monkeypatch.setattr(Path, "read_text", fake_read_text, raising=True) + monkeypatch.setattr( + text_module, + "detect_file_encodings", + lambda _: [SimpleNamespace(encoding="bad"), SimpleNamespace(encoding="utf-8")], + ) + + docs = extractor.extract() + + assert docs[0].page_content == "decoded text" + assert calls == [None, "bad", "utf-8"] + + def test_extract_autodetect_all_fail_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=True) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + monkeypatch.setattr(text_module, "detect_file_encodings", lambda _: [SimpleNamespace(encoding="latin-1")]) + + with pytest.raises(RuntimeError, match="all detected encodings failed"): + extractor.extract() + + def test_extract_decode_error_without_autodetect_raises_runtime_error(self, monkeypatch): + extractor = TextExtractor("dummy.txt", autodetect_encoding=False) + + def always_decode_error(self, encoding=None): + raise UnicodeDecodeError("utf-8", b"x", 0, 1, "decode") + + monkeypatch.setattr(Path, "read_text", always_decode_error, raising=True) + + with pytest.raises(RuntimeError, match="specified encoding failed"): + extractor.extract() + + def test_extract_wraps_non_decode_exceptions(self, monkeypatch): + extractor = TextExtractor("dummy.txt") + + def raise_other(self, encoding=None): + raise OSError("io error") + + monkeypatch.setattr(Path, "read_text", raise_other, raising=True) + + with pytest.raises(RuntimeError, match="Error loading dummy.txt"): + extractor.extract() diff --git a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py index 0792ada194..64eb89590a 100644 --- a/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py +++ b/api/tests/unit_tests/core/rag/extractor/test_word_extractor.py @@ -3,9 +3,12 @@ import io import os import tempfile +from collections import UserDict from pathlib import Path from types import SimpleNamespace +from unittest.mock import MagicMock +import pytest from docx import Document from docx.oxml import OxmlElement from docx.oxml.ns import qn @@ -136,7 +139,7 @@ def test_extract_images_from_docx(monkeypatch): monkeypatch.setattr(we, "UploadFile", FakeUploadFile) # Patch external image fetcher - def fake_get(url: str): + def fake_get(url: str, **kwargs): assert url == "https://example.com/image.png" return SimpleNamespace(status_code=200, headers={"Content-Type": "image/png"}, content=external_bytes) @@ -203,10 +206,8 @@ def test_extract_images_from_docx_uses_internal_files_url(): finally: # Restore original values - if original_files_url is not None: - dify_config.FILES_URL = original_files_url - if original_internal_files_url is not None: - dify_config.INTERNAL_FILES_URL = original_internal_files_url + dify_config.FILES_URL = original_files_url + dify_config.INTERNAL_FILES_URL = original_internal_files_url def test_extract_hyperlinks(monkeypatch): @@ -314,3 +315,405 @@ def test_extract_legacy_hyperlinks(monkeypatch): finally: if os.path.exists(tmp_path): os.remove(tmp_path) + + +def test_init_rejects_invalid_url_status(monkeypatch): + class FakeResponse: + status_code = 404 + content = b"" + closed = False + + def close(self): + self.closed = True + + fake_response = FakeResponse() + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=lambda url, **kwargs: fake_response)) + + with pytest.raises(ValueError, match="returned status code 404"): + WordExtractor("https://example.com/missing.docx", "tenant", "user") + + assert fake_response.closed is True + + +def test_init_expands_home_path_and_invalid_local_path(monkeypatch, tmp_path): + target_file = tmp_path / "expanded.docx" + target_file.write_bytes(b"docx") + + monkeypatch.setattr(we.os.path, "expanduser", lambda p: str(target_file)) + monkeypatch.setattr( + we.os.path, + "isfile", + lambda p: p == str(target_file), + ) + + extractor = WordExtractor("~/expanded.docx", "tenant", "user") + assert extractor.file_path == str(target_file) + + monkeypatch.setattr(we.os.path, "isfile", lambda p: False) + with pytest.raises(ValueError, match="is not a valid file or url"): + WordExtractor("not-a-file", "tenant", "user") + + +def test_del_closes_temp_file(): + extractor = object.__new__(WordExtractor) + extractor.temp_file = MagicMock() + + WordExtractor.__del__(extractor) + + extractor.temp_file.close.assert_called_once() + + +def test_extract_images_handles_invalid_external_cases(monkeypatch): + class FakeTargetRef: + def __contains__(self, item): + return item == "image" + + def split(self, sep): + return [None] + + rel_invalid_url = SimpleNamespace(is_external=True, target_ref="image-no-url") + rel_request_error = SimpleNamespace(is_external=True, target_ref="https://example.com/image-error") + rel_unknown_mime = SimpleNamespace(is_external=True, target_ref="https://example.com/image-unknown") + rel_internal_none_ext = SimpleNamespace(is_external=False, target_ref=FakeTargetRef(), target_part=object()) + + doc = SimpleNamespace( + part=SimpleNamespace( + rels={ + "r1": rel_invalid_url, + "r2": rel_request_error, + "r3": rel_unknown_mime, + "r4": rel_internal_none_ext, + } + ) + ) + + def fake_get(url, **kwargs): + if "image-error" in url: + raise RuntimeError("network") + return SimpleNamespace(status_code=200, headers={"Content-Type": "application/unknown"}, content=b"x") + + monkeypatch.setattr(we, "ssrf_proxy", SimpleNamespace(get=fake_get)) + db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=MagicMock())) + monkeypatch.setattr(we, "db", db_stub) + monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: None)) + monkeypatch.setattr(we.dify_config, "FILES_URL", "http://files.local", raising=False) + + extractor = object.__new__(WordExtractor) + extractor.tenant_id = "tenant" + extractor.user_id = "user" + + result = extractor._extract_images_from_docx(doc) + + assert result == {} + db_stub.session.commit.assert_called_once() + + +def test_table_to_markdown_and_parse_helpers(monkeypatch): + extractor = object.__new__(WordExtractor) + + table = SimpleNamespace( + rows=[ + SimpleNamespace(cells=[1, 2]), + SimpleNamespace(cells=[3, 4]), + ] + ) + parse_row_mock = MagicMock(side_effect=[["H1", "H2"], ["A", "B"]]) + monkeypatch.setattr(extractor, "_parse_row", parse_row_mock) + + markdown = extractor._table_to_markdown(table, {}) + assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |" + + class FakeBlip: + def __init__(self, image_id): + self.image_id = image_id + + def get(self, key): + return self.image_id + + class FakeRunChild: + def __init__(self, blips, text=""): + self._blips = blips + self.text = text + self.tag = qn("w:r") + + def xpath(self, pattern): + if pattern == ".//a:blip": + return self._blips + return [] + + class FakeRun: + def __init__(self, element, paragraph): + # Mirror the subset used by _parse_cell_paragraph + self.element = element + self.text = getattr(element, "text", "") + + # Patch we.Run so our lightweight child objects work with the extractor + monkeypatch.setattr(we, "Run", FakeRun) + + image_part = object() + paragraph = SimpleNamespace( + _element=[ + FakeRunChild([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")], text=""), + FakeRunChild([], text="plain"), + ], + part=SimpleNamespace( + rels={ + "ext": SimpleNamespace(is_external=True), + "int": SimpleNamespace(is_external=False, target_part=image_part), + } + ), + ) + + image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"} + assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain" + + cell = SimpleNamespace(paragraphs=[paragraph, paragraph]) + assert extractor._parse_cell(cell, image_map) == "EXT-IMGINT-IMGplain" + + +def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monkeypatch): + extractor = object.__new__(WordExtractor) + + ext_image_id = "ext-image" + int_embed_id = "int-embed" + shape_ext_id = "shape-ext" + shape_int_id = "shape-int" + + internal_part = object() + shape_internal_part = object() + + class Rels(UserDict): + def get(self, key, default=None): + if key == "link-bad": + raise RuntimeError("cannot resolve relation") + return super().get(key, default) + + rels = Rels( + { + ext_image_id: SimpleNamespace(is_external=True, target_ref="https://img/ext.png"), + int_embed_id: SimpleNamespace(is_external=False, target_part=internal_part), + shape_ext_id: SimpleNamespace(is_external=True, target_ref="https://img/shape.png"), + shape_int_id: SimpleNamespace(is_external=False, target_part=shape_internal_part), + "link-ok": SimpleNamespace(is_external=True, target_ref="https://example.com"), + } + ) + + image_map = { + ext_image_id: "[EXT]", + internal_part: "[INT]", + shape_ext_id: "[SHAPE_EXT]", + shape_internal_part: "[SHAPE_INT]", + } + + class FakeBlip: + def __init__(self, embed_id): + self.embed_id = embed_id + + def get(self, key): + return self.embed_id + + class FakeDrawing: + def __init__(self, embed_ids): + self.embed_ids = embed_ids + + def findall(self, pattern): + return [FakeBlip(embed_id) for embed_id in self.embed_ids] + + class FakeNode: + def __init__(self, text=None, attrs=None): + self.text = text + self._attrs = attrs or {} + + def get(self, key): + return self._attrs.get(key) + + class FakeShape: + def __init__(self, bin_id=None, img_id=None): + self.bin_id = bin_id + self.img_id = img_id + + def find(self, pattern): + if "binData" in pattern and self.bin_id: + return FakeNode( + text="shape", + attrs={"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id": self.bin_id}, + ) + if "imagedata" in pattern and self.img_id: + return FakeNode(attrs={"id": self.img_id}) + return None + + class FakeChild: + def __init__( + self, + tag, + text="", + fld_chars=None, + instr_texts=None, + drawings=None, + shapes=None, + attrs=None, + hyperlink_runs=None, + ): + self.tag = tag + self.text = text + self._fld_chars = fld_chars or [] + self._instr_texts = instr_texts or [] + self._drawings = drawings or [] + self._shapes = shapes or [] + self._attrs = attrs or {} + self._hyperlink_runs = hyperlink_runs or [] + + def findall(self, pattern): + if pattern == qn("w:fldChar"): + return self._fld_chars + if pattern == qn("w:instrText"): + return self._instr_texts + if pattern == qn("w:r"): + return self._hyperlink_runs + if pattern.endswith("}drawing"): + return self._drawings + if pattern.endswith("}pict"): + return self._shapes + return [] + + def get(self, key): + return self._attrs.get(key) + + class FakeRun: + def __init__(self, element, paragraph): + self.element = element + self.text = getattr(element, "text", "") + + paragraph_main = SimpleNamespace( + _element=[ + FakeChild( + qn("w:r"), + text="run-text", + drawings=[FakeDrawing([ext_image_id, int_embed_id])], + shapes=[FakeShape(bin_id=shape_ext_id, img_id=shape_int_id)], + ), + FakeChild( + qn("w:r"), + text="", + drawings=[], + shapes=[FakeShape(bin_id=shape_ext_id)], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-ok"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="LinkText")], + ), + FakeChild( + qn("w:hyperlink"), + attrs={qn("r:id"): "link-bad"}, + hyperlink_runs=[FakeChild(qn("w:r"), text="BrokenLink")], + ), + ] + ) + paragraph_empty = SimpleNamespace(_element=[FakeChild(qn("w:r"), text=" ")]) + + fake_doc = SimpleNamespace( + part=SimpleNamespace(rels=rels, related_parts={int_embed_id: internal_part}), + paragraphs=[paragraph_main, paragraph_empty], + tables=[SimpleNamespace(rows=[])], + element=SimpleNamespace( + body=[SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:p"), SimpleNamespace(tag="w:tbl")] + ), + ) + + monkeypatch.setattr(we, "DocxDocument", lambda _: fake_doc) + monkeypatch.setattr(we, "Run", FakeRun) + monkeypatch.setattr(extractor, "_extract_images_from_docx", lambda doc: image_map) + monkeypatch.setattr(extractor, "_table_to_markdown", lambda table, image_map: "TABLE-MARKDOWN") + logger_exception = MagicMock() + monkeypatch.setattr(we.logger, "exception", logger_exception) + + content = extractor.parse_docx("dummy.docx") + + assert "[EXT]" in content + assert "[INT]" in content + assert "[SHAPE_EXT]" in content + assert "[LinkText](https://example.com)" in content + assert "BrokenLink" in content + assert "TABLE-MARKDOWN" in content + logger_exception.assert_called_once() + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_http(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + # Build modern hyperlink inside table cell + r_id = "rIdHttp1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "Dify" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + # Relationship for external http link + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "https://dify.ai", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[Dify](https://dify.ai)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) + + +def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto(): + doc = Document() + table = doc.add_table(rows=1, cols=1) + cell = table.cell(0, 0) + p = cell.paragraphs[0] + + r_id = "rIdMail1" + hyperlink = OxmlElement("w:hyperlink") + hyperlink.set(qn("r:id"), r_id) + + run_elem = OxmlElement("w:r") + t = OxmlElement("w:t") + t.text = "john@test.com" + run_elem.append(t) + hyperlink.append(run_elem) + p._p.append(hyperlink) + + doc.part.rels.add_relationship( + "http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink", + "mailto:john@test.com", + r_id, + is_external=True, + ) + + with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp: + doc.save(tmp.name) + tmp_path = tmp.name + + try: + reopened = Document(tmp_path) + para = reopened.tables[0].cell(0, 0).paragraphs[0] + extractor = object.__new__(WordExtractor) + out = extractor._parse_cell_paragraph(para, {}) + assert out == "[john@test.com](mailto:john@test.com)" + finally: + if os.path.exists(tmp_path): + os.remove(tmp_path) diff --git a/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py new file mode 100644 index 0000000000..26ce333e11 --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/unstructured/test_unstructured_extractors.py @@ -0,0 +1,300 @@ +"""Unit tests for unstructured extractors and their local/API partitioning paths.""" + +import base64 +import sys +import types +from types import SimpleNamespace + +import pytest + +import core.rag.extractor.unstructured.unstructured_epub_extractor as epub_module +from core.rag.extractor.unstructured.unstructured_doc_extractor import UnstructuredWordExtractor +from core.rag.extractor.unstructured.unstructured_eml_extractor import UnstructuredEmailExtractor +from core.rag.extractor.unstructured.unstructured_epub_extractor import UnstructuredEpubExtractor +from core.rag.extractor.unstructured.unstructured_markdown_extractor import UnstructuredMarkdownExtractor +from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor +from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor +from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor +from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor + + +def _register_module(monkeypatch: pytest.MonkeyPatch, name: str, **attrs: object) -> types.ModuleType: + module = types.ModuleType(name) + for k, v in attrs.items(): + setattr(module, k, v) + monkeypatch.setitem(sys.modules, name, module) + return module + + +def _register_unstructured_packages(monkeypatch: pytest.MonkeyPatch) -> None: + _register_module(monkeypatch, "unstructured", __path__=[]) + _register_module(monkeypatch, "unstructured.partition", __path__=[]) + _register_module(monkeypatch, "unstructured.chunking", __path__=[]) + _register_module(monkeypatch, "unstructured.file_utils", __path__=[]) + + +def _install_chunk_by_title(monkeypatch: pytest.MonkeyPatch, chunks: list[SimpleNamespace]) -> None: + _register_unstructured_packages(monkeypatch) + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + return chunks + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + +class TestUnstructuredMarkdownMsgXml: + def test_markdown_extractor_without_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" chunk-1 "), SimpleNamespace(text=" chunk-2 ")]) + _register_module( + monkeypatch, "unstructured.partition.md", partition_md=lambda filename: [SimpleNamespace(text="x")] + ) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md").extract() + + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_markdown_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text=" via-api ")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="ignored")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + docs = UnstructuredMarkdownExtractor("/tmp/file.md", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "via-api" + assert calls == {"filename": "/tmp/file.md", "api_url": "https://u", "api_key": "k"} + + def test_msg_extractor_local(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + _register_module( + monkeypatch, "unstructured.partition.msg", partition_msg=lambda filename: [SimpleNamespace(text="x")] + ) + + assert UnstructuredMsgExtractor("/tmp/file.msg").extract()[0].page_content == "msg-doc" + + def test_msg_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="msg-doc")]) + calls = {} + + def partition_via_api(filename, api_url, api_key): + calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredMsgExtractor("/tmp/file.msg", api_url="https://u", api_key="k").extract()[0].page_content + == "msg-doc" + ) + assert calls["filename"] == "/tmp/file.msg" + + def test_xml_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="xml-doc")]) + + xml_calls = {} + + def partition_xml(filename, xml_keep_tags): + xml_calls.update({"filename": filename, "xml_keep_tags": xml_keep_tags}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.xml", partition_xml=partition_xml) + + assert UnstructuredXmlExtractor("/tmp/file.xml").extract()[0].page_content == "xml-doc" + assert xml_calls == {"filename": "/tmp/file.xml", "xml_keep_tags": True} + + api_calls = {} + + def partition_via_api(filename, api_url, api_key): + api_calls.update({"filename": filename, "api_url": api_url, "api_key": api_key}) + return [SimpleNamespace(text="x")] + + _register_module(monkeypatch, "unstructured.partition.api", partition_via_api=partition_via_api) + + assert ( + UnstructuredXmlExtractor("/tmp/file.xml", api_url="https://u", api_key="k").extract()[0].page_content + == "xml-doc" + ) + assert api_calls["filename"] == "/tmp/file.xml" + + +class TestUnstructuredEmailAndEpub: + def test_email_extractor_local_decodes_html_and_suppresses_decode_errors(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + captured = {} + + def chunk_by_title( + elements: list[SimpleNamespace], max_characters: int, combine_text_under_n_chars: int + ) -> list[SimpleNamespace]: + captured["elements"] = list(elements) + return [SimpleNamespace(text=" chunked-email ")] + + _register_module(monkeypatch, "unstructured.chunking.title", chunk_by_title=chunk_by_title) + + html = "

Hello Email

" + encoded_html = base64.b64encode(html.encode("utf-8")).decode("utf-8") + bad_base64 = "not-base64" + + elements = [SimpleNamespace(text=encoded_html), SimpleNamespace(text=bad_base64)] + _register_module(monkeypatch, "unstructured.partition.email", partition_email=lambda filename: elements) + + docs = UnstructuredEmailExtractor("/tmp/file.eml").extract() + + assert docs[0].page_content == "chunked-email" + chunk_elements = captured["elements"] + assert "Hello Email" in chunk_elements[0].text + assert chunk_elements[1].text == bad_base64 + + def test_email_extractor_with_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="api-email")]) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="abc")], + ) + + docs = UnstructuredEmailExtractor("/tmp/file.eml", api_url="https://u", api_key="k").extract() + + assert docs[0].page_content == "api-email" + + def test_epub_extractor_local_and_api(self, monkeypatch): + _install_chunk_by_title(monkeypatch, [SimpleNamespace(text="epub-doc")]) + + calls = {"download": 0, "partition": 0} + + def fake_download_pandoc(): + calls["download"] += 1 + + def partition_epub(filename, xml_keep_tags): + calls["partition"] += 1 + assert xml_keep_tags is True + return [SimpleNamespace(text="x")] + + monkeypatch.setattr(epub_module.pypandoc, "download_pandoc", fake_download_pandoc) + _register_module(monkeypatch, "unstructured.partition.epub", partition_epub=partition_epub) + + docs = UnstructuredEpubExtractor("/tmp/file.epub").extract() + + assert docs[0].page_content == "epub-doc" + assert calls == {"download": 1, "partition": 1} + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="x")], + ) + + docs = UnstructuredEpubExtractor("/tmp/file.epub", api_url="https://u", api_key="k").extract() + assert docs[0].page_content == "epub-doc" + + +class TestUnstructuredPPTAndPPTX: + def test_ppt_extractor_requires_api_url(self): + with pytest.raises(NotImplementedError, match="Unstructured API Url is not configured"): + UnstructuredPPTExtractor("/tmp/file.ppt").extract() + + def test_ppt_extractor_groups_text_by_page(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="A", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="B", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="skip", metadata=SimpleNamespace(page_number=None)), + SimpleNamespace(text="C", metadata=SimpleNamespace(page_number=2)), + ], + ) + + docs = UnstructuredPPTExtractor("/tmp/file.ppt", api_url="https://u", api_key="k").extract() + + assert [doc.page_content for doc in docs] == ["A\nB", "C"] + + def test_pptx_extractor_local_and_api(self, monkeypatch): + _register_unstructured_packages(monkeypatch) + _register_module( + monkeypatch, + "unstructured.partition.pptx", + partition_pptx=lambda filename: [ + SimpleNamespace(text="P1", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="P2", metadata=SimpleNamespace(page_number=2)), + SimpleNamespace(text="Skip", metadata=SimpleNamespace(page_number=None)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx").extract() + assert [doc.page_content for doc in docs] == ["P1", "P2"] + + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [ + SimpleNamespace(text="X", metadata=SimpleNamespace(page_number=1)), + SimpleNamespace(text="Y", metadata=SimpleNamespace(page_number=1)), + ], + ) + + docs = UnstructuredPPTXExtractor("/tmp/file.pptx", api_url="https://u", api_key="k").extract() + assert [doc.page_content for doc in docs] == ["X\nY"] + + +class TestUnstructuredWord: + def _install_doc_modules(self, monkeypatch, version: str, filetype_value): + _register_unstructured_packages(monkeypatch) + + class FileType: + DOC = "doc" + + _register_module(monkeypatch, "unstructured.__version__", __version__=version) + _register_module( + monkeypatch, + "unstructured.file_utils.filetype", + FileType=FileType, + detect_filetype=lambda filename: filetype_value, + ) + _register_module( + monkeypatch, + "unstructured.partition.api", + partition_via_api=lambda filename, api_url, api_key: [SimpleNamespace(text="api-doc")], + ) + _register_module( + monkeypatch, + "unstructured.partition.docx", + partition_docx=lambda filename: [SimpleNamespace(text="docx-doc")], + ) + _register_module( + monkeypatch, + "unstructured.chunking.title", + chunk_by_title=lambda elements, max_characters, combine_text_under_n_chars: [ + SimpleNamespace(text="chunk-1"), + SimpleNamespace(text="chunk-2"), + ], + ) + + def test_word_extractor_rejects_doc_on_old_unstructured_version(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="doc") + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + + def test_word_extractor_doc_and_docx_paths(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.11", filetype_value="doc") + + docs = UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + self._install_doc_modules(monkeypatch, version="0.5.0", filetype_value="not-doc") + docs = UnstructuredWordExtractor("/tmp/file.docx", "https://u", "k").extract() + assert [doc.page_content for doc in docs] == ["chunk-1", "chunk-2"] + + def test_word_extractor_magic_import_error_fallback_to_extension(self, monkeypatch): + self._install_doc_modules(monkeypatch, version="0.4.10", filetype_value="not-used") + monkeypatch.setitem(sys.modules, "magic", None) + + with pytest.raises(ValueError, match="Partitioning .doc files is only supported"): + UnstructuredWordExtractor("/tmp/file.doc", "https://u", "k").extract() diff --git a/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py new file mode 100644 index 0000000000..d758be218a --- /dev/null +++ b/api/tests/unit_tests/core/rag/extractor/watercrawl/test_watercrawl.py @@ -0,0 +1,434 @@ +"""Unit tests for WaterCrawl client, provider, and extractor behavior.""" + +import json +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import core.rag.extractor.watercrawl.client as client_module +from core.rag.extractor.watercrawl.client import BaseAPIClient, WaterCrawlAPIClient +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) +from core.rag.extractor.watercrawl.extractor import WaterCrawlWebExtractor +from core.rag.extractor.watercrawl.provider import WaterCrawlProvider + + +def _response( + status_code: int, + json_data: dict[str, Any] | None = None, + content_type: str = "application/json", + content: bytes = b"", + text: str = "", +) -> MagicMock: + response = MagicMock() + response.status_code = status_code + response.headers = {"Content-Type": content_type} + response.content = content + response.text = text + response.json.return_value = json_data if json_data is not None else {} + response.raise_for_status.return_value = None + response.close.return_value = None + return response + + +class TestWaterCrawlExceptions: + def test_bad_request_error_properties_and_string(self): + response = _response(400, {"message": "bad request", "errors": {"url": ["invalid"]}}) + + err = WaterCrawlBadRequestError(response) + parsed_errors = json.loads(err.flat_errors) + + assert err.status_code == 400 + assert err.message == "bad request" + assert "url" in parsed_errors + assert any("invalid" in str(item) for item in parsed_errors["url"]) + assert "WaterCrawlBadRequestError" in str(err) + + def test_permission_and_authentication_error_strings(self): + response = _response(403, {"message": "quota exceeded", "errors": {}}) + + permission = WaterCrawlPermissionError(response) + authentication = WaterCrawlAuthenticationError(response) + + assert "exceeding your WaterCrawl API limits" in str(permission) + assert "API key is invalid or expired" in str(authentication) + + +class TestBaseAPIClient: + def test_init_session_builds_expected_headers(self, monkeypatch): + captured = {} + + def fake_client(**kwargs): + captured.update(kwargs) + return "session" + + monkeypatch.setattr(client_module.httpx, "Client", fake_client) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client.session == "session" + assert captured["headers"]["X-API-Key"] == "k" + assert captured["headers"]["User-Agent"] == "WaterCrawl-Plugin" + + def test_request_stream_and_non_stream_paths(self, monkeypatch): + class FakeSession: + def __init__(self): + self.request_calls = [] + self.build_calls = [] + self.send_calls = [] + + def request(self, method, url, params=None, json=None, **kwargs): + self.request_calls.append((method, url, params, json, kwargs)) + return "non-stream-response" + + def build_request(self, method, url, params=None, json=None): + req = (method, url, params, json) + self.build_calls.append(req) + return req + + def send(self, request, stream=False, **kwargs): + self.send_calls.append((request, stream, kwargs)) + return "stream-response" + + fake_session = FakeSession() + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: fake_session) + + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + assert client._request("GET", "/v1/items", query_params={"a": 1}) == "non-stream-response" + assert fake_session.request_calls[0][1] == "https://watercrawl.dev/v1/items" + + assert client._request("GET", "/v1/items", stream=True) == "stream-response" + assert fake_session.build_calls + assert fake_session.send_calls[0][1] is True + + def test_http_method_helpers_delegate_to_request(self, monkeypatch): + monkeypatch.setattr(BaseAPIClient, "init_session", lambda self: MagicMock()) + client = BaseAPIClient(api_key="k", base_url="https://watercrawl.dev") + + calls = [] + + def fake_request(method, endpoint, query_params=None, data=None, **kwargs): + calls.append((method, endpoint, query_params, data)) + return "ok" + + monkeypatch.setattr(client, "_request", fake_request) + + assert client._get("/a") == "ok" + assert client._post("/b", data={"x": 1}) == "ok" + assert client._put("/c", data={"x": 2}) == "ok" + assert client._delete("/d") == "ok" + assert client._patch("/e", data={"x": 3}) == "ok" + assert [c[0] for c in calls] == ["GET", "POST", "PUT", "DELETE", "PATCH"] + + +class TestWaterCrawlAPIClient: + def test_process_eventstream_and_download(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = MagicMock() + response.iter_lines.return_value = [ + b"event: keep-alive", + b'data: {"type":"result","data":{"result":"http://x"}}', + b'data: {"type":"log","data":{"msg":"ok"}}', + ] + + monkeypatch.setattr(client, "download_result", lambda data: {"result": {"markdown": "body"}, "url": "u"}) + + events = list(client.process_eventstream(response, download=True)) + + assert events[0]["data"]["result"]["markdown"] == "body" + assert events[1]["type"] == "log" + response.close.assert_called_once() + + @pytest.mark.parametrize( + ("status", "expected_exception"), + [ + (401, WaterCrawlAuthenticationError), + (403, WaterCrawlPermissionError), + (422, WaterCrawlBadRequestError), + ], + ) + def test_process_response_error_statuses(self, status: int, expected_exception: type[Exception]): + client = WaterCrawlAPIClient(api_key="k") + + with pytest.raises(expected_exception): + client.process_response(_response(status, {"message": "bad", "errors": {"url": ["x"]}})) + + def test_process_response_204_returns_none(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(204, None)) is None + + def test_process_response_json_payloads(self): + client = WaterCrawlAPIClient(api_key="k") + assert client.process_response(_response(200, {"ok": True})) == {"ok": True} + assert client.process_response(_response(200, None)) == {} + + def test_process_response_octet_stream_returns_bytes(self): + client = WaterCrawlAPIClient(api_key="k") + assert ( + client.process_response(_response(200, content_type="application/octet-stream", content=b"bin")) == b"bin" + ) + + def test_process_response_event_stream_returns_generator(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + generator = (item for item in [{"type": "result", "data": {}}]) + monkeypatch.setattr(client, "process_eventstream", lambda response, download=False: generator) + assert client.process_response(_response(200, content_type="text/event-stream")) is generator + + def test_process_response_unknown_content_type_raises(self): + client = WaterCrawlAPIClient(api_key="k") + with pytest.raises(Exception, match="Unknown response type"): + client.process_response(_response(200, content_type="text/plain", text="x")) + + def test_process_response_uses_raise_for_status(self): + client = WaterCrawlAPIClient(api_key="k") + response = _response(500, {"message": "server"}) + response.raise_for_status.side_effect = RuntimeError("http error") + + with pytest.raises(RuntimeError, match="http error"): + client.process_response(response) + + def test_endpoint_wrappers(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda resp: "processed") + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "get-resp") + monkeypatch.setattr(client, "_post", lambda *args, **kwargs: "post-resp") + monkeypatch.setattr(client, "_delete", lambda *args, **kwargs: "delete-resp") + + assert client.get_crawl_requests_list() == "processed" + assert client.get_crawl_request("id") == "processed" + assert client.create_crawl_request(url="https://x") == "processed" + assert client.stop_crawl_request("id") == "processed" + assert client.download_crawl_request("id") == "processed" + assert client.get_crawl_request_results("id") == "processed" + + def test_monitor_crawl_request_generator_and_validation(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + monkeypatch.setattr(client, "process_response", lambda _: (x for x in [{"type": "result", "data": 1}])) + monkeypatch.setattr(client, "_get", lambda *args, **kwargs: "stream-resp") + + events = list(client.monitor_crawl_request("job-1", prefetched=True)) + assert events == [{"type": "result", "data": 1}] + + monkeypatch.setattr(client, "process_response", lambda _: [{"type": "result"}]) + with pytest.raises(ValueError, match="Generator expected"): + list(client.monitor_crawl_request("job-1")) + + def test_scrape_url_sync_and_async(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + monkeypatch.setattr(client, "create_crawl_request", lambda **kwargs: {"uuid": "job-1"}) + + async_result = client.scrape_url("https://example.com", sync=False) + assert async_result == {"uuid": "job-1"} + + monkeypatch.setattr( + client, + "monitor_crawl_request", + lambda item_id, prefetched: iter( + [{"type": "log", "data": {}}, {"type": "result", "data": {"url": "https://example.com"}}] + ), + ) + sync_result = client.scrape_url("https://example.com", sync=True) + assert sync_result == {"url": "https://example.com"} + + def test_download_result_fetches_json_and_closes(self, monkeypatch): + client = WaterCrawlAPIClient(api_key="k") + + response = _response(200, {"markdown": "body"}) + monkeypatch.setattr(client_module.httpx, "get", lambda *args, **kwargs: response) + + result = client.download_result({"result": "https://example.com/result.json"}) + + assert result["result"] == {"markdown": "body"} + response.close.assert_called_once() + + +class TestWaterCrawlProvider: + def test_crawl_url_builds_options_and_min_wait_time(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + captured_kwargs = {} + + def create_crawl_request_spy(**kwargs): + captured_kwargs.update(kwargs) + return {"uuid": "job-1"} + + monkeypatch.setattr(provider.client, "create_crawl_request", create_crawl_request_spy) + + result = provider.crawl_url( + "https://example.com", + { + "crawl_sub_pages": True, + "limit": 5, + "max_depth": 2, + "includes": "a,b", + "excludes": "x,y", + "exclude_tags": "nav,footer", + "include_tags": "main", + "wait_time": 100, + "only_main_content": False, + }, + ) + + assert result == {"status": "active", "job_id": "job-1"} + assert captured_kwargs["url"] == "https://example.com" + assert captured_kwargs["spider_options"] == { + "max_depth": 2, + "page_limit": 5, + "allowed_domains": [], + "exclude_paths": ["x", "y"], + "include_paths": ["a", "b"], + } + assert captured_kwargs["page_options"]["exclude_tags"] == ["nav", "footer"] + assert captured_kwargs["page_options"]["include_tags"] == ["main"] + assert captured_kwargs["page_options"]["only_main_content"] is False + assert captured_kwargs["page_options"]["wait_time"] == 1000 + + def test_get_crawl_status_active_and_completed(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "running", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 3}}, + "number_of_documents": 1, + "duration": "00:00:01.500000", + }, + ) + + active = provider.get_crawl_status("job-1") + assert active["status"] == "active" + assert active["data"] == [] + assert active["time_consuming"] == pytest.approx(1.5) + + monkeypatch.setattr( + provider.client, + "get_crawl_request", + lambda job_id: { + "status": "completed", + "uuid": job_id, + "options": {"spider_options": {"page_limit": 2}}, + "number_of_documents": 2, + "duration": "00:00:02.000000", + }, + ) + monkeypatch.setattr(provider, "_get_results", lambda crawl_request_id, query_params=None: iter([{"url": "u"}])) + + completed = provider.get_crawl_status("job-2") + assert completed["status"] == "completed" + assert completed["data"] == [{"url": "u"}] + + def test_get_crawl_url_data_and_scrape(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + monkeypatch.setattr(provider, "scrape_url", lambda url: {"source_url": url}) + assert provider.get_crawl_url_data("", "https://example.com") == {"source_url": "https://example.com"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([{"source_url": "u1"}])) + assert provider.get_crawl_url_data("job", "u1") == {"source_url": "u1"} + + monkeypatch.setattr(provider, "_get_results", lambda job_id, query_params=None: iter([])) + assert provider.get_crawl_url_data("job", "u1") is None + + def test_structure_data_validation_and_get_results_pagination(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + + with pytest.raises(ValueError, match="Invalid result object"): + provider._structure_data({"result": "not-a-dict"}) + + structured = provider._structure_data( + { + "url": "https://example.com", + "result": { + "metadata": {"title": "Title", "description": "Desc"}, + "markdown": "Body", + }, + } + ) + assert structured["title"] == "Title" + assert structured["markdown"] == "Body" + + responses = [ + { + "results": [ + { + "url": "https://a", + "result": {"metadata": {"title": "A", "description": "DA"}, "markdown": "MA"}, + } + ], + "next": "next-page", + }, + {"results": [], "next": None}, + ] + + monkeypatch.setattr( + provider.client, + "get_crawl_request_results", + lambda crawl_request_id, page, page_size, query_params: responses.pop(0), + ) + + results = list(provider._get_results("job-1")) + assert len(results) == 1 + assert results[0]["source_url"] == "https://a" + + def test_scrape_url_uses_client_and_structure(self, monkeypatch): + provider = WaterCrawlProvider(api_key="k") + monkeypatch.setattr( + provider.client, "scrape_url", lambda **kwargs: {"result": {"metadata": {}, "markdown": "m"}, "url": "u"} + ) + + result = provider.scrape_url("u") + + assert result["source_url"] == "u" + + +class TestWaterCrawlWebExtractor: + def test_extract_crawl_and_scrape_modes(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: { + "markdown": "crawl", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_scrape_url_data", + lambda provider, url, tenant_id, only_main_content: { + "markdown": "scrape", + "source_url": url, + "description": "d", + "title": "t", + }, + ) + + crawl_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + scrape_extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="scrape") + + assert crawl_extractor.extract()[0].page_content == "crawl" + assert scrape_extractor.extract()[0].page_content == "scrape" + + def test_extract_crawl_returns_empty_when_service_returns_none(self, monkeypatch): + monkeypatch.setattr( + "core.rag.extractor.watercrawl.extractor.WebsiteService.get_crawl_url_data", + lambda job_id, provider, url, tenant_id: None, + ) + + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="crawl") + + assert extractor.extract() == [] + + def test_extract_unknown_mode_returns_empty(self): + extractor = WaterCrawlWebExtractor("https://example.com", "job-1", "tenant-1", mode="other") + + assert extractor.extract() == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/conftest.py b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py new file mode 100644 index 0000000000..2a3860e107 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/conftest.py @@ -0,0 +1,33 @@ +from contextlib import AbstractContextManager, nullcontext +from typing import Any + +import pytest + + +class _FakeFlaskApp: + def app_context(self) -> AbstractContextManager[None]: + return nullcontext() + + +class _FakeExecutor: + def __init__(self, future: Any) -> None: + self._future = future + + def __enter__(self) -> "_FakeExecutor": + return self + + def __exit__(self, exc_type: object, exc_value: object, traceback: object) -> bool: + return False + + def submit(self, func: object, preview: object) -> Any: + return self._future + + +@pytest.fixture +def fake_flask_app() -> _FakeFlaskApp: + return _FakeFlaskApp() + + +@pytest.fixture +def fake_executor_cls() -> type[_FakeExecutor]: + return _FakeExecutor diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py new file mode 100644 index 0000000000..e6cc582398 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -0,0 +1,630 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.models.document import AttachmentDocument, Document +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from dify_graph.model_runtime.entities.model_entities import ModelFeature + + +class TestParagraphIndexProcessor: + @pytest.fixture + def processor(self) -> ParagraphIndexProcessor: + return ParagraphIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def _llm_result(self, content: str = "summary") -> LLMResult: + return LLMResult( + model="llm-model", + message=AssistantPromptMessage(content=content), + usage=LLMUsage.empty_usage(), + ) + + def test_extract_forwards_automatic_flag(self, processor: ParagraphIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected_docs + docs = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParagraphIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_validates_segmentation(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + rules_without_segmentation = SimpleNamespace(segmentation=None) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=rules_without_segmentation, + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_builds_split_documents(self, processor: ParagraphIndexProcessor, process_rule: dict) -> None: + source_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".first", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + return_value=".first", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + ): + documents = processor.transform([source_document], process_rule=process_rule) + + assert len(documents) == 1 + assert documents[0].page_content == "first" + assert documents[0].attachments is not None + assert documents[0].metadata["doc_hash"] == "hash" + + def test_transform_automatic_mode_uses_default_rules(self, processor: ParagraphIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content="text", metadata={})] + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.Rule.model_validate", + return_value=self._rules(), + ) as mock_validate, + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_get_content_files", return_value=[]), + ): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "automatic"}) + + assert mock_validate.call_count == 1 + + def test_load_creates_vector_and_multimodal_when_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + docs = [Document(page_content="chunk", metadata={})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + vector = mock_vector_cls.return_value + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + mock_keyword_cls.assert_not_called() + + def test_load_uses_keyword_add_texts_with_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs, keywords_list=["k1", "k2"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs, keywords_list=["k1", "k2"]) + + def test_load_uses_keyword_add_texts_without_keywords_when_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.load(dataset, docs) + + mock_keyword_cls.return_value.add_texts.assert_called_once_with(docs) + + def test_clean_deletes_summaries_and_vector(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_economy_deletes_summaries_and_keywords( + self, processor: ParagraphIndexProcessor, dataset: Mock + ) -> None: + dataset.indexing_technique = "economy" + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + mock_keyword_cls.return_value.delete.assert_called_once() + + def test_clean_deletes_keywords_by_ids(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + with patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls: + processor.clean(dataset, ["node-2"], with_keywords=True) + + mock_keyword_cls.return_value.delete_by_ids.assert_called_once_with(["node-2"]) + + def test_retrieve_filters_by_threshold(self, processor: ParagraphIndexProcessor, dataset: Mock) -> None: + accepted = SimpleNamespace(page_content="keep", metadata={"source": "a"}, score=0.9) + rejected = SimpleNamespace(page_content="drop", metadata={"source": "b"}, score=0.1) + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [accepted, rejected] + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) + + assert len(docs) == 1 + assert docs[0].metadata["score"] == 0.9 + + def test_index_list_chunks_high_quality( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, ["chunk-1", "chunk-2"]) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_list_chunks_economy( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Keyword") as mock_keyword_cls, + ): + processor.index(dataset, dataset_document, ["chunk-3"]) + + mock_keyword_cls.return_value.add_texts.assert_called_once() + + def test_index_multimodal_structure_handles_files_and_account_lookup( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + chunk_with_files = SimpleNamespace( + content="content-1", + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + chunk_without_files = SimpleNamespace(content="content-2", files=None) + structure = SimpleNamespace(general_chunks=[chunk_with_files, chunk_without_files]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="img", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.paragraph_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.paragraph_index_processor.Vector"), + ): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + assert mock_files.call_count == 1 + + def test_index_multimodal_structure_requires_valid_account( + self, processor: ParagraphIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + structure = SimpleNamespace(general_chunks=[SimpleNamespace(content="content", files=None)]) + + with ( + patch( + "core.rag.index_processor.processor.paragraph_index_processor.MultimodalGeneralStructureChunk.model_validate", + return_value=structure, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"general_chunks": []}) + + def test_format_preview_validates_chunk_shape(self, processor: ParagraphIndexProcessor) -> None: + preview = processor.format_preview(["chunk-1", "chunk-2"]) + assert preview["chunk_structure"] == "text_model" + assert preview["total_segments"] == 2 + + with pytest.raises(ValueError, match="Chunks is not a list"): + processor.format_preview({"not": "a-list"}) + + def test_generate_summary_preview_success_and_failure(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())): + result = processor.generate_summary_preview( + "tenant-1", preview_items, {"enable": True}, doc_language="English" + ) + assert all(item.summary == "summary" for item in result) + + with patch.object(processor, "generate_summary", side_effect=RuntimeError("summary failed")): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", [PreviewDetail(content="chunk-1")], {"enable": True}) + + def test_generate_summary_preview_fallback_without_flask_context(self, processor: ParagraphIndexProcessor) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch.object(processor, "generate_summary", return_value=("summary", LLMUsage.empty_usage())), + ): + result = processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_timeout( + self, processor: ParagraphIndexProcessor, fake_executor_cls: type + ) -> None: + preview_items = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_items, {"enable": True}) + + future.cancel.assert_called_once() + + def test_generate_summary_validates_input(self) -> None: + with pytest.raises(ValueError, match="must be enabled"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": False}) + + with pytest.raises(ValueError, match="model_name and model_provider_name"): + ParagraphIndexProcessor.generate_summary("tenant-1", "text", {"enable": True}) + + def test_generate_summary_text_only_flow(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + model_instance.invoke_llm.return_value = self._llm_result("text summary") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota", + side_effect=RuntimeError("quota"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, usage = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + document_language="English", + ) + + assert summary == "text summary" + assert isinstance(usage, LLMUsage) + mock_logger.warning.assert_called_with("Failed to deduct quota for summary generation: %s", "quota") + + def test_generate_summary_handles_vision_and_image_conversion(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = self._llm_result("vision summary") + image_file = SimpleNamespace() + image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch.object( + ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[image_file] + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[]) as mock_extract_text, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + return_value=image_content, + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + summary, _ = ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + segment_id="seg-1", + ) + + assert summary == "vision summary" + mock_extract_text.assert_not_called() + + def test_generate_summary_fallbacks_for_prompt_and_result_types(self) -> None: + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.VISION] + ) + model_instance.invoke_llm.return_value = object() + image_file = SimpleNamespace() + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", + return_value=model_instance, + ), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.DEFAULT_GENERATOR_SUMMARY_PROMPT", + "Prompt {missing}", + ), + patch.object(ParagraphIndexProcessor, "_extract_images_from_segment_attachments", return_value=[]), + patch.object(ParagraphIndexProcessor, "_extract_images_from_text", return_value=[image_file]), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.file_manager.to_prompt_message_content", + side_effect=RuntimeError("bad image"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + with pytest.raises(ValueError, match="Expected LLMResult"): + ParagraphIndexProcessor.generate_summary( + "tenant-1", + "text content", + {"enable": True, "model_name": "model-a", "model_provider_name": "provider-a"}, + ) + + mock_logger.warning.assert_called_with( + "Failed to convert image file to prompt message content: %s", "bad image" + ) + + def test_extract_images_from_text_handles_patterns_and_build_errors(self) -> None: + text = ( + "![img](/files/11111111-1111-1111-1111-111111111111/image-preview) " + "![img2](/files/22222222-2222-2222-2222-222222222222/file-preview) " + "![tool](/files/tools/33333333-3333-3333-3333-333333333333.png)" + ) + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + non_image_upload = SimpleNamespace( + id="22222222-2222-2222-2222-222222222222", + tenant_id="tenant-1", + name="file.txt", + mime_type="text/plain", + extension="txt", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload, non_image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + return_value=SimpleNamespace(id="file-1"), + ) as mock_builder, + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert len(files) == 1 + assert mock_builder.call_count == 1 + mock_logger.warning.assert_not_called() + + def test_extract_images_from_text_returns_empty_when_no_matches(self) -> None: + assert ParagraphIndexProcessor._extract_images_from_text("tenant-1", "no images here") == [] + + def test_extract_images_from_text_logs_when_build_fails(self) -> None: + text = "![img](/files/11111111-1111-1111-1111-111111111111/image-preview)" + image_upload = SimpleNamespace( + id="11111111-1111-1111-1111-111111111111", + tenant_id="tenant-1", + name="image.png", + mime_type="image/png", + extension="png", + source_url="", + size=1, + key="key", + ) + query = Mock() + query.where.return_value.all.return_value = [image_upload] + session = Mock() + session.query.return_value = query + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.build_from_mapping", + side_effect=RuntimeError("build failed"), + ), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_text("tenant-1", text) + + assert files == [] + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments(self) -> None: + image_upload = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="", + size=1, + key="k1", + ) + bad_upload = SimpleNamespace( + id="file-2", + name="broken", + extension=None, + mime_type="image/png", + source_url="", + size=1, + key="k2", + ) + non_image_upload = SimpleNamespace( + id="file-3", + name="text", + extension="txt", + mime_type="text/plain", + source_url="", + size=1, + key="k3", + ) + execute_result = Mock() + execute_result.all.return_value = [(None, image_upload), (None, bad_upload), (None, non_image_upload)] + session = Mock() + session.execute.return_value = execute_result + + with ( + patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session), + patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, + ): + files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert len(files) == 1 + mock_logger.warning.assert_called_once() + + def test_extract_images_from_segment_attachments_empty(self) -> None: + execute_result = Mock() + execute_result.all.return_value = [] + session = Mock() + session.execute.return_value = execute_result + + with patch("core.rag.index_processor.processor.paragraph_index_processor.db.session", session): + empty_files = ParagraphIndexProcessor._extract_images_from_segment_attachments("tenant-1", "seg-1") + + assert empty_files == [] diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py new file mode 100644 index 0000000000..5c78cae7c1 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_parent_child_index_processor.py @@ -0,0 +1,524 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.models.document import AttachmentDocument, ChildDocument, Document +from services.entities.knowledge_entities.knowledge_entities import ParentMode + + +class TestParentChildIndexProcessor: + @pytest.fixture + def processor(self) -> ParentChildIndexProcessor: + return ParentChildIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + document.dataset_process_rule_id = None + return document + + def _segmentation(self) -> SimpleNamespace: + return SimpleNamespace(max_tokens=200, chunk_overlap=10, separator="\n") + + def _paragraph_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + segmentation=self._segmentation(), + subchunk_segmentation=self._segmentation(), + ) + + def _full_doc_rules(self) -> SimpleNamespace: + return SimpleNamespace( + parent_mode=ParentMode.FULL_DOC, segmentation=None, subchunk_segmentation=self._segmentation() + ) + + def test_extract_forwards_automatic_flag(self, processor: ParentChildIndexProcessor) -> None: + extract_setting = Mock() + expected = [Document(page_content="chunk", metadata={})] + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ExtractProcessor.extract" + ) as mock_extract: + mock_extract.return_value = expected + documents = processor.extract(extract_setting, process_rule_mode="hierarchical") + + assert documents == expected + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_validates_process_rule(self, processor: ParentChildIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_paragraph_requires_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(parent_mode=ParentMode.PARAGRAPH, segmentation=None) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", return_value=rules + ): + with pytest.raises(ValueError, match="No segmentation found in rules"): + processor.transform( + [Document(page_content="text", metadata={})], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + ) + + def test_transform_paragraph_builds_parent_and_child_docs(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".parent", metadata={}), + Document(page_content=" ", metadata={}), + ] + parent_document = Document(page_content="source", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + child_docs = [ChildDocument(page_content="child-1", metadata={"dataset_id": "dataset-1"})] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + return_value=".parent", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + ): + result = processor.transform( + [parent_document], + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=False, + ) + + assert len(result) == 1 + assert result[0].page_content == "parent" + assert result[0].children == child_docs + assert result[0].attachments is not None + + def test_transform_preview_returns_after_ten_parent_chunks(self, processor: ParentChildIndexProcessor) -> None: + splitter = Mock() + splitter.split_documents.return_value = [Document(page_content=f"chunk-{i}", metadata={}) for i in range(10)] + documents = [ + Document(page_content="doc-1", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="doc-2", metadata={"dataset_id": "dataset-1", "document_id": "doc-2"}), + ] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._paragraph_rules(), + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch.object(processor, "_get_content_files", return_value=[]), + patch.object(processor, "_split_child_nodes", return_value=[]), + ): + result = processor.transform( + documents, + process_rule={"mode": "custom", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 10 + + def test_transform_full_doc_mode_trims_children_for_preview(self, processor: ParentChildIndexProcessor) -> None: + docs = [ + Document(page_content="first", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + Document(page_content="second", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}), + ] + child_docs = [ChildDocument(page_content=f"child-{i}", metadata={}) for i in range(5)] + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.Rule.model_validate", + return_value=self._full_doc_rules(), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ), + patch.object(processor, "_split_child_nodes", return_value=child_docs), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.dify_config.CHILD_CHUNKS_PREVIEW_NUMBER", + 2, + ), + ): + result = processor.transform( + docs, + process_rule={"mode": "hierarchical", "rules": {"enabled": True}}, + preview=True, + ) + + assert len(result) == 1 + assert len(result[0].children or []) == 2 + assert result[0].attachments is not None + + def test_load_creates_vectors_for_child_docs(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + parent_doc = Document( + page_content="parent", + metadata={}, + children=[ + ChildDocument(page_content="child-1", metadata={}), + ChildDocument(page_content="child-2", metadata={}), + ], + ) + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, [parent_doc], multimodal_documents=multimodal_docs) + + assert vector.create.call_count == 1 + formatted_docs = vector.create.call_args[0][0] + assert len(formatted_docs) == 2 + assert all(isinstance(doc, Document) for doc in formatted_docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_clean_with_precomputed_child_ids(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + delete_query = Mock() + where_query = Mock() + where_query.delete.return_value = 2 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean( + dataset, + ["node-1"], + delete_child_chunks=True, + precomputed_child_node_ids=["child-1", "child-2"], + ) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_queries_child_ids_when_not_precomputed( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + child_query = Mock() + child_query.join.return_value.where.return_value.all.return_value = [("child-1",), (None,), ("child-2",)] + session = Mock() + session.query.return_value = child_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_child_chunks=False) + + vector.delete_by_ids.assert_called_once_with(["child-1", "child-2"]) + + def test_clean_dataset_wide_cleanup(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + where_query = Mock() + where_query.delete.return_value = 3 + session = Mock() + session.query.return_value.where.return_value = where_query + + with ( + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_child_chunks=True) + + vector.delete.assert_called_once() + where_query.delete.assert_called_once_with(synchronize_session=False) + session.commit.assert_called_once() + + def test_clean_deletes_summaries_when_requested(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + segment_query = Mock() + segment_query.filter.return_value.all.return_value = [SimpleNamespace(id="seg-1")] + session = Mock() + session.query.return_value = segment_query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.session_factory.create_session", + return_value=session_ctx, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, ["node-1"], delete_summaries=True, precomputed_child_node_ids=[]) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + + def test_clean_deletes_all_summaries_when_node_ids_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock + ) -> None: + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + ): + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + + def test_retrieve_filters_by_score_threshold(self, processor: ParentChildIndexProcessor, dataset: Mock) -> None: + ok_result = SimpleNamespace(page_content="keep", metadata={"m": 1}, score=0.8) + low_result = SimpleNamespace(page_content="drop", metadata={"m": 2}, score=0.2) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.RetrievalService.retrieve" + ) as mock_retrieve: + mock_retrieve.return_value = [ok_result, low_result] + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 3, 0.5, reranking_model) + + assert len(docs) == 1 + assert docs[0].page_content == "keep" + assert docs[0].metadata["score"] == 0.8 + + def test_split_child_nodes_requires_subchunk_segmentation(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=None) + + with pytest.raises(ValueError, match="No subchunk segmentation found"): + processor._split_child_nodes(Document(page_content="parent", metadata={}), rules, "custom", None) + + def test_split_child_nodes_generates_child_documents(self, processor: ParentChildIndexProcessor) -> None: + rules = SimpleNamespace(subchunk_segmentation=self._segmentation()) + splitter = Mock() + splitter.split_documents.return_value = [ + Document(page_content=".child-1", metadata={}), + Document(page_content=" ", metadata={}), + ] + + with ( + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + ): + child_docs = processor._split_child_nodes( + Document(page_content="parent", metadata={}), rules, "custom", None + ) + + assert len(child_docs) == 1 + assert child_docs[0].page_content == "child-1" + assert child_docs[0].metadata["doc_hash"] == "hash" + + def test_index_creates_process_rule_segments_and_vectors( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[ + SimpleNamespace( + parent_content="parent text", + child_contents=["child-1", "child-2"], + files=[SimpleNamespace(id="file-1", filename="image.png")], + ) + ], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + side_effect=lambda text: f"hash-{text}", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore" + ) as mock_store_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector") as mock_vector_cls, + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + assert dataset_document.dataset_process_rule_id == "rule-1" + session.add.assert_called_once_with(dataset_rule) + session.flush.assert_called_once() + session.commit.assert_called_once() + mock_store_cls.return_value.add_documents.assert_called_once() + assert mock_vector_cls.return_value.create.call_count == 1 + mock_vector_cls.return_value.create_multimodal.assert_called_once() + + def test_index_uses_content_files_when_files_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + dataset_rule = SimpleNamespace(id="rule-1") + session = Mock() + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.DatasetProcessRule", + return_value=dataset_rule, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=SimpleNamespace(id="user-1"), + ), + patch.object( + processor, "_get_content_files", return_value=[AttachmentDocument(page_content="image", metadata={})] + ) as mock_files, + patch("core.rag.index_processor.processor.parent_child_index_processor.DatasetDocumentStore"), + patch("core.rag.index_processor.processor.parent_child_index_processor.Vector"), + patch("core.rag.index_processor.processor.parent_child_index_processor.db.session", session), + ): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + mock_files.assert_called_once() + + def test_index_raises_when_account_missing( + self, processor: ParentChildIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child"], files=None)], + ) + + with ( + patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.helper.generate_text_hash", + return_value="hash", + ), + patch( + "core.rag.index_processor.processor.parent_child_index_processor.AccountService.load_user", + return_value=None, + ), + ): + with pytest.raises(ValueError, match="Invalid account"): + processor.index(dataset, dataset_document, {"parent_child_chunks": []}) + + def test_format_preview_returns_parent_child_structure(self, processor: ParentChildIndexProcessor) -> None: + parent_childs = SimpleNamespace( + parent_mode=ParentMode.PARAGRAPH, + parent_child_chunks=[SimpleNamespace(parent_content="parent", child_contents=["child-1", "child-2"])], + ) + + with patch( + "core.rag.index_processor.processor.parent_child_index_processor.ParentChildStructureChunk.model_validate", + return_value=parent_childs, + ): + preview = processor.format_preview({"parent_child_chunks": []}) + + assert preview["chunk_structure"] == "hierarchical_model" + assert preview["parent_mode"] == ParentMode.PARAGRAPH + assert preview["total_segments"] == 1 + + def test_generate_summary_preview_sets_summaries(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1"), PreviewDetail(content="chunk-2")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ): + result = processor.generate_summary_preview( + "tenant-1", preview_texts, {"enable": True}, doc_language="English" + ) + + assert all(item.summary == "summary" for item in result) + + def test_generate_summary_preview_raises_when_worker_fails(self, processor: ParentChildIndexProcessor) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + + with patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + side_effect=RuntimeError("summary failed"), + ): + with pytest.raises(ValueError, match="Failed to generate summaries"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + def test_generate_summary_preview_falls_back_without_flask_context( + self, processor: ParentChildIndexProcessor + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + fake_current_app = SimpleNamespace(_get_current_object=Mock(side_effect=RuntimeError("no app"))) + + with ( + patch("flask.current_app", fake_current_app), + patch( + "core.rag.index_processor.processor.paragraph_index_processor.ParagraphIndexProcessor.generate_summary", + return_value=("summary", None), + ), + ): + result = processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + assert result[0].summary == "summary" + + def test_generate_summary_preview_handles_timeout( + self, processor: ParentChildIndexProcessor, fake_executor_cls: type + ) -> None: + preview_texts = [PreviewDetail(content="chunk-1")] + future = Mock() + executor = fake_executor_cls(future) + + with ( + patch("concurrent.futures.ThreadPoolExecutor", return_value=executor), + patch("concurrent.futures.wait", side_effect=[(set(), {future}), (set(), set())]), + ): + with pytest.raises(ValueError, match="timeout"): + processor.generate_summary_preview("tenant-1", preview_texts, {"enable": True}) + + future.cancel.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py new file mode 100644 index 0000000000..99323eeec9 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_qa_index_processor.py @@ -0,0 +1,383 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ImmediateThread: + def __init__(self, target, args=(), kwargs=None): + self._target = target + self._args = args + self._kwargs = kwargs or {} + + def start(self) -> None: + self._target(*self._args, **self._kwargs) + + def join(self) -> None: + return None + + +class TestQAIndexProcessor: + @pytest.fixture + def processor(self) -> QAIndexProcessor: + return QAIndexProcessor() + + @pytest.fixture + def dataset(self) -> Mock: + dataset = Mock() + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = "high_quality" + dataset.is_multimodal = True + return dataset + + @pytest.fixture + def dataset_document(self) -> Mock: + document = Mock() + document.id = "doc-1" + document.created_by = "user-1" + return document + + @pytest.fixture + def process_rule(self) -> dict: + return { + "mode": "custom", + "rules": {"segmentation": {"max_tokens": 256, "chunk_overlap": 10, "separator": "\n"}}, + } + + def _rules(self) -> SimpleNamespace: + segmentation = SimpleNamespace(max_tokens=256, chunk_overlap=10, separator="\n") + return SimpleNamespace(segmentation=segmentation) + + def test_extract_forwards_automatic_flag(self, processor: QAIndexProcessor) -> None: + extract_setting = Mock() + expected_docs = [Document(page_content="chunk", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.ExtractProcessor.extract") as mock_extract: + mock_extract.return_value = expected_docs + + docs = processor.extract(extract_setting, process_rule_mode="automatic") + + assert docs == expected_docs + mock_extract.assert_called_once_with(extract_setting=extract_setting, is_automatic=True) + + def test_transform_rejects_none_process_rule(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No process rule found"): + processor.transform([Document(page_content="text", metadata={})], process_rule=None) + + def test_transform_rejects_missing_rules_key(self, processor: QAIndexProcessor) -> None: + with pytest.raises(ValueError, match="No rules found in process rule"): + processor.transform([Document(page_content="text", metadata={})], process_rule={"mode": "custom"}) + + def test_transform_preview_calls_formatter_once( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + document = Document(page_content="raw text", metadata={"dataset_id": "dataset-1", "document_id": "doc-1"}) + split_node = Document(page_content=".question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content="Q1", metadata={"answer": "A1"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", return_value="clean text" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text.lstrip("."), + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform( + [document], + process_rule=process_rule, + preview=True, + tenant_id="tenant-1", + doc_language="English", + ) + + assert len(result) == 1 + assert result[0].metadata["answer"] == "A1" + mock_format.assert_called_once() + + def test_transform_non_preview_uses_thread_batches( + self, processor: QAIndexProcessor, process_rule: dict, fake_flask_app + ) -> None: + documents = [ + Document(page_content="doc-1", metadata={"document_id": "doc-1", "dataset_id": "dataset-1"}), + Document(page_content="doc-2", metadata={"document_id": "doc-2", "dataset_id": "dataset-1"}), + ] + split_node = Document(page_content="question", metadata={}) + splitter = Mock() + splitter.split_documents.return_value = [split_node] + + def _append_document(flask_app, tenant_id, document_node, all_qa_documents, document_language): + all_qa_documents.append(Document(page_content=f"Q-{document_node.page_content}", metadata={"answer": "A"})) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.Rule.model_validate", return_value=self._rules() + ), + patch.object(processor, "_get_splitter", return_value=splitter), + patch( + "core.rag.index_processor.processor.qa_index_processor.CleanProcessor.clean", + side_effect=lambda text, _: text, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.remove_leading_symbols", + side_effect=lambda text: text, + ), + patch.object(processor, "_format_qa_document", side_effect=_append_document) as mock_format, + patch("core.rag.index_processor.processor.qa_index_processor.current_app") as mock_current_app, + patch( + "core.rag.index_processor.processor.qa_index_processor.threading.Thread", side_effect=_ImmediateThread + ), + ): + mock_current_app._get_current_object.return_value = fake_flask_app + result = processor.transform(documents, process_rule=process_rule, preview=False, tenant_id="tenant-1") + + assert len(result) == 2 + assert mock_format.call_count == 2 + + def test_format_by_template_validates_file_type(self, processor: QAIndexProcessor) -> None: + not_csv_file = Mock(spec=FileStorage) + not_csv_file.filename = "qa.txt" + + with pytest.raises(ValueError, match="Only CSV files"): + processor.format_by_template(not_csv_file) + + def test_format_by_template_parses_csv_rows(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + dataframe = pd.DataFrame([["Q1", "A1"], ["Q2", "A2"]]) + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=dataframe): + docs = processor.format_by_template(csv_file) + + assert [doc.page_content for doc in docs] == ["Q1", "Q2"] + assert [doc.metadata["answer"] for doc in docs] == ["A1", "A2"] + + def test_format_by_template_raises_on_empty_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch("core.rag.index_processor.processor.qa_index_processor.pd.read_csv", return_value=pd.DataFrame()): + with pytest.raises(ValueError, match="empty"): + processor.format_by_template(csv_file) + + def test_format_by_template_raises_on_invalid_csv(self, processor: QAIndexProcessor) -> None: + csv_file = Mock(spec=FileStorage) + csv_file.filename = "qa.csv" + + with patch( + "core.rag.index_processor.processor.qa_index_processor.pd.read_csv", side_effect=Exception("bad csv") + ): + with pytest.raises(ValueError, match="bad csv"): + processor.format_by_template(csv_file) + + def test_load_creates_vectors_for_high_quality_dataset(self, processor: QAIndexProcessor, dataset: Mock) -> None: + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + multimodal_docs = [AttachmentDocument(page_content="image", metadata={})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + vector = mock_vector_cls.return_value + processor.load(dataset, docs, multimodal_documents=multimodal_docs) + + vector.create.assert_called_once_with(docs) + vector.create_multimodal.assert_called_once_with(multimodal_docs) + + def test_load_skips_vector_for_non_high_quality(self, processor: QAIndexProcessor, dataset: Mock) -> None: + dataset.indexing_technique = "economy" + docs = [Document(page_content="Q1", metadata={"answer": "A1"})] + + with patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls: + processor.load(dataset, docs) + + mock_vector_cls.assert_not_called() + + def test_clean_handles_summary_deletion_and_vector_cleanup( + self, processor: QAIndexProcessor, dataset: Mock + ) -> None: + mock_segment = SimpleNamespace(id="seg-1") + mock_query = Mock() + mock_query.filter.return_value.all.return_value = [mock_segment] + mock_session = Mock() + mock_session.query.return_value = mock_query + session_context = MagicMock() + session_context.__enter__.return_value = mock_session + session_context.__exit__.return_value = False + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.session_factory.create_session", + return_value=session_context, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, ["node-1"], delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, ["seg-1"]) + vector.delete_by_ids.assert_called_once_with(["node-1"]) + + def test_clean_handles_dataset_wide_cleanup(self, processor: QAIndexProcessor, dataset: Mock) -> None: + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.SummaryIndexService.delete_summaries_for_segments" + ) as mock_summary, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + vector = mock_vector_cls.return_value + processor.clean(dataset, None, delete_summaries=True) + + mock_summary.assert_called_once_with(dataset, None) + vector.delete.assert_called_once() + + def test_retrieve_filters_by_score_threshold(self, processor: QAIndexProcessor, dataset: Mock) -> None: + result_ok = SimpleNamespace(page_content="accepted", metadata={"source": "a"}, score=0.9) + result_low = SimpleNamespace(page_content="rejected", metadata={"source": "b"}, score=0.1) + + with patch("core.rag.index_processor.processor.qa_index_processor.RetrievalService.retrieve") as mock_retrieve: + mock_retrieve.return_value = [result_ok, result_low] + reranking_model = {"reranking_provider_name": "", "reranking_model_name": ""} + docs = processor.retrieve("semantic_search", "query", dataset, 5, 0.5, reranking_model) + + assert len(docs) == 1 + assert docs[0].page_content == "accepted" + assert docs[0].metadata["score"] == 0.9 + + def test_index_adds_documents_and_vectors_for_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + qa_chunks = SimpleNamespace( + qa_chunks=[ + SimpleNamespace(question="Q1", answer="A1"), + SimpleNamespace(question="Q2", answer="A2"), + ] + ) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore") as mock_store_cls, + patch("core.rag.index_processor.processor.qa_index_processor.Vector") as mock_vector_cls, + ): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + mock_store_cls.return_value.add_documents.assert_called_once() + mock_vector_cls.return_value.create.assert_called_once() + + def test_index_requires_high_quality( + self, processor: QAIndexProcessor, dataset: Mock, dataset_document: Mock + ) -> None: + dataset.indexing_technique = "economy" + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + patch("core.rag.index_processor.processor.qa_index_processor.DatasetDocumentStore"), + ): + with pytest.raises(ValueError, match="must be high quality"): + processor.index(dataset, dataset_document, {"qa_chunks": []}) + + def test_format_preview_returns_qa_preview(self, processor: QAIndexProcessor) -> None: + qa_chunks = SimpleNamespace(qa_chunks=[SimpleNamespace(question="Q1", answer="A1")]) + + with patch( + "core.rag.index_processor.processor.qa_index_processor.QAStructureChunk.model_validate", + return_value=qa_chunks, + ): + preview = processor.format_preview({"qa_chunks": []}) + + assert preview["chunk_structure"] == "qa_model" + assert preview["total_segments"] == 1 + assert preview["qa_preview"] == [{"question": "Q1", "answer": "A1"}] + + def test_generate_summary_preview_returns_input(self, processor: QAIndexProcessor) -> None: + preview_items = [PreviewDetail(content="Q1")] + assert processor.generate_summary_preview("tenant-1", preview_items, {}) is preview_items + + def test_format_qa_document_ignores_blank_text(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + blank_document = Document(page_content=" ", metadata={}) + + processor._format_qa_document(fake_flask_app, "tenant-1", blank_document, all_qa_documents, "English") + + assert all_qa_documents == [] + + def test_format_qa_document_builds_question_answer_documents( + self, processor: QAIndexProcessor, fake_flask_app + ) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + return_value="Q1: What is this?\nA1: A test.\nQ2: Why?\nA2: Coverage.", + ), + patch( + "core.rag.index_processor.processor.qa_index_processor.helper.generate_text_hash", return_value="hash" + ), + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert len(all_qa_documents) == 2 + assert all_qa_documents[0].page_content == "What is this?" + assert all_qa_documents[0].metadata["answer"] == "A test." + assert all_qa_documents[1].metadata["answer"] == "Coverage." + + def test_format_qa_document_logs_errors(self, processor: QAIndexProcessor, fake_flask_app) -> None: + all_qa_documents: list[Document] = [] + source_document = Document(page_content="source text", metadata={"origin": "doc-1"}) + + with ( + patch( + "core.rag.index_processor.processor.qa_index_processor.LLMGenerator.generate_qa_document", + side_effect=RuntimeError("llm failure"), + ), + patch("core.rag.index_processor.processor.qa_index_processor.logger") as mock_logger, + ): + processor._format_qa_document(fake_flask_app, "tenant-1", source_document, all_qa_documents, "English") + + assert all_qa_documents == [] + mock_logger.exception.assert_called_once_with("Failed to format qa document") + + def test_format_split_text_extracts_question_answer_pairs(self, processor: QAIndexProcessor) -> None: + parsed = processor._format_split_text("Q1: First?\nA1: One.\nQ2: Second?\nA2: Two.\n") + + assert parsed == [{"question": "First?", "answer": "One."}, {"question": "Second?", "answer": "Two."}] diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py new file mode 100644 index 0000000000..b31bb6eea7 --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_base.py @@ -0,0 +1,291 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import httpx +import pytest + +from core.entities.knowledge_entities import PreviewDetail +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.index_processor_base import BaseIndexProcessor +from core.rag.models.document import AttachmentDocument, Document + + +class _ForwardingBaseIndexProcessor(BaseIndexProcessor): + def extract(self, extract_setting, **kwargs): + return super().extract(extract_setting, **kwargs) + + def transform(self, documents, current_user=None, **kwargs): + return super().transform(documents, current_user=current_user, **kwargs) + + def generate_summary_preview(self, tenant_id, preview_texts, summary_index_setting, doc_language=None): + return super().generate_summary_preview( + tenant_id=tenant_id, + preview_texts=preview_texts, + summary_index_setting=summary_index_setting, + doc_language=doc_language, + ) + + def load(self, dataset, documents, multimodal_documents=None, with_keywords=True, **kwargs): + return super().load( + dataset=dataset, + documents=documents, + multimodal_documents=multimodal_documents, + with_keywords=with_keywords, + **kwargs, + ) + + def clean(self, dataset, node_ids, with_keywords=True, **kwargs): + return super().clean(dataset=dataset, node_ids=node_ids, with_keywords=with_keywords, **kwargs) + + def index(self, dataset, document, chunks): + return super().index(dataset=dataset, document=document, chunks=chunks) + + def format_preview(self, chunks): + return super().format_preview(chunks) + + def retrieve(self, retrieval_method, query, dataset, top_k, score_threshold, reranking_model): + return super().retrieve( + retrieval_method=retrieval_method, + query=query, + dataset=dataset, + top_k=top_k, + score_threshold=score_threshold, + reranking_model=reranking_model, + ) + + +class TestBaseIndexProcessor: + @pytest.fixture + def processor(self) -> _ForwardingBaseIndexProcessor: + return _ForwardingBaseIndexProcessor() + + def test_abstract_methods_raise_not_implemented(self, processor: _ForwardingBaseIndexProcessor) -> None: + with pytest.raises(NotImplementedError): + processor.extract(Mock()) + with pytest.raises(NotImplementedError): + processor.transform([]) + with pytest.raises(NotImplementedError): + processor.generate_summary_preview("tenant", [PreviewDetail(content="c")], {}) + with pytest.raises(NotImplementedError): + processor.load(Mock(), []) + with pytest.raises(NotImplementedError): + processor.clean(Mock(), None) + with pytest.raises(NotImplementedError): + processor.index(Mock(), Mock(), {}) + with pytest.raises(NotImplementedError): + processor.format_preview([]) + with pytest.raises(NotImplementedError): + processor.retrieve("semantic_search", "q", Mock(), 3, 0.5, {}) + + def test_get_splitter_validates_custom_length(self, processor: _ForwardingBaseIndexProcessor) -> None: + with patch( + "core.rag.index_processor.index_processor_base.dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH", 1000 + ): + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 49, 0, "", None) + with pytest.raises(ValueError, match="between 50 and 1000"): + processor._get_splitter("custom", 1001, 0, "", None) + + def test_get_splitter_custom_mode_uses_fixed_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + fixed_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.FixedRecursiveCharacterTextSplitter.from_encoder", + return_value=fixed_splitter, + ) as mock_fixed: + splitter = processor._get_splitter("hierarchical", 120, 10, "\\n\\n", None) + + assert splitter is fixed_splitter + assert mock_fixed.call_args.kwargs["fixed_separator"] == "\n\n" + assert mock_fixed.call_args.kwargs["chunk_size"] == 120 + + def test_get_splitter_automatic_mode_uses_enhance_splitter(self, processor: _ForwardingBaseIndexProcessor) -> None: + auto_splitter = Mock() + with patch( + "core.rag.index_processor.index_processor_base.EnhanceRecursiveCharacterTextSplitter.from_encoder", + return_value=auto_splitter, + ) as mock_enhance: + splitter = processor._get_splitter("automatic", 0, 0, "", None) + + assert splitter is auto_splitter + assert "chunk_size" in mock_enhance.call_args.kwargs + + def test_extract_markdown_images(self, processor: _ForwardingBaseIndexProcessor) -> None: + markdown = "text ![a](https://a/img.png) and ![b](/files/123/file-preview)" + images = processor._extract_markdown_images(markdown) + assert images == ["https://a/img.png", "/files/123/file-preview"] + + def test_get_content_files_without_images_returns_empty(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="no image markdown", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + assert processor._get_content_files(document) == [] + + def test_get_content_files_handles_all_sources_and_duplicates( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = [ + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview", + "/files/bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb/file-preview", + "/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", + "https://example.com/remote.png?x=1", + ] + upload_a = SimpleNamespace(id="aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa", name="a.png") + upload_b = SimpleNamespace(id="bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", name="b.png") + upload_tool = SimpleNamespace(id="tool-upload-id", name="tool.png") + upload_remote = SimpleNamespace(id="remote-upload-id", name="remote.png") + db_query = Mock() + db_query.where.return_value.all.return_value = [upload_a, upload_b, upload_tool, upload_remote] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch.object(processor, "_download_tool_file", return_value="tool-upload-id") as mock_tool_download, + patch.object(processor, "_download_image", return_value="remote-upload-id") as mock_image_download, + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document, current_user=Mock()) + + assert len(files) == 5 + assert all(isinstance(file, AttachmentDocument) for file in files) + assert files[0].metadata["doc_type"] == DocType.IMAGE + assert files[0].metadata["document_id"] == "doc-1" + assert files[0].metadata["dataset_id"] == "ds-1" + assert files[0].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + assert files[1].metadata["doc_id"] == "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa" + mock_tool_download.assert_called_once() + mock_image_download.assert_called_once() + + def test_get_content_files_skips_tool_and_remote_download_without_user( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/tools/cccccccc-cccc-cccc-cccc-cccccccccccc.png", "https://example.com/remote.png"] + + with patch.object(processor, "_extract_markdown_images", return_value=images): + files = processor._get_content_files(document, current_user=None) + + assert files == [] + + def test_get_content_files_ignores_missing_upload_records(self, processor: _ForwardingBaseIndexProcessor) -> None: + document = Document(page_content="ignored", metadata={"document_id": "doc-1", "dataset_id": "ds-1"}) + images = ["/files/aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa/image-preview"] + db_query = Mock() + db_query.where.return_value.all.return_value = [] + db_session = Mock() + db_session.query.return_value = db_query + + with ( + patch.object(processor, "_extract_markdown_images", return_value=images), + patch("core.rag.index_processor.index_processor_base.db.session", db_session), + ): + files = processor._get_content_files(document) + + assert files == [] + + def test_download_image_success_with_filename_from_content_disposition( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + response = Mock() + response.headers = { + "Content-Length": "4", + "content-disposition": "attachment; filename=test-image.png", + "content-type": "image/png", + } + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"data"] + upload_result = SimpleNamespace(id="upload-id") + + mock_db = Mock() + mock_db.engine = Mock() + + with ( + patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response), + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + upload_id = processor._download_image("https://example.com/test.png", current_user=Mock()) + + assert upload_id == "upload-id" + mock_file_service.return_value.upload_file.assert_called_once() + + def test_download_image_validates_size_and_empty_content(self, processor: _ForwardingBaseIndexProcessor) -> None: + too_large = Mock() + too_large.headers = {"Content-Length": str(3 * 1024 * 1024), "content-type": "image/png"} + too_large.raise_for_status.return_value = None + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=too_large): + assert processor._download_image("https://example.com/too-large.png", current_user=Mock()) is None + + empty = Mock() + empty.headers = {"Content-Length": "0", "content-type": "image/png"} + empty.raise_for_status.return_value = None + empty.iter_bytes.return_value = [] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=empty): + assert processor._download_image("https://example.com/empty.png", current_user=Mock()) is None + + def test_download_image_limits_stream_size(self, processor: _ForwardingBaseIndexProcessor) -> None: + response = Mock() + response.headers = {"content-type": "image/png"} + response.raise_for_status.return_value = None + response.iter_bytes.return_value = [b"a" * (3 * 1024 * 1024)] + + with patch("core.rag.index_processor.index_processor_base.ssrf_proxy.get", return_value=response): + assert processor._download_image("https://example.com/big-stream.png", current_user=Mock()) is None + + def test_download_image_handles_timeout_request_and_unexpected_errors( + self, processor: _ForwardingBaseIndexProcessor + ) -> None: + request = httpx.Request("GET", "https://example.com/image.png") + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.TimeoutException("timeout"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=httpx.RequestError("bad request", request=request), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + with patch( + "core.rag.index_processor.index_processor_base.ssrf_proxy.get", + side_effect=RuntimeError("unexpected"), + ): + assert processor._download_image("https://example.com/image.png", current_user=Mock()) is None + + def test_download_tool_file_returns_none_when_not_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + db_query = Mock() + db_query.where.return_value.first.return_value = None + db_session = Mock() + db_session.query.return_value = db_query + + with patch("core.rag.index_processor.index_processor_base.db.session", db_session): + assert processor._download_tool_file("tool-id", current_user=Mock()) is None + + def test_download_tool_file_uploads_file_when_found(self, processor: _ForwardingBaseIndexProcessor) -> None: + tool_file = SimpleNamespace(file_key="k1", name="tool.png", mimetype="image/png") + db_query = Mock() + db_query.where.return_value.first.return_value = tool_file + db_session = Mock() + db_session.query.return_value = db_query + mock_db = Mock() + mock_db.session = db_session + mock_db.engine = Mock() + upload_result = SimpleNamespace(id="upload-id") + + with ( + patch("core.rag.index_processor.index_processor_base.db", mock_db), + patch("core.rag.index_processor.index_processor_base.storage.load_once", return_value=b"blob") as mock_load, + patch("services.file_service.FileService") as mock_file_service, + ): + mock_file_service.return_value.upload_file.return_value = upload_result + result = processor._download_tool_file("tool-id", current_user=Mock()) + + assert result == "upload-id" + mock_load.assert_called_once_with("k1") + mock_file_service.return_value.upload_file.assert_called_once() diff --git a/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py new file mode 100644 index 0000000000..0fc666dbbf --- /dev/null +++ b/api/tests/unit_tests/core/rag/indexing/test_index_processor_factory.py @@ -0,0 +1,42 @@ +import pytest + +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.index_processor.index_processor_factory import IndexProcessorFactory +from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor +from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor +from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor + + +class TestIndexProcessorFactory: + def test_requires_index_type(self) -> None: + factory = IndexProcessorFactory(index_type=None) + + with pytest.raises(ValueError, match="Index type must be specified"): + factory.init_index_processor() + + def test_builds_paragraph_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARAGRAPH_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParagraphIndexProcessor) + + def test_builds_qa_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.QA_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, QAIndexProcessor) + + def test_builds_parent_child_processor(self) -> None: + factory = IndexProcessorFactory(index_type=IndexStructureType.PARENT_CHILD_INDEX) + + processor = factory.init_index_processor() + + assert isinstance(processor, ParentChildIndexProcessor) + + def test_rejects_unsupported_index_type(self) -> None: + factory = IndexProcessorFactory(index_type="unsupported") + + with pytest.raises(ValueError, match="is not supported"): + factory.init_index_processor() diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index c00fee8fe5..b011ade884 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -61,9 +61,9 @@ from core.indexing_runner import ( DocumentIsPausedError, IndexingRunner, ) -from core.model_runtime.entities.model_entities import ModelType from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import ChildDocument, Document +from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index e4597e7f8c..b150d677f1 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -12,21 +12,26 @@ All tests use mocking to avoid external dependencies and ensure fast, reliable e Tests follow the Arrange-Act-Assert pattern for clarity. """ +from operator import itemgetter +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest from core.model_manager import ModelInstance -from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights +from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner +from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult -def create_mock_model_instance(): +def create_mock_model_instance() -> ModelInstance: """Create a properly configured mock ModelInstance for reranking tests.""" mock_instance = Mock(spec=ModelInstance) # Setup provider_model_bundle chain for check_model_support_vision @@ -59,14 +64,7 @@ class TestRerankModelRunner: @pytest.fixture def mock_model_instance(self): """Create a mock ModelInstance for reranking.""" - mock_instance = Mock(spec=ModelInstance) - # Setup provider_model_bundle chain for check_model_support_vision - mock_instance.provider_model_bundle = Mock() - mock_instance.provider_model_bundle.configuration = Mock() - mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id" - mock_instance.provider = "test-provider" - mock_instance.model_name = "test-model" - return mock_instance + return create_mock_model_instance() @pytest.fixture def rerank_runner(self, mock_model_instance): @@ -382,6 +380,206 @@ class TestRerankModelRunner: assert call_kwargs["user"] == "user123" +class _ForwardingBaseRerankRunner(BaseRerankRunner): + def run( + self, + query: str, + documents: list[Document], + score_threshold: float | None = None, + top_n: int | None = None, + user: str | None = None, + query_type: QueryType = QueryType.TEXT_QUERY, + ) -> list[Document]: + return super().run( + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=top_n, + user=user, + query_type=query_type, + ) + + +class TestBaseRerankRunner: + def test_run_raises_not_implemented(self): + runner = _ForwardingBaseRerankRunner() + + with pytest.raises(NotImplementedError): + runner.run(query="python", documents=[]) + + +class TestRerankModelRunnerMultimodal: + @pytest.fixture + def mock_model_instance(self): + return create_mock_model_instance() + + @pytest.fixture + def rerank_runner(self, mock_model_instance): + return RerankModelRunner(rerank_model_instance=mock_model_instance) + + def test_run_returns_original_documents_for_non_text_query_without_vision_support( + self, rerank_runner, mock_model_instance + ): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), + ] + + with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + mock_mm.return_value.check_model_support_vision.return_value = False + result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) + + assert result == documents + mock_model_instance.invoke_rerank.assert_not_called() + + def test_run_uses_multimodal_path_when_vision_support_is_enabled(self, rerank_runner): + documents = [ + Document(page_content="doc", metadata={"doc_id": "doc1", "source": "wiki"}, provider="dify"), + ] + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="doc", score=0.88)], + ) + + with ( + patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch.object( + rerank_runner, + "fetch_multimodal_rerank", + return_value=(rerank_result, documents), + ) as mock_multimodal, + ): + mock_mm.return_value.check_model_support_vision.return_value = True + result = rerank_runner.run(query="python", documents=documents, query_type=QueryType.TEXT_QUERY) + + mock_multimodal.assert_called_once() + assert len(result) == 1 + assert result[0].metadata["score"] == 0.88 + + def test_fetch_multimodal_rerank_builds_docs_and_calls_text_rerank(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-1", "doc_type": DocType.IMAGE}, + provider="dify", + ) + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + external_doc = Document( + page_content="external-content", + metadata={}, + provider="external", + ) + query = Mock() + query.where.return_value.first.return_value = SimpleNamespace(key="image-key") + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"image-bytes") as mock_load_once, + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc, text_doc, external_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc, text_doc, external_doc, external_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert len(unique_documents) == 3 + mock_load_once.assert_called_once_with("image-key") + text_rerank_call_args = mock_text_rerank.call_args.args + assert len(text_rerank_call_args[1]) == 3 + + def test_fetch_multimodal_rerank_skips_missing_image_upload(self, rerank_runner): + image_doc = Document( + page_content="image-content", + metadata={"doc_id": "img-missing", "doc_type": DocType.IMAGE}, + provider="dify", + ) + query = Mock() + query.where.return_value.first.return_value = None + rerank_result = RerankResult(model="rerank-model", docs=[]) + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query), + patch.object( + rerank_runner, + "fetch_text_rerank", + return_value=(rerank_result, [image_doc]), + ) as mock_text_rerank, + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[image_doc], + query_type=QueryType.TEXT_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [image_doc] + docs_arg = mock_text_rerank.call_args.args[1] + assert len(docs_arg) == 1 + + def test_fetch_multimodal_rerank_image_query_invokes_multimodal_model(self, rerank_runner, mock_model_instance): + text_doc = Document( + page_content="text-content", + metadata={"doc_id": "txt-1", "doc_type": DocType.TEXT}, + provider="dify", + ) + query_chain = Mock() + query_chain.where.return_value.first.return_value = SimpleNamespace(key="query-image-key") + rerank_result = RerankResult( + model="rerank-model", + docs=[RerankDocument(index=0, text="text-content", score=0.77)], + ) + mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + + with ( + patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), + ): + result, unique_documents = rerank_runner.fetch_multimodal_rerank( + query="query-upload-id", + documents=[text_doc], + score_threshold=0.2, + top_n=2, + user="user-1", + query_type=QueryType.IMAGE_QUERY, + ) + + assert result == rerank_result + assert unique_documents == [text_doc] + invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs + assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE + assert invoke_kwargs["docs"][0]["content"] == "text-content" + assert invoke_kwargs["user"] == "user-1" + + def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): + query_chain = Mock() + query_chain.where.return_value.first.return_value = None + + with patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain): + with pytest.raises(ValueError, match="Upload file not found for query"): + rerank_runner.fetch_multimodal_rerank( + query="missing-upload-id", + documents=[], + query_type=QueryType.IMAGE_QUERY, + ) + + def test_fetch_multimodal_rerank_rejects_unsupported_query_type(self, rerank_runner): + with pytest.raises(ValueError, match="is not supported"): + rerank_runner.fetch_multimodal_rerank( + query="python", + documents=[], + query_type="unsupported_query_type", + ) + + class TestWeightRerankRunner: """Unit tests for WeightRerankRunner. @@ -512,34 +710,39 @@ class TestWeightRerankRunner: - TF-IDF scores are calculated correctly - Cosine similarity is computed for keyword vectors """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - - # Mock keyword extraction with specific keywords + keyword_map = { + "python programming": ["python", "programming"], + "Python is a programming language": ["python", "programming", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.side_effect = [ - ["python", "programming"], # query - ["python", "programming", "language"], # doc1 - ["javascript", "web"], # doc2 - ["java", "programming"], # doc3 - ] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("python programming", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "python programming", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="python programming", documents=sample_documents_with_vectors) - # Assert: Keywords are extracted and scores are calculated - assert len(result) == 3 - # Document 1 should have highest keyword score (matches both query terms) - # Document 3 should have medium score (matches one term) - # Document 2 should have lowest score (matches no terms) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_vector_score_calculation( self, @@ -556,30 +759,42 @@ class TestWeightRerankRunner: - Cosine similarity is calculated with document vectors - Vector scores are properly normalized """ - # Arrange: Create runner runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test query": ["test"], + "Python is a programming language": ["python"], + "JavaScript for web development": ["javascript"], + "Java object-oriented programming": ["java"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding model mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance - # Mock cache embedding with specific query vector mock_cache_instance = MagicMock() query_vector = [0.2, 0.3, 0.4, 0.5] mock_cache_instance.embed_query.return_value = query_vector mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test query", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test query", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test query", documents=sample_documents_with_vectors) - # Assert: Vector scores are calculated - assert len(result) == 3 - # Verify cosine similarity was computed (doc2 vector is closest to query vector) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_score_threshold_filtering_weighted( self, @@ -742,28 +957,40 @@ class TestWeightRerankRunner: - Keyword weight (0.4) is applied to keyword scores - Combined score is the sum of weighted components """ - # Arrange: Create runner with known weights runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Python is a programming language": ["python", "language"], + "JavaScript for web development": ["javascript", "web"], + "Java object-oriented programming": ["java", "programming"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2, 0.3, 0.4] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", sample_documents_with_vectors) + vector_scores = runner._calculate_cosine( + "tenant123", "test", sample_documents_with_vectors, weights_config.vector_setting + ) + expected_scores = { + doc.metadata["doc_id"]: (0.6 * vector_score + 0.4 * query_score) + for doc, query_score, vector_score in zip(sample_documents_with_vectors, query_scores, vector_scores) + } + result = runner.run(query="test", documents=sample_documents_with_vectors) - # Assert: Scores are combined with weights - # Score = 0.6 * vector_score + 0.4 * keyword_score - assert len(result) == 3 - assert all("score" in doc.metadata for doc in result) + expected_order = [doc_id for doc_id, _ in sorted(expected_scores.items(), key=itemgetter(1), reverse=True)] + assert [doc.metadata["doc_id"] for doc in result] == expected_order + for doc in result: + doc_id = doc.metadata["doc_id"] + assert doc.metadata["score"] == pytest.approx(expected_scores[doc_id], rel=1e-6) def test_existing_vector_score_in_metadata( self, @@ -778,7 +1005,6 @@ class TestWeightRerankRunner: - If document already has a score in metadata, it's used - Cosine similarity calculation is skipped for such documents """ - # Arrange: Documents with pre-existing scores documents = [ Document( page_content="Content with existing score", @@ -790,24 +1016,29 @@ class TestWeightRerankRunner: runner = WeightRerankRunner(tenant_id="tenant123", weights=weights_config) - # Mock keyword extraction + keyword_map = { + "test": ["test"], + "Content with existing score": ["test"], + } mock_handler_instance = MagicMock() - mock_handler_instance.extract_keywords.return_value = ["test"] + mock_handler_instance.extract_keywords.side_effect = lambda text, _: keyword_map[text] mock_jieba_handler.return_value = mock_handler_instance - # Mock embedding mock_embedding_instance = MagicMock() mock_model_manager.return_value.get_model_instance.return_value = mock_embedding_instance mock_cache_instance = MagicMock() mock_cache_instance.embed_query.return_value = [0.1, 0.2] mock_cache_embedding.return_value = mock_cache_instance - # Act: Run reranking + query_scores = runner._calculate_keyword_score("test", documents) + vector_scores = runner._calculate_cosine("tenant123", "test", documents, weights_config.vector_setting) + expected_score = 0.6 * vector_scores[0] + 0.4 * query_scores[0] + result = runner.run(query="test", documents=documents) - # Assert: Existing score is used in calculation assert len(result) == 1 - # The final score should incorporate the existing score (0.95) with vector weight (0.6) + assert result[0].metadata["doc_id"] == "doc1" + assert result[0].metadata["score"] == pytest.approx(expected_score, rel=1e-6) class TestRerankRunnerFactory: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index ca08cb0591..665e98bd9c 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -1,80 +1,42 @@ -""" -Unit tests for dataset retrieval functionality. - -This module provides comprehensive test coverage for the RetrievalService class, -which is responsible for retrieving relevant documents from datasets using various -search strategies. - -Core Retrieval Mechanisms Tested: -================================== -1. **Vector Search (Semantic Search)** - - Uses embedding vectors to find semantically similar documents - - Supports score thresholds and top-k limiting - - Can filter by document IDs and metadata - -2. **Keyword Search** - - Traditional text-based search using keyword matching - - Handles special characters and query escaping - - Supports document filtering - -3. **Full-Text Search** - - BM25-based full-text search for text matching - - Used in hybrid search scenarios - -4. **Hybrid Search** - - Combines vector and full-text search results - - Implements deduplication to avoid duplicate chunks - - Uses DataPostProcessor for score merging with configurable weights - -5. **Score Merging Algorithms** - - Deduplication based on doc_id - - Retains higher-scoring duplicates - - Supports weighted score combination - -6. **Metadata Filtering** - - Filters documents based on metadata conditions - - Supports document ID filtering - -Test Architecture: -================== -- **Fixtures**: Provide reusable mock objects (datasets, documents, Flask app) -- **Mocking Strategy**: Mock at the method level (embedding_search, keyword_search, etc.) - rather than at the class level to properly simulate the ThreadPoolExecutor behavior -- **Pattern**: All tests follow Arrange-Act-Assert (AAA) pattern -- **Isolation**: Each test is independent and doesn't rely on external state - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py -v - - # Run a specific test class - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::TestRetrievalService -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py::\ -TestRetrievalService::test_vector_search_basic -v - -Notes: -====== -- The RetrievalService uses ThreadPoolExecutor for concurrent search operations -- Tests mock the individual search methods to avoid threading complexity -- All mocked search methods modify the all_documents list in-place -- Score thresholds and top-k limits are enforced by the search methods -""" - +import threading +from contextlib import contextmanager, nullcontext +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch from uuid import uuid4 import pytest +from flask import Flask, current_app +from sqlalchemy import column +from core.app.app_config.entities import ( + Condition as AppCondition, +) +from core.app.app_config.entities import ( + DatasetEntity, + DatasetRetrieveConfigEntity, +) +from core.app.app_config.entities import ( + MetadataFilteringCondition as AppMetadataFilteringCondition, +) +from core.app.app_config.entities import ( + ModelConfig as AppModelConfig, +) +from core.app.app_config.entities import ModelConfig as WorkflowModelConfig +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.entities.agent_entities import PlanningStrategy +from core.entities.model_entities import ModelStatus +from core.rag.data_post_processor.data_post_processor import WeightsDict from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.index_processor.constant.doc_type import DocType +from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.models.document import Document +from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset # ==================== Helper Functions ==================== @@ -2013,3 +1975,3094 @@ class TestDocumentModel: assert doc1 == doc2 assert doc1 != doc3 + + +# ==================== Helper Functions ==================== + + +def create_mock_dataset_methods( + dataset_id: str | None = None, + tenant_id: str | None = None, + provider: str = "dify", + indexing_technique: str = "high_quality", + available_document_count: int = 10, +) -> Mock: + """ + Create a mock Dataset object for testing. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant ID for the dataset + provider: Provider type ("dify" or "external") + indexing_technique: Indexing technique ("high_quality" or "economy") + available_document_count: Number of available documents + + Returns: + Mock: A properly configured Dataset mock + """ + dataset = Mock(spec=Dataset) + dataset.id = dataset_id or str(uuid4()) + dataset.tenant_id = tenant_id or str(uuid4()) + dataset.name = "test_dataset" + dataset.provider = provider + dataset.indexing_technique = indexing_technique + dataset.available_document_count = available_document_count + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "top_k": 4, + "score_threshold_enabled": False, + } + return dataset + + +def create_mock_document_methods( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +# ==================== Test _check_knowledge_rate_limit ==================== + + +class TestCheckKnowledgeRateLimit: + """ + Test suite for _check_knowledge_rate_limit method. + + The _check_knowledge_rate_limit method validates whether a tenant has + exceeded their knowledge retrieval rate limit. This is important for: + - Preventing abuse of the knowledge retrieval system + - Enforcing subscription plan limits + - Tracking usage for billing purposes + + Test Cases: + ============ + 1. Rate limit disabled - no exception raised + 2. Rate limit enabled but not exceeded - no exception raised + 3. Rate limit enabled and exceeded - RateLimitExceededError raised + 4. Redis operations are performed correctly + 5. RateLimitLog is created when limit is exceeded + """ + + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service): + """ + Test that when rate limit is disabled, no exception is raised. + + This test verifies the behavior when the tenant's subscription + does not have rate limiting enabled. + + Verifies: + - FeatureService.get_knowledge_rate_limit is called + - No Redis operations are performed + - No exception is raised + - Retrieval proceeds normally + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit disabled + mock_limit = Mock() + mock_limit.enabled = False + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify FeatureService was called + mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id) + + # Verify no Redis operations were performed + assert not mock_redis.zadd.called + assert not mock_redis.zremrangebyscore.called + assert not mock_redis.zcard.called + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory): + """ + Test that when rate limit is enabled but not exceeded, no exception is raised. + + This test simulates a tenant making requests within their rate limit. + The Redis sorted set stores timestamps of recent requests, and old + requests (older than 60 seconds) are removed. + + Verifies: + - Redis zadd is called to track the request + - Redis zremrangebyscore removes old entries + - Redis zcard returns count within limit + - No exception is raised + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 # Current time in milliseconds + mock_time.time.return_value = current_time / 1000 # Return seconds + mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds + + # Mock Redis operations + # zcard returns 50 (within limit of 100) + mock_redis.zcard.return_value = 50 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify Redis operations + expected_key = f"rate_limit_{tenant_id}" + mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time}) + mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000) + mock_redis.zcard.assert_called_once_with(expected_key) + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_exceeded_raises_exception( + self, mock_time, mock_redis, mock_feature_service, mock_session_factory + ): + """ + Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised. + + This test simulates a tenant exceeding their rate limit. When the count + of recent requests exceeds the limit, an exception should be raised and + a RateLimitLog should be created. + + Verifies: + - Redis zcard returns count exceeding limit + - RateLimitExceededError is raised with correct message + - RateLimitLog is created in database + - Session operations are performed correctly + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 + mock_time.time.return_value = current_time / 1000 + + # Mock Redis operations - return count exceeding limit + mock_redis.zcard.return_value = 150 # Exceeds limit of 100 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert + with pytest.raises(exc.RateLimitExceededError) as exc_info: + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify exception message + assert "knowledge base request rate limit" in str(exc_info.value) + + # Verify RateLimitLog was created + mock_session.add.assert_called_once() + added_log = mock_session.add.call_args[0][0] + assert added_log.tenant_id == tenant_id + assert added_log.subscription_plan == "professional" + assert added_log.operation == "knowledge" + + +# ==================== Test _get_available_datasets ==================== + + +class TestGetAvailableDatasets: + """ + Test suite for _get_available_datasets method. + + The _get_available_datasets method retrieves datasets that are available + for retrieval. A dataset is considered available if: + - It belongs to the specified tenant + - It's in the list of requested dataset_ids + - It has at least one completed, enabled, non-archived document OR + - It's an external provider dataset + + Note: Due to SQLAlchemy subquery complexity, full testing is done in + integration tests. Unit tests here verify basic behavior. + """ + + def test_method_exists_and_has_correct_signature(self): + """ + Test that the method exists and has the correct signature. + + Verifies: + - Method exists on DatasetRetrieval class + - Accepts tenant_id and dataset_ids parameters + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + + # Assert - method exists + assert hasattr(dataset_retrieval, "_get_available_datasets") + # Assert - method is callable + assert callable(dataset_retrieval._get_available_datasets) + + +# ==================== Test knowledge_retrieval ==================== + + +class TestDatasetRetrievalKnowledgeRetrieval: + """ + Test suite for knowledge_retrieval method. + + The knowledge_retrieval method is the main entry point for retrieving + knowledge from datasets. It orchestrates the entire retrieval process: + 1. Checks rate limits + 2. Gets available datasets + 3. Applies metadata filtering if enabled + 4. Performs retrieval (single or multiple mode) + 5. Formats and returns results + + Test Cases: + ============ + 1. Single mode retrieval + 2. Multiple mode retrieval + 3. Metadata filtering disabled + 4. Metadata filtering automatic + 5. Metadata filtering manual + 6. External documents handling + 7. Dify documents handling + 8. Empty results handling + 9. Rate limit exceeded + 10. No available datasets + """ + + def test_knowledge_retrieval_single_mode_basic(self): + """ + Test knowledge_retrieval in single retrieval mode - basic check. + + Note: Full single mode testing requires complex model mocking and + is better suited for integration tests. This test verifies the + method accepts single mode requests. + + Verifies: + - Method can accept single mode request + - Request parameters are correctly structured + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + model_mode="chat", + completion_params={"temperature": 0.7}, + ) + + # Assert - request is properly structured + assert request.retrieval_mode == "single" + assert request.model_provider == "openai" + assert request.model_name == "gpt-4" + assert request.model_mode == "chat" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor): + """ + Test knowledge_retrieval in multiple retrieval mode. + + In multiple mode, retrieval is performed across all datasets and + results are combined and reranked. + + Verifies: + - Rate limit is checked + - Available datasets are retrieved + - Multiple retrieval is performed + - Results are combined and reranked + - Results are formatted correctly + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id1 = str(uuid4()) + dataset_id2 = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id1, dataset_id2], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + score_threshold=0.7, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets + mock_dataset1 = create_mock_dataset_methods(dataset_id=dataset_id1, tenant_id=tenant_id) + mock_dataset2 = create_mock_dataset_methods(dataset_id=dataset_id2, tenant_id=tenant_id) + with patch.object( + dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2] + ): + # Mock get_metadata_filter_condition + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return documents + doc1 = create_mock_document_methods("Python is great", "doc1", score=0.9) + doc2 = create_mock_document_methods("Python is awesome", "doc2", score=0.8) + with patch.object( + dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2] + ) as mock_multiple_retrieve: + # Mock format_retrieval_documents + mock_record = Mock() + mock_record.segment = Mock() + mock_record.segment.dataset_id = dataset_id1 + mock_record.segment.document_id = str(uuid4()) + mock_record.segment.index_node_hash = "hash123" + mock_record.segment.hit_count = 5 + mock_record.segment.word_count = 100 + mock_record.segment.position = 1 + mock_record.segment.get_sign_content.return_value = "Python is great" + mock_record.segment.answer = None + mock_record.score = 0.9 + mock_record.child_chunks = [] + mock_record.summary = None + mock_record.files = None + + mock_retrieval_service = Mock() + mock_retrieval_service.format_retrieval_documents.return_value = [mock_record] + + with patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService", + return_value=mock_retrieval_service, + ): + # Mock database queries + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + mock_dataset_from_db = Mock() + mock_dataset_from_db.id = dataset_id1 + mock_dataset_from_db.name = "test_dataset" + + mock_document = Mock() + mock_document.id = str(uuid4()) + mock_document.name = "test_doc" + mock_document.data_source_type = "upload_file" + mock_document.doc_metadata = {} + + mock_session.query.return_value.filter.return_value.all.return_value = [ + mock_dataset_from_db + ] + mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( + [mock_dataset_from_db, mock_document] + ) + + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + mock_multiple_retrieve.assert_called_once() + + def test_knowledge_retrieval_metadata_filtering_disabled(self): + """ + Test knowledge_retrieval with metadata filtering disabled. + + When metadata filtering is disabled, get_metadata_filter_condition is + NOT called (the method checks metadata_filtering_mode != "disabled"). + + Verifies: + - get_metadata_filter_condition is NOT called when mode is "disabled" + - Retrieval proceeds without metadata filters + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + metadata_filtering_mode="disabled", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + # Mock get_metadata_filter_condition - should NOT be called when disabled + with patch.object( + dataset_retrieval, + "get_metadata_filter_condition", + return_value=(None, None), + ) as mock_get_metadata: + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + # get_metadata_filter_condition should NOT be called when mode is "disabled" + mock_get_metadata.assert_not_called() + + def test_knowledge_retrieval_with_external_documents(self): + """ + Test knowledge_retrieval with external documents. + + External documents come from external knowledge bases and should + be formatted differently than Dify documents. + + Verifies: + - External documents are handled correctly + - Provider is set to "external" + - Metadata includes external-specific fields + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id, provider="external") + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Create external document + external_doc = create_mock_document_methods( + "External knowledge", + "doc1", + score=0.9, + provider="external", + additional_metadata={ + "dataset_id": dataset_id, + "dataset_name": "external_kb", + "document_id": "ext_doc1", + "title": "External Document", + }, + ) + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + if result: + assert result[0].metadata.data_source_type == "external" + + def test_knowledge_retrieval_empty_results(self): + """ + Test knowledge_retrieval when no documents are found. + + Verifies: + - Empty list is returned + - No errors are raised + - All dependencies are still called + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset_methods(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return empty list + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded(self): + """ + Test knowledge_retrieval when rate limit is exceeded. + + Verifies: + - RateLimitExceededError is raised + - No further processing occurs + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=exc.RateLimitExceededError("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(exc.RateLimitExceededError): + dataset_retrieval.knowledge_retrieval(request) + + def test_knowledge_retrieval_no_available_datasets(self): + """ + Test knowledge_retrieval when no datasets are available. + + Verifies: + - Empty list is returned + - No retrieval is attempted + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets to return empty list + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self): + """ + Test that knowledge_retrieval processes multiple documents with different scores. + + Note: Full sorting and position testing requires complex SQLAlchemy mocking + which is better suited for integration tests. This test verifies documents + with different scores can be created and have their metadata. + + Verifies: + - Documents can be created with different scores + - Score metadata is properly set + """ + # Create documents with different scores + doc1 = create_mock_document_methods("Low score", "doc1", score=0.6) + doc2 = create_mock_document_methods("High score", "doc2", score=0.95) + doc3 = create_mock_document_methods("Medium score", "doc3", score=0.8) + + # Assert - each document has the correct score + assert doc1.metadata["score"] == 0.6 + assert doc2.metadata["score"] == 0.95 + assert doc3.metadata["score"] == 0.8 + + # Assert - documents are correctly sorted (not the retrieval result, just the list) + unsorted = [doc1, doc2, doc3] + sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True) + assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6] + + +class TestProcessMetadataFilterFunc: + """ + Comprehensive test suite for process_metadata_filter_func method. + + This test class validates all metadata filtering conditions supported by + the DatasetRetrieval class, including string operations, numeric comparisons, + null checks, and list operations. + + Method Signature: + ================== + def process_metadata_filter_func( + self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list + ) -> list: + + The method builds SQLAlchemy filter expressions by: + 1. Validating value is not None (except for empty/not empty conditions) + 2. Using DatasetDocument.doc_metadata JSON field operations + 3. Adding appropriate SQLAlchemy expressions to the filters list + 4. Returning the updated filters list + + Mocking Strategy: + ================== + - Mock DatasetDocument.doc_metadata to avoid database dependencies + - Verify filter expressions are created correctly + - Test with various data types (str, int, float, list) + """ + + @pytest.fixture + def retrieval(self): + """ + Create a DatasetRetrieval instance for testing. + + Returns: + DatasetRetrieval: Instance to test process_metadata_filter_func + """ + return DatasetRetrieval() + + @pytest.fixture + def mock_doc_metadata(self): + """ + Mock the DatasetDocument.doc_metadata JSON field. + + The method uses DatasetDocument.doc_metadata[metadata_name] to access + JSON fields. We mock this to avoid database dependencies. + + Returns: + Mock: Mocked doc_metadata attribute + """ + mock_metadata_field = MagicMock() + + # Create mock for string access + mock_string_access = MagicMock() + mock_string_access.like = MagicMock() + mock_string_access.notlike = MagicMock() + mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_string_access.in_ = MagicMock(return_value=MagicMock()) + + # Create mock for float access (for numeric comparisons) + mock_float_access = MagicMock() + mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) + mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) + mock_float_access.__le__ = MagicMock(return_value=MagicMock()) + mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) + + # Create mock for null checks + mock_null_access = MagicMock() + mock_null_access.is_ = MagicMock(return_value=MagicMock()) + mock_null_access.isnot = MagicMock(return_value=MagicMock()) + + # Setup __getitem__ to return appropriate mock based on usage + def getitem_side_effect(name): + if name in ["author", "title", "category"]: + return mock_string_access + elif name in ["year", "price", "rating"]: + return mock_float_access + else: + return mock_string_access + + mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) + mock_metadata_field.as_string.return_value = mock_string_access + mock_metadata_field.as_float.return_value = mock_float_access + mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ + mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot + + return mock_metadata_field + + # ==================== String Condition Tests ==================== + + def test_contains_condition_string_value(self, retrieval): + """ + Test 'contains' condition with string value. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value% syntax + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "John" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_contains_condition(self, retrieval): + """ + Test 'not contains' condition. + + Verifies: + - Filters list is populated with NOT LIKE expression + - Pattern matching uses %value% syntax with negation + """ + filters = [] + sequence = 0 + condition = "not contains" + metadata_name = "title" + value = "banned" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_start_with_condition(self, retrieval): + """ + Test 'start with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses value% syntax + """ + filters = [] + sequence = 0 + condition = "start with" + metadata_name = "category" + value = "tech" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_end_with_condition(self, retrieval): + """ + Test 'end with' condition. + + Verifies: + - Filters list is populated with LIKE expression + - Pattern matching uses %value syntax + """ + filters = [] + sequence = 0 + condition = "end with" + metadata_name = "filename" + value = ".pdf" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Equality Condition Tests ==================== + + def test_is_condition_with_string_value(self, retrieval): + """ + Test 'is' (=) condition with string value. + + Verifies: + - Filters list is populated with equality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = "Jane Doe" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_equals_condition_with_string_value(self, retrieval): + """ + Test '=' condition with string value. + + Verifies: + - Same behavior as 'is' condition + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "=" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_int_value(self, retrieval): + """ + Test 'is' condition with integer value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_condition_with_float_value(self, retrieval): + """ + Test 'is' condition with float value. + + Verifies: + - Numeric comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "price" + value = 19.99 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_string_value(self, retrieval): + """ + Test 'is not' (≠) condition with string value. + + Verifies: + - Filters list is populated with inequality expression + - String comparison is used + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "author" + value = "Unknown" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_equals_condition(self, retrieval): + """ + Test '≠' condition with string value. + + Verifies: + - Same behavior as 'is not' condition + - Inequality expression is used + """ + filters = [] + sequence = 0 + condition = "≠" + metadata_name = "category" + value = "archived" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_is_not_condition_with_numeric_value(self, retrieval): + """ + Test 'is not' condition with numeric value. + + Verifies: + - Numeric inequality comparison is used + - as_float() is called on the metadata field + """ + filters = [] + sequence = 0 + condition = "is not" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Null Condition Tests ==================== + + def test_empty_condition(self, retrieval): + """ + Test 'empty' condition (null check). + + Verifies: + - Filters list is populated with IS NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "empty" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_not_empty_condition(self, retrieval): + """ + Test 'not empty' condition (not null check). + + Verifies: + - Filters list is populated with IS NOT NULL expression + - Value can be None for this condition + """ + filters = [] + sequence = 0 + condition = "not empty" + metadata_name = "description" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Numeric Comparison Tests ==================== + + def test_before_condition(self, retrieval): + """ + Test 'before' (<) condition. + + Verifies: + - Filters list is populated with less than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "before" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_condition(self, retrieval): + """ + Test '<' condition. + + Verifies: + - Same behavior as 'before' condition + - Less than expression is used + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "price" + value = 100.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_after_condition(self, retrieval): + """ + Test 'after' (>) condition. + + Verifies: + - Filters list is populated with greater than expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "after" + metadata_name = "year" + value = 2020 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_condition(self, retrieval): + """ + Test '>' condition. + + Verifies: + - Same behavior as 'after' condition + - Greater than expression is used + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≤' condition. + + Verifies: + - Filters list is populated with less than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≤" + metadata_name = "price" + value = 50.0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_less_than_or_equal_condition_ascii(self, retrieval): + """ + Test '<=' condition. + + Verifies: + - Same behavior as '≤' condition + - Less than or equal expression is used + """ + filters = [] + sequence = 0 + condition = "<=" + metadata_name = "year" + value = 2023 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_unicode(self, retrieval): + """ + Test '≥' condition. + + Verifies: + - Filters list is populated with greater than or equal expression + - Numeric comparison is used + """ + filters = [] + sequence = 0 + condition = "≥" + metadata_name = "rating" + value = 3.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_greater_than_or_equal_condition_ascii(self, retrieval): + """ + Test '>=' condition. + + Verifies: + - Same behavior as '≥' condition + - Greater than or equal expression is used + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "year" + value = 2000 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== List/In Condition Tests ==================== + + def test_in_condition_with_comma_separated_string(self, retrieval): + """ + Test 'in' condition with comma-separated string value. + + Verifies: + - String is split into list + - Whitespace is trimmed from each value + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "tech, science, AI " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_list_value(self, retrieval): + """ + Test 'in' condition with list value. + + Verifies: + - List is processed correctly + - None values are filtered out + - IN expression is created with valid values + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "tags" + value = ["python", "javascript", None, "golang"] + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_tuple_value(self, retrieval): + """ + Test 'in' condition with tuple value. + + Verifies: + - Tuple is processed like a list + - IN expression is created + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = ("tech", "science", "ai") + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_empty_string(self, retrieval): + """ + Test 'in' condition with empty string value. + + Verifies: + - Empty string results in literal(False) filter + - No valid values to match + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + # Verify it's a literal(False) expression + # This is a bit tricky to test without access to the actual expression + + def test_in_condition_with_only_whitespace(self, retrieval): + """ + Test 'in' condition with whitespace-only string value. + + Verifies: + - Whitespace-only string results in literal(False) filter + - All values are stripped and filtered out + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = " , , " + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_in_condition_with_single_string(self, retrieval): + """ + Test 'in' condition with single non-comma string. + + Verifies: + - Single string is treated as single-item list + - IN expression is created with one value + """ + filters = [] + sequence = 0 + condition = "in" + metadata_name = "category" + value = "technology" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + # ==================== Edge Case Tests ==================== + + def test_none_value_with_non_empty_condition(self, retrieval): + """ + Test None value with conditions that require value. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values (except empty/not empty) + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 # No filter added + + def test_none_value_with_equals_condition(self, retrieval): + """ + Test None value with 'is' (=) condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = "is" + metadata_name = "author" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_none_value_with_numeric_condition(self, retrieval): + """ + Test None value with numeric comparison condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for None values + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "year" + value = None + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_existing_filters_preserved(self, retrieval): + """ + Test that existing filters are preserved. + + Verifies: + - Existing filters in the list are not removed + - New filters are appended to the list + """ + existing_filter = MagicMock() + filters = [existing_filter] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 2 + assert filters[0] == existing_filter + + def test_multiple_filters_accumulated(self, retrieval): + """ + Test multiple calls to accumulate filters. + + Verifies: + - Each call adds a new filter to the list + - All filters are preserved across calls + """ + filters = [] + + # First filter + retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) + assert len(filters) == 1 + + # Second filter + retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) + assert len(filters) == 2 + + # Third filter + retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) + assert len(filters) == 3 + + def test_unknown_condition(self, retrieval): + """ + Test unknown/unsupported condition. + + Verifies: + - Original filters list is returned unchanged + - No filter is added for unknown conditions + """ + filters = [] + sequence = 0 + condition = "unknown_condition" + metadata_name = "author" + value = "test" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 0 + + def test_empty_string_value_with_contains(self, retrieval): + """ + Test empty string value with 'contains' condition. + + Verifies: + - Filter is added even with empty string + - LIKE expression is created + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "author" + value = "" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_special_characters_in_value(self, retrieval): + """ + Test special characters in value string. + + Verifies: + - Special characters are handled in value + - LIKE expression is created correctly + """ + filters = [] + sequence = 0 + condition = "contains" + metadata_name = "title" + value = "C++ & Python's features" + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_zero_value_with_numeric_condition(self, retrieval): + """ + Test zero value with numeric comparison condition. + + Verifies: + - Zero is treated as valid value + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">" + metadata_name = "price" + value = 0 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_negative_value_with_numeric_condition(self, retrieval): + """ + Test negative value with numeric comparison condition. + + Verifies: + - Negative numbers are handled correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = "<" + metadata_name = "temperature" + value = -10.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + def test_float_value_with_integer_comparison(self, retrieval): + """ + Test float value with numeric comparison condition. + + Verifies: + - Float values work correctly + - Numeric comparison is performed + """ + filters = [] + sequence = 0 + condition = ">=" + metadata_name = "rating" + value = 4.5 + + result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) + + assert result == filters + assert len(filters) == 1 + + +class TestKnowledgeRetrievalRegression: + @pytest.fixture + def mock_dataset(self) -> Dataset: + dataset = Mock(spec=Dataset) + dataset.id = str(uuid4()) + dataset.tenant_id = str(uuid4()) + dataset.name = "test_dataset" + dataset.indexing_technique = "high_quality" + dataset.provider = "dify" + return dataset + + def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): + """ + Repro test for current bug: + reranking runs after `with flask_app.app_context():` exits. + `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, + so we must assert from that list (not from an outer try/except). + """ + dataset_retrieval = DatasetRetrieval() + flask_app = Flask(__name__) + tenant_id = str(uuid4()) + + # second dataset to ensure dataset_count > 1 reranking branch + secondary_dataset = Mock(spec=Dataset) + secondary_dataset.id = str(uuid4()) + secondary_dataset.provider = "dify" + secondary_dataset.indexing_technique = "high_quality" + + # retriever returns 1 doc into internal list (all_documents_item) + document = Document( + page_content="Context aware doc", + metadata={ + "doc_id": "doc1", + "score": 0.95, + "document_id": str(uuid4()), + "dataset_id": mock_dataset.id, + }, + provider="dify", + ) + + def fake_retriever( + flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids + ): + all_documents.append(document) + + called = {"init": 0, "invoke": 0} + + class ContextRequiredPostProcessor: + def __init__(self, *args, **kwargs): + called["init"] += 1 + # will raise RuntimeError if no Flask app context exists + _ = current_app.name + + def invoke(self, *args, **kwargs): + called["invoke"] += 1 + _ = current_app.name + return kwargs.get("documents") or args[1] + + # output list from _multiple_retrieve_thread + all_documents: list[Document] = [] + + # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here + thread_exceptions: list[Exception] = [] + + def target(): + with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): + with patch( + "core.rag.retrieval.dataset_retrieval.DataPostProcessor", + ContextRequiredPostProcessor, + ): + dataset_retrieval._multiple_retrieve_thread( + flask_app=flask_app, + available_datasets=[mock_dataset, secondary_dataset], + metadata_condition=None, + metadata_filter_document_ids=None, + all_documents=all_documents, + tenant_id=tenant_id, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={ + "reranking_provider_name": "cohere", + "reranking_model_name": "rerank-v2", + }, + weights=None, + top_k=3, + score_threshold=0.0, + query="test query", + attachment_id=None, + dataset_count=2, # force reranking branch + thread_exceptions=thread_exceptions, # ✅ key + ) + + t = threading.Thread(target=target) + t.start() + t.join() + + # Ensure reranking branch was actually executed + assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." + + # Current buggy code should record an exception (not raise it) + assert not thread_exceptions, thread_exceptions + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class TestDatasetRetrievalAdditionalHelpers: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_llm_usage_and_record_usage(self, retrieval: DatasetRetrieval) -> None: + empty_usage = retrieval.llm_usage + assert empty_usage.total_tokens == 0 + + retrieval._record_usage(None) + assert retrieval.llm_usage.total_tokens == 0 + + usage_1 = LLMUsage.from_metadata({"prompt_tokens": 2, "completion_tokens": 3, "total_tokens": 5}) + usage_2 = LLMUsage.from_metadata({"prompt_tokens": 4, "completion_tokens": 1, "total_tokens": 5}) + retrieval._record_usage(usage_1) + retrieval._record_usage(usage_2) + assert retrieval.llm_usage.total_tokens == 10 + + def test_replace_metadata_filter_value(self, retrieval: DatasetRetrieval) -> None: + assert retrieval._replace_metadata_filter_value("plain", {}) == "plain" + replaced = retrieval._replace_metadata_filter_value( + "hello {{name}}\n\t{{missing}}", + {"name": "world"}, + ) + assert replaced == "hello world {{missing}}" + + def test_process_metadata_filter_in_with_scalar_fallback(self) -> None: + filters: list = [] + result = DatasetRetrieval.process_metadata_filter_func( + sequence=0, + condition="in", + metadata_name="category", + value=123, + filters=filters, + ) + assert result is filters + assert len(filters) == 1 + + def test_calculate_vector_score(self, retrieval: DatasetRetrieval) -> None: + doc_high = Document(page_content="a", metadata={"score": 0.9}, provider="dify") + doc_low = Document(page_content="b", metadata={"score": 0.2}, provider="dify") + doc_no_meta = Document(page_content="c", metadata={}, provider="dify") + + filtered = retrieval.calculate_vector_score([doc_low, doc_high, doc_no_meta], top_k=1, score_threshold=0.5) + assert len(filtered) == 1 + assert filtered[0].metadata["score"] == 0.9 + + assert retrieval.calculate_vector_score([doc_low], top_k=2, score_threshold=1.0) == [] + + def test_calculate_keyword_score(self, retrieval: DatasetRetrieval) -> None: + documents = [ + Document(page_content="python language", metadata={"doc_id": "1"}, provider="dify"), + Document(page_content="java language", metadata={"doc_id": "2"}, provider="dify"), + ] + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [ + ["python", "language"], + ["python", "language"], + ["java", "language"], + ] + + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("python language", documents, top_k=1) + + assert len(ranked) == 1 + assert "keywords" in ranked[0].metadata + assert ranked[0].metadata["doc_id"] == "1" + + def test_send_trace_task(self, retrieval: DatasetRetrieval) -> None: + trace_manager = Mock() + retrieval.application_generate_entity = SimpleNamespace(trace_manager=trace_manager) + docs = [Document(page_content="d", metadata={}, provider="dify")] + + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_called_once() + + retrieval.application_generate_entity = None + trace_manager.reset_mock() + retrieval._send_trace_task("m1", docs, {"cost": 1}) + trace_manager.add_trace_task.assert_not_called() + + def test_on_query(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query=None, + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="account", + user_id="u1", + ) + mock_session.add_all.assert_not_called() + + retrieval._on_query( + query="python", + attachment_ids=["f1"], + dataset_ids=["d1", "d2"], + app_id="a1", + user_from="account", + user_id="u1", + ) + mock_session.add_all.assert_called() + mock_session.commit.assert_called() + + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: + usage = LLMUsage.empty_usage() + chunk_1 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace(message=SimpleNamespace(content="hello "), usage=usage), + ) + chunk_2 = SimpleNamespace( + model="m1", + prompt_messages=[Mock()], + delta=SimpleNamespace( + message=SimpleNamespace(content=[SimpleNamespace(data="world")]), + usage=None, + ), + ) + text, returned_usage = retrieval._handle_invoke_result(iter([chunk_1, chunk_2])) + assert text == "hello world" + assert returned_usage == usage + + text_empty, usage_empty = retrieval._handle_invoke_result(iter([])) + assert text_empty == "" + assert usage_empty == LLMUsage.empty_usage() + + def test_get_prompt_template(self, retrieval: DatasetRetrieval) -> None: + model_config_chat = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=["x"], + ) + model_config_completion = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="completion", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + + with patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform") as mock_prompt_transform: + mock_prompt_transform.return_value.get_prompt.return_value = ["prompt"] + prompt_messages, stop = retrieval._get_prompt_template( + model_config=model_config_chat, + mode="chat", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages == ["prompt"] + assert stop == ["x"] + + with patch( + "core.rag.retrieval.dataset_retrieval.METADATA_FILTER_COMPLETION_PROMPT", + "{input_text} {metadata_fields}", + ): + prompt_messages_completion, stop_completion = retrieval._get_prompt_template( + model_config=model_config_completion, + mode="completion", + metadata_fields=["author"], + query="python", + ) + assert prompt_messages_completion == ["prompt"] + assert stop_completion == [] + + with pytest.raises(ValueError): + retrieval._get_prompt_template( + model_config=model_config_chat, + mode="unknown-mode", + metadata_fields=[], + query="python", + ) + + def test_fetch_model_config_validation_and_success(self, retrieval: DatasetRetrieval) -> None: + with pytest.raises(ValueError, match="single_retrieval_config is required"): + retrieval._fetch_model_config("tenant-1", None) # type: ignore[arg-type] + + model_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat", completion_params={"stop": ["END"]}) + model_instance = Mock() + model_instance.credentials = {"k": "v"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance = Mock() + model_instance.model_type_instance.get_model_schema.return_value = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance + mock_cfg_entity.return_value = SimpleNamespace( + provider="openai", + model="gpt", + stop=["END"], + parameters={"temperature": 0.1}, + ) + + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model = SimpleNamespace(status=ModelStatus.NO_CONFIGURE) + model_instance.provider_model_bundle.configuration.get_provider_model.return_value = provider_model + with pytest.raises(ValueError, match="credentials is not initialized"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.NO_PERMISSION + with pytest.raises(ValueError, match="currently not support"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.QUOTA_EXCEEDED + with pytest.raises(ValueError, match="quota exceeded"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + provider_model.status = ModelStatus.ACTIVE + bad_mode_cfg = AppModelConfig(provider="openai", name="gpt", mode="chat") + bad_mode_cfg.mode = None # type: ignore[assignment] + with pytest.raises(ValueError, match="LLM mode is required"): + retrieval._fetch_model_config("tenant-1", bad_mode_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = None + with pytest.raises(ValueError, match="not exist"): + retrieval._fetch_model_config("tenant-1", model_cfg) + + model_instance.model_type_instance.get_model_schema.return_value = Mock() + model_cfg_success = AppModelConfig( + provider="openai", + name="gpt", + mode="chat", + completion_params={"temperature": 0.1, "stop": ["END"]}, + ) + _, config = retrieval._fetch_model_config("tenant-1", model_cfg_success) + assert config.provider == "openai" + assert config.model == "gpt" + assert config.stop == ["END"] + assert "stop" not in config.parameters + + def test_automatic_metadata_filter_func(self, retrieval: DatasetRetrieval) -> None: + metadata_field = SimpleNamespace(name="author") + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([Mock()]) + model_config = ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + session_scalars = Mock() + session_scalars.all.return_value = [metadata_field] + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, model_config)), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_handle_invoke_result", return_value=('{"metadata_map":[]}', usage)), + patch("core.rag.retrieval.dataset_retrieval.parse_and_check_json_markdown") as mock_parse, + patch.object(retrieval, "_record_usage") as mock_record_usage, + ): + mock_parse.return_value = { + "metadata_map": [ + { + "metadata_field_name": "author", + "metadata_field_value": "Alice", + "comparison_operator": "contains", + }, + { + "metadata_field_name": "ignored", + "metadata_field_value": "value", + "comparison_operator": "contains", + }, + ] + } + result = retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + assert result == [{"metadata_name": "author", "value": "Alice", "condition": "contains"}] + mock_record_usage.assert_called_once_with(usage) + + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", side_effect=RuntimeError("boom")), + ): + with pytest.raises(RuntimeError, match="boom"): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + ) + + def test_get_metadata_filter_condition(self, retrieval: DatasetRetrieval) -> None: + db_query = Mock() + db_query.where.return_value = db_query + db_query.all.return_value = [SimpleNamespace(dataset_id="d1", id="doc-1")] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="disabled", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + assert mapping is None + assert condition is None + + automatic_filters = [{"condition": "contains", "metadata_name": "author", "value": "Alice"}] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query), + patch.object(retrieval, "_automatic_metadata_filter_func", return_value=automatic_filters), + ): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="automatic", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=AppMetadataFilteringCondition(logical_operator="or", conditions=[]), + inputs={}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.logical_operator == "or" + + manual_conditions = AppMetadataFilteringCondition( + logical_operator="and", + conditions=[AppCondition(name="author", comparison_operator="contains", value="{{name}}")], + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + mapping, condition = retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="manual", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=manual_conditions, + inputs={"name": "Alice"}, + ) + assert mapping == {"d1": ["doc-1"]} + assert condition is not None + assert condition.conditions[0].value == "Alice" + + with patch("core.rag.retrieval.dataset_retrieval.db.session.query", return_value=db_query): + with pytest.raises(ValueError, match="Invalid metadata filtering mode"): + retrieval.get_metadata_filter_condition( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="u1", + metadata_filtering_mode="unsupported", + metadata_model_config=AppModelConfig(provider="openai", name="gpt", mode="chat"), + metadata_filtering_conditions=None, + inputs={}, + ) + + def test_get_available_datasets(self, retrieval: DatasetRetrieval) -> None: + session = Mock() + subquery_query = Mock() + subquery_query.where.return_value = subquery_query + subquery_query.group_by.return_value = subquery_query + subquery_query.having.return_value = subquery_query + subquery_query.subquery.return_value = SimpleNamespace( + c=SimpleNamespace( + dataset_id=column("dataset_id"), available_document_count=column("available_document_count") + ) + ) + + dataset_query = Mock() + dataset_query.outerjoin.return_value = dataset_query + dataset_query.where.return_value = dataset_query + dataset_query.all.return_value = [SimpleNamespace(id="d1"), None, SimpleNamespace(id="d2")] + session.query.side_effect = [subquery_query, dataset_query] + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx): + available = retrieval._get_available_datasets("tenant-1", ["d1", "d2"]) + + assert [dataset.id for dataset in available] == ["d1", "d2"] + + def test_check_knowledge_rate_limit(self, retrieval: DatasetRetrieval) -> None: + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=2, subscription_plan="pro") + mock_redis.zcard.return_value = 1 + retrieval._check_knowledge_rate_limit("tenant-1") + mock_redis.zadd.assert_called_once() + + session = Mock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit, + patch("core.rag.retrieval.dataset_retrieval.redis_client") as mock_redis, + patch("core.rag.retrieval.dataset_retrieval.time.time", return_value=100.0), + patch("core.rag.retrieval.dataset_retrieval.session_factory.create_session", return_value=session_ctx), + ): + mock_limit.return_value = SimpleNamespace(enabled=True, limit=1, subscription_plan="pro") + mock_redis.zcard.return_value = 2 + with pytest.raises(exc.RateLimitExceededError): + retrieval._check_knowledge_rate_limit("tenant-1") + session.add.assert_called_once() + + with patch("core.rag.retrieval.dataset_retrieval.FeatureService.get_knowledge_rate_limit") as mock_limit: + mock_limit.return_value = SimpleNamespace(enabled=False) + retrieval._check_knowledge_rate_limit("tenant-1") + + +def _doc( + provider: str = "dify", + content: str = "content", + score: float = 0.9, + dataset_id: str = "dataset-1", + document_id: str = "document-1", + doc_id: str = "node-1", + extra: dict | None = None, +) -> Document: + metadata = { + "score": score, + "dataset_id": dataset_id, + "document_id": document_id, + "doc_id": doc_id, + } + if extra: + metadata.update(extra) + return Document(page_content=content, metadata=metadata, provider=provider) + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._alive = False + + def start(self) -> None: + self._alive = True + if self._target: + self._target(**self._kwargs) + self._alive = False + + def join(self, timeout=None) -> None: + return None + + def is_alive(self) -> bool: + return self._alive + + +class _JoinDrivenThread: + def __init__(self, target=None, kwargs=None): + self._target = target + self._kwargs = kwargs or {} + self._started = False + self._alive = False + + def start(self) -> None: + self._started = True + self._alive = True + + def join(self, timeout=None) -> None: + if self._started and self._alive and self._target: + self._target(**self._kwargs) + self._alive = False + + def is_alive(self) -> bool: + return self._alive + + +@contextmanager +def _timer(): + yield {"cost": 1} + + +class TestKnowledgeRetrievalCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_returns_empty_when_query_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query=None, + retrieval_mode="multiple", + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + assert retrieval.knowledge_retrieval(request) == [] + + def test_raises_when_metadata_model_config_missing(self, retrieval: DatasetRetrieval) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["d1"], + query="query", + retrieval_mode="multiple", + metadata_filtering_mode="automatic", + metadata_model_config=None, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + ): + with pytest.raises(ValueError, match="metadata_model_config is required"): + retrieval.knowledge_retrieval(request) + + @pytest.mark.parametrize( + ("status", "error_cls"), + [ + (ModelStatus.NO_CONFIGURE, "ModelCredentialsNotInitializedError"), + (ModelStatus.NO_PERMISSION, "ModelNotSupportedError"), + (ModelStatus.QUOTA_EXCEEDED, "ModelQuotaExceededError"), + ], + ) + def test_single_mode_raises_for_model_status( + self, + retrieval: DatasetRetrieval, + status: ModelStatus, + error_cls: str, + ) -> None: + request = KnowledgeRetrievalRequest( + tenant_id="tenant-1", + user_id="user-1", + app_id="app-1", + user_from="workflow", + dataset_ids=["dataset-1"], + query="python", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + ) + provider_model_bundle = Mock() + provider_model_bundle.configuration.get_provider_model.return_value = SimpleNamespace(status=status) + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = Mock() + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + credentials={}, + ) + with ( + patch.object(retrieval, "_check_knowledge_rate_limit"), + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + ): + mock_model_manager.return_value.get_model_instance.return_value = model_instance + with pytest.raises(Exception) as exc_info: + retrieval.knowledge_retrieval(request) + assert error_cls in type(exc_info.value).__name__ + + +class TestRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def _build_model_config(self, features: list[ModelFeature] | None = None): + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = SimpleNamespace(features=features or []) + provider_bundle = SimpleNamespace(model_type_instance=model_type_instance) + return ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt-4", + model_schema=Mock(), + mode="chat", + provider_model_bundle=provider_bundle, + credentials={}, + parameters={}, + stop=[], + ) + + def test_returns_none_when_dataset_ids_empty(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=[], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=self._build_model_config(), + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_returns_none_when_model_schema_missing(self, retrieval: DatasetRetrieval) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + ), + ) + model_config = self._build_model_config() + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = Mock() + result = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert result == (None, []) + + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config() + external_doc = _doc( + provider="external", + content="external content", + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External", "dataset_name": "External DS"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[external_doc]), + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + assert context == "external content" + assert files == [] + + def test_multiple_strategy_with_vision_and_source_details(self, retrieval: DatasetRetrieval) -> None: + retrieve_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=4, + score_threshold=0.1, + rerank_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + reranking_enabled=True, + metadata_filtering_mode="disabled", + ) + config = DatasetEntity(dataset_ids=["d1"], retrieve_config=retrieve_config) + model_config = self._build_model_config(features=[ModelFeature.TOOL_CALL]) + external_doc = _doc( + provider="external", + content="external body", + score=0.8, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"title": "External Title", "dataset_name": "External DS"}, + ) + dify_doc = _doc( + provider="dify", + content="dify body", + score=0.9, + dataset_id="d1", + document_id="doc-1", + doc_id="node-1", + ) + record = SimpleNamespace( + segment=SimpleNamespace( + id="segment-1", + dataset_id="d1", + document_id="doc-1", + tenant_id="tenant-1", + hit_count=3, + word_count=11, + position=1, + index_node_hash="hash-1", + content="segment content", + answer="segment answer", + get_sign_content=lambda: "segment content", + ), + score=0.9, + summary="short summary", + files=None, + ) + dataset_item = SimpleNamespace(id="d1", name="Dataset One") + document_item = SimpleNamespace( + id="doc-1", + name="Document One", + data_source_type="upload_file", + doc_metadata={"lang": "en"}, + ) + upload_file = SimpleNamespace( + id="file-1", + name="image", + extension="png", + mime_type="image/png", + source_url="https://example.com/img.png", + size=123, + key="k1", + ) + execute_attachments = SimpleNamespace(all=lambda: [(SimpleNamespace(), upload_file)]) + execute_docs = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [document_item])) + execute_datasets = SimpleNamespace(scalars=lambda: SimpleNamespace(all=lambda: [dataset_item])) + hit_callback = Mock() + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.format_retrieval_documents", + return_value=[record], + ), + patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), + patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, + ): + mock_model_manager.return_value.get_model_instance.return_value = Mock() + mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.DEBUGGER, + show_retrieve_source=True, + hit_callback=hit_callback, + message_id="m1", + vision_enabled=True, + ) + + assert "short summary" in (context or "") + assert "question:segment content answer:segment answer" in (context or "") + assert len(files or []) == 1 + hit_callback.return_retriever_resource_info.assert_called_once() + + +class TestSingleAndMultipleRetrieveCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_single_retrieve_external_path(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="External DS", + description=None, + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2}) + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end") as mock_end, + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + mock_external.return_value = [ + {"content": "ext result", "metadata": {"k": "v"}, "score": 0.9, "title": "Ext Doc"} + ] + result = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + message_id="m1", + ) + + assert len(result) == 1 + assert result[0].provider == "external" + mock_end.assert_called_once() + assert retrieval.llm_usage.total_tokens == 2 + + def test_single_retrieve_dify_path_and_filters(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="dataset desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "cohere", "reranking_model_name": "rerank"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + "top_k": 3, + "score_threshold_enabled": True, + "score_threshold": 0.2, + }, + ) + app = Flask(__name__) + usage = LLMUsage.from_metadata({"prompt_tokens": 1, "completion_tokens": 0, "total_tokens": 1}) + result_doc = _doc(provider="dify", score=0.7, dataset_id="ds-1", document_id="doc-1", doc_id="node-1") + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.FunctionCallMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[result_doc] + ) as mock_retrieve, + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_on_retrieval_end"), + patch.object(retrieval, "_on_query"), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", usage) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.ROUTER, + metadata_filter_document_ids={"ds-1": ["doc-1"]}, + metadata_condition=SimpleNamespace(), + ) + + assert results == [result_doc] + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc-1"] + assert retrieval.llm_usage.total_tokens == 1 + + def test_single_retrieve_returns_empty_when_no_dataset_selected(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls: + mock_router_cls.return_value.invoke.return_value = (None, LLMUsage.empty_usage()) + results = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[ + SimpleNamespace(id="ds-1", name="DS", description=None), + ], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + ) + assert results == [] + + def test_single_retrieve_respects_metadata_filter_shortcuts(self, retrieval: DatasetRetrieval) -> None: + dataset = SimpleNamespace( + id="ds-1", + name="Internal DS", + description="desc", + provider="dify", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={"top_k": 2, "search_method": "semantic_search", "reranking_enable": False}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.ReactMultiDatasetRouter") as mock_router_cls, + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset), + ): + mock_router_cls.return_value.invoke.return_value = ("ds-1", LLMUsage.empty_usage()) + no_filter = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids=None, + metadata_condition=SimpleNamespace(), + ) + missing_doc_ids = retrieval.single_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + query="python", + available_datasets=[dataset], + model_instance=Mock(), + model_config=Mock(), + planning_strategy=PlanningStrategy.REACT_ROUTER, + metadata_filter_document_ids={"other-ds": ["x"]}, + metadata_condition=None, + ) + assert no_filter == [] + assert missing_doc_ids == [] + + def test_multiple_retrieve_validation_paths(self, retrieval: DatasetRetrieval) -> None: + assert ( + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=[], + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + == [] + ) + + mixed = [ + SimpleNamespace(id="d1", indexing_technique="high_quality"), + SimpleNamespace(id="d2", indexing_technique="economy"), + ] + with pytest.raises(ValueError, match="different indexing technique"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=mixed, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="weighted_score", + reranking_enable=False, + ) + + high_quality_mismatch = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-b", + embedding_model_provider="provider-b", + ), + ] + with pytest.raises(ValueError, match="different embedding model"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=high_quality_mismatch, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + ) + + def test_multiple_retrieve_threads_and_dedup(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + SimpleNamespace( + id="d2", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ), + ] + doc_a = _doc(provider="dify", score=0.8, dataset_id="d1", document_id="doc-1", doc_id="dup") + doc_b = _doc(provider="dify", score=0.7, dataset_id="d2", document_id="doc-2", doc_id="dup") + doc_external = _doc( + provider="external", + score=0.9, + dataset_id="ext-ds", + document_id="ext-doc", + doc_id="ext-node", + extra={"dataset_name": "Ext", "title": "Ext"}, + ) + app = Flask(__name__) + weights: WeightsDict = { + "vector_setting": {"vector_weight": 0.5, "embedding_provider_name": "", "embedding_model_name": ""}, + "keyword_setting": {"keyword_weight": 0.5}, + } + + def fake_multiple_thread(**kwargs): + if kwargs["query"]: + kwargs["all_documents"].extend([doc_a, doc_b]) + if kwargs["attachment_id"]: + kwargs["all_documents"].append(doc_external) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=fake_multiple_thread), + patch.object(retrieval, "_on_query") as mock_on_query, + patch.object(retrieval, "_on_retrieval_end") as mock_end, + ): + result = retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode=RerankMode.WEIGHTED_SCORE, + reranking_enable=True, + weights=weights, + attachment_ids=["att-1"], + message_id="m1", + ) + + assert len(result) == 2 + assert any(doc.provider == "external" for doc in result) + assert weights["vector_setting"]["embedding_provider_name"] == "provider-a" + assert weights["vector_setting"]["embedding_model_name"] == "model-a" + mock_on_query.assert_called_once() + mock_end.assert_called_once() + + def test_multiple_retrieve_propagates_thread_exception(self, retrieval: DatasetRetrieval) -> None: + datasets = [ + SimpleNamespace( + id="d1", + indexing_technique="high_quality", + embedding_model="model-a", + embedding_model_provider="provider-a", + ) + ] + app = Flask(__name__) + + def failing_thread(**kwargs): + kwargs["thread_exceptions"].append(RuntimeError("thread boom")) + + with app.app_context(): + with ( + patch("core.rag.retrieval.dataset_retrieval.measure_time", _timer), + patch("core.rag.retrieval.dataset_retrieval.threading.Thread", _ImmediateThread), + patch.object(retrieval, "_multiple_retrieve_thread", side_effect=failing_thread), + ): + with pytest.raises(RuntimeError, match="thread boom"): + retrieval.multiple_retrieve( + app_id="app-1", + tenant_id="tenant-1", + user_id="user-1", + user_from="workflow", + available_datasets=datasets, + query="python", + top_k=2, + score_threshold=0.0, + reranking_mode="reranking_model", + ) + + +class TestInternalHooksCoverage: + @pytest.fixture + def retrieval(self) -> DatasetRetrieval: + return DatasetRetrieval() + + def test_on_retrieval_end_without_dify_documents(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + with patch.object(retrieval, "_send_trace_task") as mock_trace: + retrieval._on_retrieval_end( + flask_app=app, + documents=[_doc(provider="external")], + message_id="m1", + timer={"cost": 1}, + ) + mock_trace.assert_called_once() + + def test_on_retrieval_end_dify_without_document_ids(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + doc = Document(page_content="x", metadata={"doc_id": "n1"}, provider="dify") + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=[doc], message_id="m1", timer={"cost": 1}) + mock_trace.assert_called_once() + + def test_on_retrieval_end_updates_segments_for_text_and_image(self, retrieval: DatasetRetrieval) -> None: + app = Flask(__name__) + docs = [ + _doc(provider="dify", document_id="doc-a", doc_id="idx-a", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-b", doc_id="att-b", extra={"doc_type": DocType.IMAGE}), + _doc(provider="dify", document_id="doc-c", doc_id="idx-c", extra={"doc_type": "text"}), + _doc(provider="dify", document_id="doc-d", doc_id="att-d", extra={"doc_type": DocType.IMAGE}), + ] + dataset_docs = [ + SimpleNamespace(id="doc-a", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-b", doc_form=IndexStructureType.PARENT_CHILD_INDEX), + SimpleNamespace(id="doc-c", doc_form="qa_model"), + SimpleNamespace(id="doc-d", doc_form="qa_model"), + ] + child_chunks = [SimpleNamespace(index_node_id="idx-a", segment_id="seg-a")] + segments = [SimpleNamespace(index_node_id="idx-c", id="seg-c")] + bindings = [SimpleNamespace(segment_id="seg-b"), SimpleNamespace(segment_id="seg-d")] + + def _scalars(items): + result = Mock() + result.all.return_value = items + return result + + session = Mock() + session.scalars.side_effect = [ + _scalars(dataset_docs), + _scalars(child_chunks), + _scalars(segments), + _scalars(bindings), + ] + query = Mock() + query.where.return_value = query + session.query.return_value = query + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = False + + with ( + patch("core.rag.retrieval.dataset_retrieval.db", SimpleNamespace(engine=Mock())), + patch("core.rag.retrieval.dataset_retrieval.Session", return_value=session_ctx), + patch.object(retrieval, "_send_trace_task") as mock_trace, + ): + retrieval._on_retrieval_end(flask_app=app, documents=docs, message_id="m1", timer={"cost": 1}) + + query.update.assert_called_once() + session.commit.assert_called_once() + mock_trace.assert_called_once() + + def test_retriever_variants(self, retrieval: DatasetRetrieval) -> None: + flask_app = SimpleNamespace(app_context=lambda: nullcontext()) + all_documents: list[Document] = [] + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=None): + assert ( + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="d1", + query="python", + top_k=1, + all_documents=all_documents, + ) + == [] + ) + + external_dataset = SimpleNamespace( + id="ext-ds", + name="External", + provider="external", + tenant_id="tenant-1", + retrieval_model={"top_k": 2}, + indexing_technique="high_quality", + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=external_dataset), + patch( + "core.rag.retrieval.dataset_retrieval.ExternalDatasetService.fetch_external_knowledge_retrieval" + ) as mock_external, + ): + mock_external.return_value = [{"content": "e", "metadata": {}, "score": 0.8, "title": "Ext"}] + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="ext-ds", + query="python", + top_k=1, + all_documents=all_documents, + ) + + economy_dataset = SimpleNamespace( + id="eco-ds", + provider="dify", + retrieval_model={"top_k": 1}, + indexing_technique="economy", + ) + high_dataset = SimpleNamespace( + id="hq-ds", + provider="dify", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 4, + "score_threshold": 0.3, + "score_threshold_enabled": True, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "x", "reranking_model_name": "y"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {}}, + }, + indexing_technique="high_quality", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", side_effect=[economy_dataset, high_dataset] + ), + patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService.retrieve", return_value=[_doc(provider="dify")] + ) as mock_retrieve, + ): + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="eco-ds", + query="python", + top_k=2, + all_documents=all_documents, + ) + retrieval._retriever( + flask_app=flask_app, # type: ignore[arg-type] + dataset_id="hq-ds", + query="python", + top_k=2, + all_documents=all_documents, + attachment_ids=["att-1"], + ) + assert mock_retrieve.call_count == 2 + assert len(all_documents) >= 3 + + def test_to_dataset_retriever_tool_paths(self, retrieval: DatasetRetrieval) -> None: + dataset_skip_zero = SimpleNamespace(id="d1", provider="dify", available_document_count=0) + dataset_ok_single = SimpleNamespace( + id="d2", + provider="dify", + available_document_count=2, + retrieval_model={"top_k": 2, "score_threshold_enabled": True, "score_threshold": 0.1}, + ) + single_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ) + with ( + patch( + "core.rag.retrieval.dataset_retrieval.db.session.scalar", + side_effect=[None, dataset_skip_zero, dataset_ok_single], + ), + patch( + "core.tools.utils.dataset_retriever.dataset_retriever_tool.DatasetRetrieverTool.from_dataset", + return_value="single-tool", + ) as mock_single_tool, + ): + single_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["missing", "d1", "d2"], + retrieve_config=single_config, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={"k": "v"}, + ) + + assert single_tools == ["single-tool"] + mock_single_tool.assert_called_once() + + multiple_config_missing = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + reranking_model=None, + ) + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single): + with pytest.raises(ValueError, match="Reranking model is required"): + retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config_missing, + return_resource=True, + invoke_from=InvokeFrom.WEB_APP, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + + multiple_config = DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + metadata_filtering_mode="disabled", + top_k=3, + score_threshold=0.2, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v3"}, + ) + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalar", return_value=dataset_ok_single), + patch( + "core.tools.utils.dataset_retriever.dataset_multi_retriever_tool.DatasetMultiRetrieverTool.from_dataset", + return_value="multi-tool", + ) as mock_multi_tool, + ): + multi_tools = retrieval.to_dataset_retriever_tool( + tenant_id="tenant-1", + dataset_ids=["d2"], + retrieve_config=multiple_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="user-1", + inputs={}, + ) + assert multi_tools == ["multi-tool"] + mock_multi_tool.assert_called_once() + + def test_additional_small_branches(self, retrieval: DatasetRetrieval) -> None: + keyword_handler = Mock() + keyword_handler.extract_keywords.side_effect = [[], []] + doc = Document(page_content="doc", metadata={"doc_id": "1"}, provider="dify") + with patch("core.rag.retrieval.dataset_retrieval.JiebaKeywordTableHandler", return_value=keyword_handler): + ranked = retrieval.calculate_keyword_score("query", [doc], top_k=1) + assert len(ranked) == 1 + assert ranked[0].metadata.get("score") == 0.0 + + with patch("core.rag.retrieval.dataset_retrieval.db.session.scalars") as mock_scalars: + mock_scalars.return_value.all.return_value = [] + with pytest.raises(ValueError): + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=None, # type: ignore[arg-type] + ) + + session_scalars = Mock() + session_scalars.all.return_value = [SimpleNamespace(name="author")] + with ( + patch("core.rag.retrieval.dataset_retrieval.db.session.scalars", return_value=session_scalars), + patch.object(retrieval, "_fetch_model_config", return_value=(Mock(), Mock())), + patch.object(retrieval, "_get_prompt_template", return_value=(["prompt"], [])), + patch.object(retrieval, "_record_usage"), + ): + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("nope") + with patch.object(retrieval, "_fetch_model_config", return_value=(model_instance, Mock())): + assert ( + retrieval._automatic_metadata_filter_func( + dataset_ids=["d1"], + query="python", + tenant_id="tenant-1", + user_id="user-1", + metadata_model_config=WorkflowModelConfig(provider="openai", name="gpt", mode="chat"), + ) + is None + ) + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelMode", return_value=object()), + patch("core.rag.retrieval.dataset_retrieval.AdvancedPromptTransform"), + ): + with pytest.raises(ValueError, match="not support"): + retrieval._get_prompt_template( + model_config=ModelConfigWithCredentialsEntity.model_construct( + provider="openai", + model="gpt", + model_schema=Mock(), + mode="chat", + provider_model_bundle=Mock(), + credentials={}, + parameters={}, + stop=[], + ), + mode="chat", + metadata_fields=[], + query="q", + ) diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py deleted file mode 100644 index 07d6e51e4b..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py +++ /dev/null @@ -1,873 +0,0 @@ -""" -Unit tests for DatasetRetrieval.process_metadata_filter_func. - -This module provides comprehensive test coverage for the process_metadata_filter_func -method in the DatasetRetrieval class, which is responsible for building SQLAlchemy -filter expressions based on metadata filtering conditions. - -Conditions Tested: -================== -1. **String Conditions**: contains, not contains, start with, end with -2. **Equality Conditions**: is / =, is not / ≠ -3. **Null Conditions**: empty, not empty -4. **Numeric Comparisons**: before / <, after / >, ≤ / <=, ≥ / >= -5. **List Conditions**: in -6. **Edge Cases**: None values, different data types (str, int, float) - -Test Architecture: -================== -- Direct instantiation of DatasetRetrieval -- Mocking of DatasetDocument model attributes -- Verification of SQLAlchemy filter expressions -- Follows Arrange-Act-Assert (AAA) pattern - -Running Tests: -============== - # Run all tests in this module - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py -v - - # Run a specific test - uv run --project api pytest \ - api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_metadata_filter.py::\ -TestProcessMetadataFilterFunc::test_contains_condition -v -""" - -from unittest.mock import MagicMock - -import pytest - -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval - - -class TestProcessMetadataFilterFunc: - """ - Comprehensive test suite for process_metadata_filter_func method. - - This test class validates all metadata filtering conditions supported by - the DatasetRetrieval class, including string operations, numeric comparisons, - null checks, and list operations. - - Method Signature: - ================== - def process_metadata_filter_func( - self, sequence: int, condition: str, metadata_name: str, value: Any | None, filters: list - ) -> list: - - The method builds SQLAlchemy filter expressions by: - 1. Validating value is not None (except for empty/not empty conditions) - 2. Using DatasetDocument.doc_metadata JSON field operations - 3. Adding appropriate SQLAlchemy expressions to the filters list - 4. Returning the updated filters list - - Mocking Strategy: - ================== - - Mock DatasetDocument.doc_metadata to avoid database dependencies - - Verify filter expressions are created correctly - - Test with various data types (str, int, float, list) - """ - - @pytest.fixture - def retrieval(self): - """ - Create a DatasetRetrieval instance for testing. - - Returns: - DatasetRetrieval: Instance to test process_metadata_filter_func - """ - return DatasetRetrieval() - - @pytest.fixture - def mock_doc_metadata(self): - """ - Mock the DatasetDocument.doc_metadata JSON field. - - The method uses DatasetDocument.doc_metadata[metadata_name] to access - JSON fields. We mock this to avoid database dependencies. - - Returns: - Mock: Mocked doc_metadata attribute - """ - mock_metadata_field = MagicMock() - - # Create mock for string access - mock_string_access = MagicMock() - mock_string_access.like = MagicMock() - mock_string_access.notlike = MagicMock() - mock_string_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_string_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_string_access.in_ = MagicMock(return_value=MagicMock()) - - # Create mock for float access (for numeric comparisons) - mock_float_access = MagicMock() - mock_float_access.__eq__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ne__ = MagicMock(return_value=MagicMock()) - mock_float_access.__lt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__gt__ = MagicMock(return_value=MagicMock()) - mock_float_access.__le__ = MagicMock(return_value=MagicMock()) - mock_float_access.__ge__ = MagicMock(return_value=MagicMock()) - - # Create mock for null checks - mock_null_access = MagicMock() - mock_null_access.is_ = MagicMock(return_value=MagicMock()) - mock_null_access.isnot = MagicMock(return_value=MagicMock()) - - # Setup __getitem__ to return appropriate mock based on usage - def getitem_side_effect(name): - if name in ["author", "title", "category"]: - return mock_string_access - elif name in ["year", "price", "rating"]: - return mock_float_access - else: - return mock_string_access - - mock_metadata_field.__getitem__ = MagicMock(side_effect=getitem_side_effect) - mock_metadata_field.as_string.return_value = mock_string_access - mock_metadata_field.as_float.return_value = mock_float_access - mock_metadata_field[metadata_name:str].is_ = mock_null_access.is_ - mock_metadata_field[metadata_name:str].isnot = mock_null_access.isnot - - return mock_metadata_field - - # ==================== String Condition Tests ==================== - - def test_contains_condition_string_value(self, retrieval): - """ - Test 'contains' condition with string value. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value% syntax - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "John" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_contains_condition(self, retrieval): - """ - Test 'not contains' condition. - - Verifies: - - Filters list is populated with NOT LIKE expression - - Pattern matching uses %value% syntax with negation - """ - filters = [] - sequence = 0 - condition = "not contains" - metadata_name = "title" - value = "banned" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_start_with_condition(self, retrieval): - """ - Test 'start with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses value% syntax - """ - filters = [] - sequence = 0 - condition = "start with" - metadata_name = "category" - value = "tech" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_end_with_condition(self, retrieval): - """ - Test 'end with' condition. - - Verifies: - - Filters list is populated with LIKE expression - - Pattern matching uses %value syntax - """ - filters = [] - sequence = 0 - condition = "end with" - metadata_name = "filename" - value = ".pdf" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Equality Condition Tests ==================== - - def test_is_condition_with_string_value(self, retrieval): - """ - Test 'is' (=) condition with string value. - - Verifies: - - Filters list is populated with equality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = "Jane Doe" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_equals_condition_with_string_value(self, retrieval): - """ - Test '=' condition with string value. - - Verifies: - - Same behavior as 'is' condition - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "=" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_int_value(self, retrieval): - """ - Test 'is' condition with integer value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_condition_with_float_value(self, retrieval): - """ - Test 'is' condition with float value. - - Verifies: - - Numeric comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "price" - value = 19.99 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_string_value(self, retrieval): - """ - Test 'is not' (≠) condition with string value. - - Verifies: - - Filters list is populated with inequality expression - - String comparison is used - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "author" - value = "Unknown" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_equals_condition(self, retrieval): - """ - Test '≠' condition with string value. - - Verifies: - - Same behavior as 'is not' condition - - Inequality expression is used - """ - filters = [] - sequence = 0 - condition = "≠" - metadata_name = "category" - value = "archived" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_is_not_condition_with_numeric_value(self, retrieval): - """ - Test 'is not' condition with numeric value. - - Verifies: - - Numeric inequality comparison is used - - as_float() is called on the metadata field - """ - filters = [] - sequence = 0 - condition = "is not" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Null Condition Tests ==================== - - def test_empty_condition(self, retrieval): - """ - Test 'empty' condition (null check). - - Verifies: - - Filters list is populated with IS NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "empty" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_not_empty_condition(self, retrieval): - """ - Test 'not empty' condition (not null check). - - Verifies: - - Filters list is populated with IS NOT NULL expression - - Value can be None for this condition - """ - filters = [] - sequence = 0 - condition = "not empty" - metadata_name = "description" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Numeric Comparison Tests ==================== - - def test_before_condition(self, retrieval): - """ - Test 'before' (<) condition. - - Verifies: - - Filters list is populated with less than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "before" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_condition(self, retrieval): - """ - Test '<' condition. - - Verifies: - - Same behavior as 'before' condition - - Less than expression is used - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "price" - value = 100.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_after_condition(self, retrieval): - """ - Test 'after' (>) condition. - - Verifies: - - Filters list is populated with greater than expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "after" - metadata_name = "year" - value = 2020 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_condition(self, retrieval): - """ - Test '>' condition. - - Verifies: - - Same behavior as 'after' condition - - Greater than expression is used - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≤' condition. - - Verifies: - - Filters list is populated with less than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≤" - metadata_name = "price" - value = 50.0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_less_than_or_equal_condition_ascii(self, retrieval): - """ - Test '<=' condition. - - Verifies: - - Same behavior as '≤' condition - - Less than or equal expression is used - """ - filters = [] - sequence = 0 - condition = "<=" - metadata_name = "year" - value = 2023 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_unicode(self, retrieval): - """ - Test '≥' condition. - - Verifies: - - Filters list is populated with greater than or equal expression - - Numeric comparison is used - """ - filters = [] - sequence = 0 - condition = "≥" - metadata_name = "rating" - value = 3.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_greater_than_or_equal_condition_ascii(self, retrieval): - """ - Test '>=' condition. - - Verifies: - - Same behavior as '≥' condition - - Greater than or equal expression is used - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "year" - value = 2000 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== List/In Condition Tests ==================== - - def test_in_condition_with_comma_separated_string(self, retrieval): - """ - Test 'in' condition with comma-separated string value. - - Verifies: - - String is split into list - - Whitespace is trimmed from each value - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "tech, science, AI " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_list_value(self, retrieval): - """ - Test 'in' condition with list value. - - Verifies: - - List is processed correctly - - None values are filtered out - - IN expression is created with valid values - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "tags" - value = ["python", "javascript", None, "golang"] - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_tuple_value(self, retrieval): - """ - Test 'in' condition with tuple value. - - Verifies: - - Tuple is processed like a list - - IN expression is created - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = ("tech", "science", "ai") - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_empty_string(self, retrieval): - """ - Test 'in' condition with empty string value. - - Verifies: - - Empty string results in literal(False) filter - - No valid values to match - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - # Verify it's a literal(False) expression - # This is a bit tricky to test without access to the actual expression - - def test_in_condition_with_only_whitespace(self, retrieval): - """ - Test 'in' condition with whitespace-only string value. - - Verifies: - - Whitespace-only string results in literal(False) filter - - All values are stripped and filtered out - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = " , , " - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_in_condition_with_single_string(self, retrieval): - """ - Test 'in' condition with single non-comma string. - - Verifies: - - Single string is treated as single-item list - - IN expression is created with one value - """ - filters = [] - sequence = 0 - condition = "in" - metadata_name = "category" - value = "technology" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - # ==================== Edge Case Tests ==================== - - def test_none_value_with_non_empty_condition(self, retrieval): - """ - Test None value with conditions that require value. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values (except empty/not empty) - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 # No filter added - - def test_none_value_with_equals_condition(self, retrieval): - """ - Test None value with 'is' (=) condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = "is" - metadata_name = "author" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_none_value_with_numeric_condition(self, retrieval): - """ - Test None value with numeric comparison condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for None values - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "year" - value = None - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_existing_filters_preserved(self, retrieval): - """ - Test that existing filters are preserved. - - Verifies: - - Existing filters in the list are not removed - - New filters are appended to the list - """ - existing_filter = MagicMock() - filters = [existing_filter] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 2 - assert filters[0] == existing_filter - - def test_multiple_filters_accumulated(self, retrieval): - """ - Test multiple calls to accumulate filters. - - Verifies: - - Each call adds a new filter to the list - - All filters are preserved across calls - """ - filters = [] - - # First filter - retrieval.process_metadata_filter_func(0, "contains", "author", "John", filters) - assert len(filters) == 1 - - # Second filter - retrieval.process_metadata_filter_func(1, ">", "year", 2020, filters) - assert len(filters) == 2 - - # Third filter - retrieval.process_metadata_filter_func(2, "is", "category", "tech", filters) - assert len(filters) == 3 - - def test_unknown_condition(self, retrieval): - """ - Test unknown/unsupported condition. - - Verifies: - - Original filters list is returned unchanged - - No filter is added for unknown conditions - """ - filters = [] - sequence = 0 - condition = "unknown_condition" - metadata_name = "author" - value = "test" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 0 - - def test_empty_string_value_with_contains(self, retrieval): - """ - Test empty string value with 'contains' condition. - - Verifies: - - Filter is added even with empty string - - LIKE expression is created - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "author" - value = "" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_special_characters_in_value(self, retrieval): - """ - Test special characters in value string. - - Verifies: - - Special characters are handled in value - - LIKE expression is created correctly - """ - filters = [] - sequence = 0 - condition = "contains" - metadata_name = "title" - value = "C++ & Python's features" - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_zero_value_with_numeric_condition(self, retrieval): - """ - Test zero value with numeric comparison condition. - - Verifies: - - Zero is treated as valid value - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">" - metadata_name = "price" - value = 0 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_negative_value_with_numeric_condition(self, retrieval): - """ - Test negative value with numeric comparison condition. - - Verifies: - - Negative numbers are handled correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = "<" - metadata_name = "temperature" - value = -10.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 - - def test_float_value_with_integer_comparison(self, retrieval): - """ - Test float value with numeric comparison condition. - - Verifies: - - Float values work correctly - - Numeric comparison is performed - """ - filters = [] - sequence = 0 - condition = ">=" - metadata_name = "rating" - value = 4.5 - - result = retrieval.process_metadata_filter_func(sequence, condition, metadata_name, value, filters) - - assert result == filters - assert len(filters) == 1 diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py index 4bc802dc23..48782515d0 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -6,7 +6,7 @@ import pytest from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.workflow.nodes.knowledge_retrieval import exc -from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest from models.dataset import Dataset # ==================== Helper Functions ==================== diff --git a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py deleted file mode 100644 index 5f461d53ae..0000000000 --- a/api/tests/unit_tests/core/rag/retrieval/test_knowledge_retrieval.py +++ /dev/null @@ -1,113 +0,0 @@ -import threading -from unittest.mock import Mock, patch -from uuid import uuid4 - -import pytest -from flask import Flask, current_app - -from core.rag.models.document import Document -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from models.dataset import Dataset - - -class TestRetrievalService: - @pytest.fixture - def mock_dataset(self) -> Dataset: - dataset = Mock(spec=Dataset) - dataset.id = str(uuid4()) - dataset.tenant_id = str(uuid4()) - dataset.name = "test_dataset" - dataset.indexing_technique = "high_quality" - dataset.provider = "dify" - return dataset - - def test_multiple_retrieve_reranking_with_app_context(self, mock_dataset): - """ - Repro test for current bug: - reranking runs after `with flask_app.app_context():` exits. - `_multiple_retrieve_thread` catches exceptions and stores them into `thread_exceptions`, - so we must assert from that list (not from an outer try/except). - """ - dataset_retrieval = DatasetRetrieval() - flask_app = Flask(__name__) - tenant_id = str(uuid4()) - - # second dataset to ensure dataset_count > 1 reranking branch - secondary_dataset = Mock(spec=Dataset) - secondary_dataset.id = str(uuid4()) - secondary_dataset.provider = "dify" - secondary_dataset.indexing_technique = "high_quality" - - # retriever returns 1 doc into internal list (all_documents_item) - document = Document( - page_content="Context aware doc", - metadata={ - "doc_id": "doc1", - "score": 0.95, - "document_id": str(uuid4()), - "dataset_id": mock_dataset.id, - }, - provider="dify", - ) - - def fake_retriever( - flask_app, dataset_id, query, top_k, all_documents, document_ids_filter, metadata_condition, attachment_ids - ): - all_documents.append(document) - - called = {"init": 0, "invoke": 0} - - class ContextRequiredPostProcessor: - def __init__(self, *args, **kwargs): - called["init"] += 1 - # will raise RuntimeError if no Flask app context exists - _ = current_app.name - - def invoke(self, *args, **kwargs): - called["invoke"] += 1 - _ = current_app.name - return kwargs.get("documents") or args[1] - - # output list from _multiple_retrieve_thread - all_documents: list[Document] = [] - - # IMPORTANT: _multiple_retrieve_thread swallows exceptions and appends them here - thread_exceptions: list[Exception] = [] - - def target(): - with patch.object(dataset_retrieval, "_retriever", side_effect=fake_retriever): - with patch( - "core.rag.retrieval.dataset_retrieval.DataPostProcessor", - ContextRequiredPostProcessor, - ): - dataset_retrieval._multiple_retrieve_thread( - flask_app=flask_app, - available_datasets=[mock_dataset, secondary_dataset], - metadata_condition=None, - metadata_filter_document_ids=None, - all_documents=all_documents, - tenant_id=tenant_id, - reranking_enable=True, - reranking_mode="reranking_model", - reranking_model={ - "reranking_provider_name": "cohere", - "reranking_model_name": "rerank-v2", - }, - weights=None, - top_k=3, - score_threshold=0.0, - query="test query", - attachment_id=None, - dataset_count=2, # force reranking branch - thread_exceptions=thread_exceptions, # ✅ key - ) - - t = threading.Thread(target=target) - t.start() - t.join() - - # Ensure reranking branch was actually executed - assert called["init"] >= 1, "DataPostProcessor was never constructed; reranking branch may not have run." - - # Current buggy code should record an exception (not raise it) - assert not thread_exceptions, thread_exceptions diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py new file mode 100644 index 0000000000..cfa9094e12 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -0,0 +1,100 @@ +from unittest.mock import Mock + +from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage + + +class TestFunctionCallMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = FunctionCallMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_model_response(self) -> None: + router = FunctionCallMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + usage = LLMUsage.empty_usage() + response = Mock() + response.usage = usage + response.message.tool_calls = [Mock(function=Mock())] + response.message.tool_calls[0].function.name = "dataset-2" + model_instance = Mock() + model_instance.invoke_llm.return_value = response + + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id == "dataset-2" + assert returned_usage == usage + model_instance.invoke_llm.assert_called_once() + + def test_invoke_returns_none_when_no_tool_calls(self) -> None: + router = FunctionCallMultiDatasetRouter() + response = Mock() + response.usage = LLMUsage.empty_usage() + response.message.tool_calls = [] + model_instance = Mock() + model_instance.invoke_llm.return_value = response + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == response.usage + + def test_invoke_returns_empty_usage_when_model_raises(self) -> None: + router = FunctionCallMultiDatasetRouter() + model_instance = Mock() + model_instance.invoke_llm.side_effect = RuntimeError("boom") + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=model_instance, + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py new file mode 100644 index 0000000000..e429563739 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -0,0 +1,252 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class TestReactMultiDatasetRouter: + def test_invoke_returns_none_when_no_tools(self) -> None: + router = ReactMultiDatasetRouter() + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_single_tool_directly(self) -> None: + router = ReactMultiDatasetRouter() + tool = Mock() + tool.name = "dataset-1" + + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id == "dataset-1" + assert usage == LLMUsage.empty_usage() + + def test_invoke_returns_tool_from_react_invoke(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + tool_1 = Mock(name="dataset-1") + tool_1.name = "dataset-1" + tool_2 = Mock(name="dataset-2") + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", return_value=("dataset-2", usage)) as mock_react: + dataset_id, returned_usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + mock_react.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_invoke_handles_react_invoke_errors(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_2 = Mock() + tool_2.name = "dataset-2" + + with patch.object(router, "_react_invoke", side_effect=RuntimeError("boom")): + dataset_id, usage = router.invoke( + query="python", + dataset_tools=[tool_1, tool_2], + model_config=Mock(), + model_instance=Mock(), + user_id="u1", + tenant_id="t1", + ) + + assert dataset_id is None + assert usage == LLMUsage.empty_usage() + + def test_react_invoke_returns_action_tool(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "chat" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] + tools[0].name = "dataset-1" + tools[0].description = "desc" + tools[1].name = "dataset-2" + tools[1].description = "desc" + + with ( + patch.object(router, "create_chat_prompt", return_value=[Mock()]) as mock_chat_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object(router, "_invoke_llm", return_value=('{"action":"dataset-2","action_input":{}}', usage)), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactAction("dataset-2", {}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=tools, + user_id="u1", + tenant_id="t1", + ) + + mock_chat_prompt.assert_called_once() + assert dataset_id == "dataset-2" + assert returned_usage == usage + + def test_react_invoke_returns_none_for_finish(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "completion" + model_config.parameters = {"temperature": 0.1} + usage = LLMUsage.empty_usage() + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, "_invoke_llm", return_value=('{"action":"Final Answer","action_input":"done"}', usage) + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + + dataset_id, returned_usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert returned_usage == usage + + def test_invoke_llm_and_handle_result(self) -> None: + router = ReactMultiDatasetRouter() + usage = LLMUsage.empty_usage() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=usage) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + model_instance = Mock() + model_instance.invoke_llm.return_value = iter([chunk]) + + with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + text, returned_usage = router._invoke_llm( + completion_param={"temperature": 0.1}, + model_instance=model_instance, + prompt_messages=[Mock()], + stop=["Observation:"], + user_id="u1", + tenant_id="t1", + ) + + assert text == "part" + assert returned_usage == usage + mock_deduct.assert_called_once() + + def test_handle_invoke_result_with_empty_usage(self) -> None: + router = ReactMultiDatasetRouter() + delta = SimpleNamespace(message=SimpleNamespace(content="part"), usage=None) + chunk = SimpleNamespace(model="m1", prompt_messages=[Mock()], delta=delta) + + text, usage = router._handle_invoke_result(iter([chunk])) + + assert text == "part" + assert usage == LLMUsage.empty_usage() + + def test_create_chat_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + chat_prompt = router.create_chat_prompt(query="python", tools=[tool_1, tool_2]) + assert len(chat_prompt) == 2 + assert chat_prompt[0].role == PromptMessageRole.SYSTEM + assert chat_prompt[1].role == PromptMessageRole.USER + assert "dataset-1" in chat_prompt[0].text + assert "dataset-2" in chat_prompt[0].text + + def test_create_completion_prompt(self) -> None: + router = ReactMultiDatasetRouter() + tool_1 = Mock() + tool_1.name = "dataset-1" + tool_1.description = "d1" + tool_2 = Mock() + tool_2.name = "dataset-2" + tool_2.description = "d2" + + completion_prompt = router.create_completion_prompt(tools=[tool_1, tool_2]) + assert "dataset-1: d1" in completion_prompt.text + assert "dataset-2: d2" in completion_prompt.text + + def test_react_invoke_uses_completion_branch_for_non_chat_mode(self) -> None: + router = ReactMultiDatasetRouter() + model_config = Mock() + model_config.mode = "unknown-mode" + model_config.parameters = {} + tool = Mock() + tool.name = "dataset-1" + tool.description = "desc" + + with ( + patch.object(router, "create_completion_prompt", return_value=Mock()) as mock_completion_prompt, + patch( + "core.rag.retrieval.router.multi_dataset_react_route.AdvancedPromptTransform" + ) as mock_prompt_transform, + patch.object( + router, + "_invoke_llm", + return_value=('{"action":"Final Answer","action_input":"done"}', LLMUsage.empty_usage()), + ), + patch("core.rag.retrieval.router.multi_dataset_react_route.StructuredChatOutputParser") as mock_parser_cls, + ): + mock_prompt_transform.return_value.get_prompt.return_value = [Mock()] + mock_parser_cls.return_value.parse.return_value = ReactFinish({"output": "done"}, "log") + dataset_id, usage = router._react_invoke( + query="python", + model_config=model_config, + model_instance=Mock(), + tools=[tool], + user_id="u1", + tenant_id="t1", + ) + + mock_completion_prompt.assert_called_once() + assert dataset_id is None + assert usage == LLMUsage.empty_usage() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py new file mode 100644 index 0000000000..c8fa0ea62f --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_structured_chat_output_parser.py @@ -0,0 +1,69 @@ +import pytest + +from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish +from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser + + +class TestStructuredChatOutputParser: + def test_parse_action_without_action_input(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"some_action"}\n```' + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "some_action" + assert result.tool_input == {} + + def test_parse_json_without_action_key(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"not_action":"search"}\n```' + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) + + def test_parse_returns_action_for_tool_call(self) -> None: + parser = StructuredChatOutputParser() + text = ( + 'Thought: call tool\nAction:\n```json\n{"action":"search_dataset","action_input":{"query":"python"}}\n```' + ) + + result = parser.parse(text) + + assert isinstance(result, ReactAction) + assert result.tool == "search_dataset" + assert result.tool_input == {"query": "python"} + assert result.log == text + + def test_parse_returns_finish_for_final_answer(self) -> None: + parser = StructuredChatOutputParser() + text = 'Thought: done\nAction:\n```json\n{"action":"Final Answer","action_input":"final text"}\n```' + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": "final text"} + assert result.log == text + + def test_parse_returns_finish_for_json_array_payload(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n[{"action":"search","action_input":"hello"}]\n```' + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + assert result.log == text + + def test_parse_returns_finish_for_plain_text(self) -> None: + parser = StructuredChatOutputParser() + text = "No structured action block" + + result = parser.parse(text) + + assert isinstance(result, ReactFinish) + assert result.return_values == {"output": text} + + def test_parse_raises_value_error_for_invalid_json(self) -> None: + parser = StructuredChatOutputParser() + text = 'Action:\n```json\n{"action":"search","action_input": }\n```' + + with pytest.raises(ValueError, match="Could not parse LLM output"): + parser.parse(text) diff --git a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py index 943a9e5712..976de10d89 100644 --- a/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py +++ b/api/tests/unit_tests/core/rag/splitter/test_text_splitter.py @@ -125,7 +125,11 @@ Run with coverage: - Tests are organized by functionality in classes for better organization """ +import asyncio import string +import sys +import types +from inspect import currentframe from unittest.mock import Mock, patch import pytest @@ -604,6 +608,51 @@ class TestRecursiveCharacterTextSplitter: assert "def hello_world" in combined or "hello_world" in combined +class TestTextSplitterBasePaths: + """Target uncovered base TextSplitter paths.""" + + def test_from_huggingface_tokenizer_success_path(self): + """Cover from_huggingface_tokenizer success branch with mocked transformers.""" + + class _FakePreTrainedTokenizerBase: + pass + + class _FakeTokenizer(_FakePreTrainedTokenizerBase): + def encode(self, text: str): + return [ord(c) for c in text] + + fake_transformers = types.SimpleNamespace(PreTrainedTokenizerBase=_FakePreTrainedTokenizerBase) + with patch.dict(sys.modules, {"transformers": fake_transformers}): + splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer( + tokenizer=_FakeTokenizer(), + chunk_size=5, + chunk_overlap=1, + ) + + chunks = splitter.split_text("abcdef") + assert chunks + + def test_from_huggingface_tokenizer_import_error(self): + """Cover from_huggingface_tokenizer import-error branch.""" + with patch.dict(sys.modules, {"transformers": None}): + with pytest.raises(ValueError, match="Could not import transformers"): + RecursiveCharacterTextSplitter.from_huggingface_tokenizer(tokenizer=object(), chunk_size=5) + + def test_atransform_documents_raises_not_implemented(self): + """Cover atransform_documents NotImplemented branch.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=20, chunk_overlap=5) + with pytest.raises(NotImplementedError): + asyncio.run(splitter.atransform_documents([Document(page_content="x", metadata={})])) + + def test_merge_splits_logs_warning_for_oversized_total(self): + """Cover logger.warning path in _merge_splits.""" + splitter = RecursiveCharacterTextSplitter(chunk_size=5, chunk_overlap=1) + with patch("core.rag.splitter.text_splitter.logger.warning") as mock_warning: + merged = splitter._merge_splits(["abcdefghij", "b"], "", [10, 1]) + assert merged + mock_warning.assert_called_once() + + # ============================================================================ # Test TokenTextSplitter # ============================================================================ @@ -662,6 +711,44 @@ class TestTokenTextSplitter: except ImportError: pytest.skip("tiktoken not installed") + def test_initialization_and_split_with_mocked_tiktoken_encoding(self): + """Cover TokenTextSplitter __init__ else-path and split_text logic.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_tiktoken = types.SimpleNamespace(get_encoding=lambda name: _FakeEncoding()) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(encoding_name="gpt2", chunk_size=4, chunk_overlap=1) + result = splitter.split_text("abcdefgh") + + assert result + assert all(isinstance(chunk, str) for chunk in result) + + def test_initialization_with_model_name_uses_encoding_for_model(self): + """Cover TokenTextSplitter model_name init branch.""" + + class _FakeEncoding: + def encode(self, text: str, allowed_special=None, disallowed_special=None): + return [ord(c) for c in text] + + def decode(self, token_ids: list[int]) -> str: + return "".join(chr(i) for i in token_ids) + + fake_encoding = _FakeEncoding() + fake_tiktoken = types.SimpleNamespace( + encoding_for_model=lambda model_name: fake_encoding, + get_encoding=lambda name: _FakeEncoding(), + ) + with patch.dict(sys.modules, {"tiktoken": fake_tiktoken}): + splitter = TokenTextSplitter(model_name="gpt-4", chunk_size=5, chunk_overlap=1) + + assert splitter._tokenizer is fake_encoding + # ============================================================================ # Test EnhanceRecursiveCharacterTextSplitter @@ -731,6 +818,50 @@ class TestEnhanceRecursiveCharacterTextSplitter: assert len(result) > 0 assert all(isinstance(chunk, str) for chunk in result) + def test_from_encoder_internal_token_encoder_paths(self): + """ + Test internal _token_encoder branches by capturing local closure from frame. + + This validates: + - empty texts path + - embedding model path + - GPT2Tokenizer fallback path + - _character_encoder empty-path branch + """ + + class _SpySplitter(EnhanceRecursiveCharacterTextSplitter): + captured_token_encoder = None + captured_character_encoder = None + + def __init__(self, **kwargs): + frame = currentframe() + if frame and frame.f_back: + _SpySplitter.captured_token_encoder = frame.f_back.f_locals.get("_token_encoder") + _SpySplitter.captured_character_encoder = frame.f_back.f_locals.get("_character_encoder") + super().__init__(**kwargs) + + mock_model = Mock() + mock_model.get_text_embedding_num_tokens.return_value = [3, 5] + + _SpySplitter.from_encoder(embedding_model_instance=mock_model, chunk_size=10, chunk_overlap=1) + token_encoder = _SpySplitter.captured_token_encoder + character_encoder = _SpySplitter.captured_character_encoder + + assert token_encoder is not None + assert character_encoder is not None + assert token_encoder([]) == [] + assert token_encoder(["abc", "defgh"]) == [3, 5] + assert character_encoder([]) == [] + + with patch( + "core.rag.splitter.fixed_text_splitter.GPT2Tokenizer.get_num_tokens", + side_effect=lambda text: len(text) + 1, + ): + _SpySplitter.from_encoder(embedding_model_instance=None, chunk_size=10, chunk_overlap=1) + token_encoder_without_model = _SpySplitter.captured_token_encoder + assert token_encoder_without_model is not None + assert token_encoder_without_model(["ab", "cdef"]) == [3, 5] + # ============================================================================ # Test FixedRecursiveCharacterTextSplitter @@ -908,6 +1039,56 @@ class TestFixedRecursiveCharacterTextSplitter: chunks = splitter.split_text(data) assert chunks == ["chunk 1\n\nsubchunk 1.\nsubchunk 2.", "chunk 2\n\nsubchunk 1\nsubchunk 2."] + def test_recursive_split_keep_separator_and_recursive_fallback(self): + """Cover keep-separator split branch and recursive _split_text fallback.""" + text = "short." + ("x" * 60) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=[".", " ", ""], + chunk_size=10, + chunk_overlap=2, + keep_separator=True, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert any("short." in chunk for chunk in chunks) + assert any(len(chunk) <= 12 for chunk in chunks) + + def test_recursive_split_newline_separator_filtering(self): + """Cover newline-specific empty filtering branch.""" + text = "line1\n\nline2\n\nline3" + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n", ""], + chunk_size=50, + chunk_overlap=5, + ) + + chunks = splitter.recursive_split_text(text) + + assert chunks + assert all(chunk != "" for chunk in chunks) + assert "line1" in "".join(chunks) + assert "line2" in "".join(chunks) + assert "line3" in "".join(chunks) + + def test_recursive_split_without_new_separator_appends_long_chunk(self): + """Cover branch where no further separators exist and long split is appended directly.""" + text = "aa\n" + ("b" * 40) + splitter = FixedRecursiveCharacterTextSplitter( + fixed_separator="", + separators=["\n"], + chunk_size=10, + chunk_overlap=2, + ) + + chunks = splitter.recursive_split_text(text) + + assert "aa" in chunks + assert any(len(chunk) >= 40 for chunk in chunks) + # ============================================================================ # Test Metadata Preservation diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e6d0371cd5..e7eecfa297 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -11,7 +11,7 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowType +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index f6211f4cca..2a83a4e802 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -61,7 +61,7 @@ def sample_workflow_node_execution(): workflow_execution_id=str(uuid4()), index=1, node_id="test_node", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -259,7 +259,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=1, node_id="node1", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Node 1", inputs={"input1": "value1"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -272,7 +272,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=2, node_id="node2", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Node 2", inputs={"input2": "value2"}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -310,7 +310,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=2, node_id="node2", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Node 2", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, @@ -323,7 +323,7 @@ class TestCeleryWorkflowNodeExecutionRepository: workflow_execution_id=workflow_run_id, index=1, node_id="node1", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Node 1", inputs={}, status=WorkflowNodeExecutionStatus.RUNNING, diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index 7f1e2c5e5b..fe9eed0307 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -12,8 +12,8 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 811ed2143b..9af4d12664 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -5,7 +5,6 @@ from __future__ import annotations import dataclasses from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock import pytest @@ -15,7 +14,7 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -24,7 +23,7 @@ from core.workflow.nodes.human_input.entities import ( MemberRecipient, UserAction, ) -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -35,7 +34,7 @@ from models.human_input import ( def _build_repository() -> HumanInputFormRepositoryImpl: - return HumanInputFormRepositoryImpl(session_factory=MagicMock(), tenant_id="tenant-id") + return HumanInputFormRepositoryImpl(tenant_id="tenant-id") def _patch_recipient_factory(monkeypatch: pytest.MonkeyPatch) -> list[SimpleNamespace]: @@ -389,8 +388,21 @@ def _session_factory(session: _FakeSession): return _factory +def _patch_repo_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + """Patch repository's global session factory to return our fake session. + + The repositories under test now use a global session factory; patch its + create_session method so unit tests don't hit a real database. + """ + monkeypatch.setattr( + "core.repositories.human_input_repository.session_factory.create_session", + _session_factory(session), + raising=True, + ) + + class TestHumanInputFormRepositoryImplPublicMethods: - def test_get_form_returns_entity_and_recipients(self): + def test_get_form_returns_entity_and_recipients(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -408,7 +420,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: access_token="token-123", ) session = _FakeSession(scalars_results=[form, [recipient]]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -418,13 +431,14 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" - def test_get_form_returns_none_when_missing(self): + def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") assert repo.get_form("run-1", "node-1") is None - def test_get_form_returns_unsubmitted_state(self): + def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -436,7 +450,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: expiration_time=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -445,7 +460,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: assert entity.selected_action_id is None assert entity.submitted_data is None - def test_get_form_returns_submission_when_completed(self): + def test_get_form_returns_submission_when_completed(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -460,7 +475,8 @@ class TestHumanInputFormRepositoryImplPublicMethods: submitted_at=naive_utc_now(), ) session = _FakeSession(scalars_results=[form, []]) - repo = HumanInputFormRepositoryImpl(_session_factory(session), tenant_id="tenant-id") + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") entity = repo.get_form(form.workflow_run_id, form.node_id) @@ -471,7 +487,7 @@ class TestHumanInputFormRepositoryImplPublicMethods: class TestHumanInputFormSubmissionRepository: - def test_get_by_token_returns_record(self): + def test_get_by_token_returns_record(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -490,7 +506,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_token("token-123") @@ -499,7 +516,7 @@ class TestHumanInputFormSubmissionRepository: assert record.recipient_type == RecipientType.STANDALONE_WEB_APP assert record.submitted is False - def test_get_by_form_id_and_recipient_type_uses_recipient(self): + def test_get_by_form_id_and_recipient_type_uses_recipient(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( id="form-1", workflow_run_id="run-1", @@ -518,7 +535,8 @@ class TestHumanInputFormSubmissionRepository: form=form, ) session = _FakeSession(scalars_result=recipient) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record = repo.get_by_form_id_and_recipient_type( form_id=form.id, @@ -553,7 +571,8 @@ class TestHumanInputFormSubmissionRepository: forms={form.id: form}, recipients={recipient.id: recipient}, ) - repo = HumanInputFormSubmissionRepository(_session_factory(session)) + _patch_repo_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() record: HumanInputFormRecord = repo.mark_submitted( form_id=form.id, diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py new file mode 100644 index 0000000000..4116e8b4a5 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -0,0 +1,677 @@ +from __future__ import annotations + +import dataclasses +import json +from collections.abc import Sequence +from datetime import datetime, timedelta +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.repositories.human_input_repository import ( + HumanInputFormRecord, + HumanInputFormRepositoryImpl, + HumanInputFormSubmissionRepository, + _HumanInputFormEntityImpl, + _HumanInputFormRecipientEntityImpl, + _InvalidTimeoutStatusError, + _WorkspaceMemberInfo, +) +from dify_graph.nodes.human_input.entities import ( + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + ExternalRecipient, + HumanInputNodeData, + MemberRecipient, + UserAction, + WebAppDeliveryMethod, +) +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError +from libs.datetime_utils import naive_utc_now +from models.human_input import HumanInputFormRecipient, RecipientType + + +@pytest.fixture(autouse=True) +def _stub_select(monkeypatch: pytest.MonkeyPatch) -> None: + class _FakeSelect: + def join(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def where(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + def options(self, *_args: Any, **_kwargs: Any) -> _FakeSelect: + return self + + monkeypatch.setattr("core.repositories.human_input_repository.select", lambda *_args, **_kwargs: _FakeSelect()) + monkeypatch.setattr("core.repositories.human_input_repository.selectinload", lambda *_args, **_kwargs: "_loader") + + +def _make_form_definition_json(*, include_expiration_time: bool) -> str: + payload: dict[str, Any] = { + "form_content": "hi", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "rendered_content": "

hi

", + } + if include_expiration_time: + payload["expiration_time"] = naive_utc_now() + return json.dumps(payload, default=str) + + +@dataclasses.dataclass +class _DummyForm: + id: str + workflow_run_id: str | None + node_id: str + tenant_id: str + app_id: str + form_definition: str + rendered_content: str + expiration_time: datetime + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + created_at: datetime = dataclasses.field(default_factory=naive_utc_now) + selected_action_id: str | None = None + submitted_data: str | None = None + submitted_at: datetime | None = None + submission_user_id: str | None = None + submission_end_user_id: str | None = None + completed_by_recipient_id: str | None = None + status: HumanInputFormStatus = HumanInputFormStatus.WAITING + + +@dataclasses.dataclass +class _DummyRecipient: + id: str + form_id: str + recipient_type: RecipientType + access_token: str | None + + +class _FakeScalarResult: + def __init__(self, obj: Any): + self._obj = obj + + def first(self) -> Any: + if isinstance(self._obj, list): + return self._obj[0] if self._obj else None + return self._obj + + def all(self) -> list[Any]: + if self._obj is None: + return [] + if isinstance(self._obj, list): + return list(self._obj) + return [self._obj] + + +class _FakeExecuteResult: + def __init__(self, rows: Sequence[tuple[Any, ...]]): + self._rows = list(rows) + + def all(self) -> list[tuple[Any, ...]]: + return list(self._rows) + + +class _FakeSession: + def __init__( + self, + *, + scalars_result: Any = None, + scalars_results: list[Any] | None = None, + forms: dict[str, _DummyForm] | None = None, + recipients: dict[str, _DummyRecipient] | None = None, + execute_rows: Sequence[tuple[Any, ...]] = (), + ): + if scalars_results is not None: + self._scalars_queue = list(scalars_results) + else: + self._scalars_queue = [scalars_result] + self._forms = forms or {} + self._recipients = recipients or {} + self._execute_rows = list(execute_rows) + self.added: list[Any] = [] + + def scalars(self, _query: Any) -> _FakeScalarResult: + if self._scalars_queue: + value = self._scalars_queue.pop(0) + else: + value = None + return _FakeScalarResult(value) + + def execute(self, _stmt: Any) -> _FakeExecuteResult: + return _FakeExecuteResult(self._execute_rows) + + def get(self, model_cls: Any, obj_id: str) -> Any: + name = getattr(model_cls, "__name__", "") + if name == "HumanInputForm": + return self._forms.get(obj_id) + if name == "HumanInputFormRecipient": + return self._recipients.get(obj_id) + return None + + def add(self, obj: Any) -> None: + self.added.append(obj) + + def add_all(self, objs: Sequence[Any]) -> None: + self.added.extend(list(objs)) + + def flush(self) -> None: + # Simulate DB default population for attributes referenced in entity wrappers. + for obj in self.added: + if hasattr(obj, "id") and obj.id in (None, ""): + obj.id = f"gen-{len(str(self.added))}" + if isinstance(obj, HumanInputFormRecipient) and obj.access_token is None: + if obj.recipient_type == RecipientType.CONSOLE: + obj.access_token = "token-console" + elif obj.recipient_type == RecipientType.BACKSTAGE: + obj.access_token = "token-backstage" + else: + obj.access_token = "token-webapp" + + def refresh(self, _obj: Any) -> None: + return None + + def begin(self) -> _FakeSession: + return self + + def __enter__(self) -> _FakeSession: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +class _SessionFactoryStub: + def __init__(self, session: _FakeSession): + self._session = session + + def create_session(self) -> _FakeSession: + return self._session + + +def _patch_session_factory(monkeypatch: pytest.MonkeyPatch, session: _FakeSession) -> None: + monkeypatch.setattr("core.repositories.human_input_repository.session_factory", _SessionFactoryStub(session)) + + +def test_recipient_entity_token_raises_when_missing() -> None: + recipient = SimpleNamespace(id="r1", access_token=None) + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + with pytest.raises(AssertionError, match="access_token should not be None"): + _ = entity.token + + +def test_recipient_entity_id_and_token_success() -> None: + recipient = SimpleNamespace(id="r1", access_token="tok") + entity = _HumanInputFormRecipientEntityImpl(recipient) # type: ignore[arg-type] + assert entity.id == "r1" + assert entity.token == "tok" + + +def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + console = _DummyRecipient(id="c1", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="ctok") + webapp = _DummyRecipient( + id="w1", form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP, access_token="wtok" + ) + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] + assert entity.web_app_token == "ctok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] + assert entity.web_app_token == "wtok" + + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.web_app_token is None + + +def test_form_entity_submitted_data_parsed() -> None: + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + submitted_data='{"a": 1}', + submitted_at=naive_utc_now(), + ) + entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] + assert entity.submitted is True + assert entity.submitted_data == {"a": 1} + assert entity.rendered_content == "

x

" + assert entity.selected_action_id is None + assert entity.status == HumanInputFormStatus.WAITING + + +def test_form_record_from_models_injects_expiration_time_when_missing() -> None: + expiration = naive_utc_now() + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=False), + rendered_content="

x

", + expiration_time=expiration, + submitted_data='{"k": "v"}', + ) + record = HumanInputFormRecord.from_models(form, None) # type: ignore[arg-type] + assert record.definition.expiration_time == expiration + assert record.submitted_data == {"k": "v"} + assert record.submitted is False + + +def test_create_email_recipients_from_resolved_dedupes_and_skips_blank(monkeypatch: pytest.MonkeyPatch) -> None: + created: list[SimpleNamespace] = [] + + def fake_new(cls, form_id: str, delivery_id: str, payload: Any): # type: ignore[no-untyped-def] + recipient = SimpleNamespace( + id=f"{payload.TYPE}-{len(created)}", + form_id=form_id, + delivery_id=delivery_id, + recipient_type=payload.TYPE, + recipient_payload=payload.model_dump_json(), + access_token="tok", + ) + created.append(recipient) + return recipient + + monkeypatch.setattr("core.repositories.human_input_repository.HumanInputFormRecipient.new", classmethod(fake_new)) + + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + recipients = repo._create_email_recipients_from_resolved( # type: ignore[attr-defined] + form_id="f", + delivery_id="d", + members=[ + _WorkspaceMemberInfo(user_id="u1", email=""), + _WorkspaceMemberInfo(user_id="u2", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u3", email="a@example.com"), + ], + external_emails=["", "a@example.com", "b@example.com", "b@example.com"], + ) + assert [r.recipient_type for r in recipients] == [RecipientType.EMAIL_MEMBER, RecipientType.EMAIL_EXTERNAL] + + +def test_query_workspace_members_by_ids_empty_returns_empty() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._query_workspace_members_by_ids(session=MagicMock(), restrict_to_user_ids=["", ""]) == [] + + +def test_query_workspace_members_by_ids_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com"), ("u2", "b@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_workspace_members_by_ids(session=session, restrict_to_user_ids=["u1", "u2"]) + assert rows == [ + _WorkspaceMemberInfo(user_id="u1", email="a@example.com"), + _WorkspaceMemberInfo(user_id="u2", email="b@example.com"), + ] + + +def test_query_all_workspace_members_maps_rows() -> None: + session = _FakeSession(execute_rows=[("u1", "a@example.com")]) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + rows = repo._query_all_workspace_members(session=session) + assert rows == [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + +def test_repository_init_sets_tenant_id() -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo._tenant_id == "tenant" + + +def test_delivery_method_to_model_webapp_creates_delivery_and_recipient(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + result = repo._delivery_method_to_model( + session=MagicMock(), form_id="form-1", delivery_method=WebAppDeliveryMethod() + ) + assert result.delivery.id == "del-1" + assert result.delivery.form_id == "form-1" + assert len(result.recipients) == 1 + assert result.recipients[0].recipient_type == RecipientType.STANDALONE_WEB_APP + + +def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: "del-1") + called: dict[str, Any] = {} + + def fake_build(*, session: Any, form_id: str, delivery_id: str, recipients_config: Any) -> list[Any]: + called.update( + {"session": session, "form_id": form_id, "delivery_id": delivery_id, "recipients_config": recipients_config} + ) + return ["r"] + + monkeypatch.setattr(repo, "_build_email_recipients", fake_build) + + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + subject="s", + body="b", + ) + ) + result = repo._delivery_method_to_model(session="sess", form_id="form-1", delivery_method=method) + assert result.recipients == ["r"] + assert called["delivery_id"] == "del-1" + + +def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + monkeypatch.setattr( + repo, + "_query_all_workspace_members", + lambda *, session: [_WorkspaceMemberInfo(user_id="u", email="a@example.com")], + ) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + ) + assert recipients == ["ok"] + + +def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(monkeypatch: pytest.MonkeyPatch) -> None: + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + def fake_query(*, session: Any, restrict_to_user_ids: Sequence[str]) -> list[_WorkspaceMemberInfo]: + assert restrict_to_user_ids == ["u1"] + return [_WorkspaceMemberInfo(user_id="u1", email="a@example.com")] + + monkeypatch.setattr(repo, "_query_workspace_members_by_ids", fake_query) + monkeypatch.setattr(repo, "_create_email_recipients_from_resolved", lambda **_: ["ok"]) + recipients = repo._build_email_recipients( + session=MagicMock(), + form_id="f", + delivery_id="d", + recipients_config=EmailRecipients( + whole_workspace=False, + items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + ), + ) + assert recipients == ["ok"] + + +def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + assert repo.get_form("run", "node") is None + + form = _DummyForm( + id="f1", + workflow_run_id="run", + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = _DummyRecipient( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + ) + session = _FakeSession(scalars_results=[form, [recipient]]) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + entity = repo.get_form("run", "node") + assert entity is not None + assert entity.id == "f1" + assert entity.recipients[0].id == "r1" + assert entity.recipients[0].token == "tok" + + +def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + ids = iter(["form-id", "del-web", "del-console", "del-backstage"]) + monkeypatch.setattr("core.repositories.human_input_repository.uuidv7", lambda: next(ids)) + + session = _FakeSession() + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + + form_config = HumanInputNodeData( + title="Title", + delivery_methods=[], + form_content="hello", + inputs=[], + user_actions=[UserAction(id="submit", title="Submit")], + ) + params = FormCreateParams( + app_id="app", + workflow_execution_id="run", + node_id="node", + form_config=form_config, + rendered_content="

hello

", + delivery_methods=[WebAppDeliveryMethod()], + display_in_ui=True, + resolved_default_values={}, + form_kind=HumanInputFormKind.RUNTIME, + console_recipient_required=True, + console_creator_account_id="acc-1", + backstage_recipient_required=True, + ) + + entity = repo.create_form(params) + assert entity.id == "form-id" + assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) + # Console token should take precedence when console recipient is present. + assert entity.web_app_token == "token-console" + assert len(entity.recipients) == 3 + + +def test_submission_get_by_token_returns_none_when_missing_or_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + recipient = SimpleNamespace(form=None) + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_token("tok") is None + + +def test_submission_repository_init_no_args() -> None: + repo = HumanInputFormSubmissionRepository() + assert isinstance(repo, HumanInputFormSubmissionRepository) + + +def test_submission_get_by_token_and_get_by_form_id_success_paths(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f1", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + recipient = SimpleNamespace( + id="r1", + form_id=form.id, + recipient_type=RecipientType.STANDALONE_WEB_APP, + access_token="tok", + form=form, + ) + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_token("tok") + assert record is not None + assert record.access_token == "tok" + + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=recipient)) + repo = HumanInputFormSubmissionRepository() + record = repo.get_by_form_id_and_recipient_type(form_id=form.id, recipient_type=RecipientType.STANDALONE_WEB_APP) + assert record is not None + assert record.recipient_id == "r1" + + +def test_submission_get_by_form_id_returns_none_on_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(scalars_result=None)) + repo = HumanInputFormSubmissionRepository() + assert repo.get_by_form_id_and_recipient_type(form_id="f", recipient_type=RecipientType.CONSOLE) is None + + +def test_mark_submitted_updates_and_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + fixed_now = datetime(2024, 1, 1, 0, 0, 0) + monkeypatch.setattr("core.repositories.human_input_repository.naive_utc_now", lambda: fixed_now) + + missing_session = _FakeSession(forms={}) + _patch_session_factory(monkeypatch, missing_session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_submitted( + form_id="missing", + recipient_id=None, + selected_action_id="a", + form_data={}, + submission_user_id=None, + submission_end_user_id=None, + ) + + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=fixed_now, + ) + recipient = _DummyRecipient(id="r", form_id=form.id, recipient_type=RecipientType.CONSOLE, access_token="tok") + session = _FakeSession(forms={form.id: form}, recipients={recipient.id: recipient}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_submitted( + form_id=form.id, + recipient_id=recipient.id, + selected_action_id="approve", + form_data={"k": "v"}, + submission_user_id="u", + submission_end_user_id="eu", + ) + assert form.status == HumanInputFormStatus.SUBMITTED + assert form.submitted_at == fixed_now + assert record.submitted_data == {"k": "v"} + + +def test_mark_timeout_invalid_status_raises(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(_InvalidTimeoutStatusError, match="invalid timeout status"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.SUBMITTED) # type: ignore[arg-type] + + +def test_mark_timeout_already_timed_out_returns_record(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.TIMEOUT, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.TIMEOUT, reason="r") + assert record.status == HumanInputFormStatus.TIMEOUT + + +def test_mark_timeout_submitted_raises_form_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + status=HumanInputFormStatus.SUBMITTED, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form already submitted"): + repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + + +def test_mark_timeout_updates_fields(monkeypatch: pytest.MonkeyPatch) -> None: + form = _DummyForm( + id="f", + workflow_run_id=None, + node_id="node", + tenant_id="tenant", + app_id="app", + form_definition=_make_form_definition_json(include_expiration_time=True), + rendered_content="

x

", + expiration_time=naive_utc_now(), + selected_action_id="a", + submitted_data="{}", + submission_user_id="u", + submission_end_user_id="eu", + completed_by_recipient_id="r", + status=HumanInputFormStatus.WAITING, + ) + session = _FakeSession(forms={form.id: form}) + _patch_session_factory(monkeypatch, session) + repo = HumanInputFormSubmissionRepository() + record = repo.mark_timeout(form_id=form.id, timeout_status=HumanInputFormStatus.EXPIRED) + assert form.status == HumanInputFormStatus.EXPIRED + assert form.selected_action_id is None + assert form.submitted_data is None + assert form.submission_user_id is None + assert form.submission_end_user_id is None + assert form.completed_by_recipient_id is None + assert record.status == HumanInputFormStatus.EXPIRED + + +def test_mark_timeout_raises_when_form_missing(monkeypatch: pytest.MonkeyPatch) -> None: + _patch_session_factory(monkeypatch, _FakeSession(forms={})) + repo = HumanInputFormSubmissionRepository() + with pytest.raises(FormNotFoundError, match="form not found"): + repo.mark_timeout(form_id="missing", timeout_status=HumanInputFormStatus.TIMEOUT) diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..232ab07882 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,291 @@ +from datetime import UTC, datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from models import Account, CreatorUserRole, EndUser, WorkflowRun +from models.enums import WorkflowRunTriggeredFrom + + +@pytest.fixture +def mock_session_factory(): + """Mock SQLAlchemy session factory.""" + session_factory = MagicMock(spec=sessionmaker) + session = MagicMock() + session.get.return_value = None + session_factory.return_value.__enter__.return_value = session + return session_factory + + +@pytest.fixture +def mock_engine(): + """Mock SQLAlchemy Engine.""" + return MagicMock(spec=Engine) + + +@pytest.fixture +def mock_account(): + """Mock Account user.""" + account = MagicMock(spec=Account) + account.id = str(uuid4()) + account.current_tenant_id = str(uuid4()) + return account + + +@pytest.fixture +def mock_end_user(): + """Mock EndUser.""" + user = MagicMock(spec=EndUser) + user.id = str(uuid4()) + user.tenant_id = str(uuid4()) + return user + + +@pytest.fixture +def sample_workflow_execution(): + """Sample WorkflowExecution for testing.""" + return WorkflowExecution( + id_=str(uuid4()), + workflow_id=str(uuid4()), + workflow_type=WorkflowType.WORKFLOW, + workflow_version="1.0", + graph={"nodes": [], "edges": []}, + inputs={"input1": "value1"}, + outputs={"output1": "result1"}, + status=WorkflowExecutionStatus.SUCCEEDED, + error_message="", + total_tokens=100, + total_steps=5, + exceptions_count=0, + started_at=datetime.now(UTC), + finished_at=datetime.now(UTC), + ) + + +class TestSQLAlchemyWorkflowExecutionRepository: + def test_init_with_sessionmaker(self, mock_session_factory, mock_account): + app_id = "test_app_id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=app_id, triggered_from=triggered_from + ) + + assert repo._session_factory == mock_session_factory + assert repo._tenant_id == mock_account.current_tenant_id + assert repo._app_id == app_id + assert repo._triggered_from == triggered_from + assert repo._creator_user_id == mock_account.id + assert repo._creator_user_role == CreatorUserRole.ACCOUNT + + def test_init_with_engine(self, mock_engine, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_engine, + user=mock_account, + app_id="test_app_id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + assert isinstance(repo._session_factory, sessionmaker) + assert repo._session_factory.kw["bind"] == mock_engine + + def test_init_invalid_session_factory(self, mock_account): + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowExecutionRepository( + session_factory="invalid", user=mock_account, app_id=None, triggered_from=None + ) + + def test_init_no_tenant_id(self, mock_session_factory): + user = MagicMock(spec=Account) + user.current_tenant_id = None + + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=user, app_id=None, triggered_from=None + ) + + def test_init_with_end_user(self, mock_session_factory, mock_end_user): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_end_user, app_id=None, triggered_from=None + ) + assert repo._tenant_id == mock_end_user.tenant_id + assert repo._creator_user_role == CreatorUserRole.END_USER + + def test_to_domain_model(self, mock_session_factory, mock_account): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + db_model = MagicMock(spec=WorkflowRun) + db_model.id = str(uuid4()) + db_model.workflow_id = str(uuid4()) + db_model.type = "workflow" + db_model.version = "1.0" + db_model.inputs_dict = {"in": "val"} + db_model.outputs_dict = {"out": "val"} + db_model.graph_dict = {"nodes": []} + db_model.status = "succeeded" + db_model.error = "some error" + db_model.total_tokens = 50 + db_model.total_steps = 3 + db_model.exceptions_count = 1 + db_model.created_at = datetime.now(UTC) + db_model.finished_at = datetime.now(UTC) + + domain_model = repo._to_domain_model(db_model) + + assert domain_model.id_ == db_model.id + assert domain_model.workflow_id == db_model.workflow_id + assert domain_model.status == WorkflowExecutionStatus.SUCCEEDED + assert domain_model.inputs == db_model.inputs_dict + assert domain_model.error_message == "some error" + + def test_to_db_model(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + + # Make elapsed time deterministic to avoid flaky tests + sample_workflow_execution.started_at = datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC) + sample_workflow_execution.finished_at = datetime(2023, 1, 1, 0, 0, 10, tzinfo=UTC) + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.id == sample_workflow_execution.id_ + assert db_model.tenant_id == repo._tenant_id + assert db_model.app_id == "test_app" + assert db_model.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING + assert db_model.status == sample_workflow_execution.status.value + assert db_model.total_tokens == sample_workflow_execution.total_tokens + assert db_model.elapsed_time == 10.0 + + def test_to_db_model_edge_cases(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Test with empty/None fields + sample_workflow_execution.graph = None + sample_workflow_execution.inputs = None + sample_workflow_execution.outputs = None + sample_workflow_execution.error_message = None + sample_workflow_execution.finished_at = None + + db_model = repo._to_db_model(sample_workflow_execution) + + assert db_model.graph is None + assert db_model.inputs is None + assert db_model.outputs is None + assert db_model.error is None + assert db_model.elapsed_time == 0 + + def test_to_db_model_app_id_none(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id=None, + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + db_model = repo._to_db_model(sample_workflow_execution) + assert not hasattr(db_model, "app_id") or db_model.app_id is None + assert db_model.tenant_id == repo._tenant_id + + def test_to_db_model_missing_context(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, user=mock_account, app_id=None, triggered_from=None + ) + + # Test triggered_from missing + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(sample_workflow_execution) + + repo._triggered_from = WorkflowRunTriggeredFrom.APP_RUN + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(sample_workflow_execution) + + repo._creator_user_id = "some_id" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(sample_workflow_execution) + + def test_save(self, mock_session_factory, mock_account, sample_workflow_execution): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + repo.save(sample_workflow_execution) + + session = mock_session_factory.return_value.__enter__.return_value + session.merge.assert_called_once() + session.commit.assert_called_once() + + # Check cache + assert sample_workflow_execution.id_ in repo._execution_cache + cached_model = repo._execution_cache[sample_workflow_execution.id_] + assert cached_model.id == sample_workflow_execution.id_ + + def test_save_uses_execution_started_at_when_record_does_not_exist( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + started_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + sample_workflow_execution.started_at = started_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = None + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == started_at + session.commit.assert_called_once() + + def test_save_preserves_existing_created_at_when_record_already_exists( + self, mock_session_factory, mock_account, sample_workflow_execution + ): + repo = SQLAlchemyWorkflowExecutionRepository( + session_factory=mock_session_factory, + user=mock_account, + app_id="test_app", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + execution_id = sample_workflow_execution.id_ + existing_created_at = datetime(2026, 1, 1, 12, 0, 0, tzinfo=UTC) + + existing_run = WorkflowRun() + existing_run.id = execution_id + existing_run.tenant_id = repo._tenant_id + existing_run.created_at = existing_created_at + + session = mock_session_factory.return_value.__enter__.return_value + session.get.return_value = existing_run + + sample_workflow_execution.started_at = datetime(2026, 1, 1, 12, 30, 0, tzinfo=UTC) + + repo.save(sample_workflow_execution) + + saved_model = session.merge.call_args.args[0] + assert saved_model.created_at == existing_created_at + session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py new file mode 100644 index 0000000000..73de15e2cf --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -0,0 +1,772 @@ +from __future__ import annotations + +import json +import logging +from collections.abc import Mapping +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock + +import psycopg2.errors +import pytest +from sqlalchemy import Engine, create_engine +from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories.sqlalchemy_workflow_node_execution_repository import ( + SQLAlchemyWorkflowNodeExecutionRepository, + _deterministic_json_dump, + _filter_by_offload_type, + _find_first, + _replace_or_append_offload, +) +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from models import Account, EndUser +from models.enums import ExecutionOffLoadType +from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom + + +def _mock_account(*, tenant_id: str = "tenant", user_id: str = "user") -> Account: + user = Mock(spec=Account) + user.id = user_id + user.current_tenant_id = tenant_id + return user + + +def _mock_end_user(*, tenant_id: str = "tenant", user_id: str = "user") -> EndUser: + user = Mock(spec=EndUser) + user.id = user_id + user.tenant_id = tenant_id + return user + + +def _execution( + *, + execution_id: str = "exec-id", + node_execution_id: str = "node-exec-id", + workflow_run_id: str = "run-id", + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.SUCCEEDED, + inputs: Mapping[str, Any] | None = None, + outputs: Mapping[str, Any] | None = None, + process_data: Mapping[str, Any] | None = None, + metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, +) -> WorkflowNodeExecution: + return WorkflowNodeExecution( + id=execution_id, + node_execution_id=node_execution_id, + workflow_id="workflow-id", + workflow_execution_id=workflow_run_id, + index=1, + predecessor_node_id=None, + node_id="node-id", + node_type=BuiltinNodeTypes.LLM, + title="Title", + inputs=inputs, + outputs=outputs, + process_data=process_data, + status=status, + error=None, + elapsed_time=1.0, + metadata=metadata, + created_at=datetime.now(UTC), + finished_at=None, + ) + + +class _SessionCtx: + def __init__(self, session: Any): + self._session = session + + def __enter__(self) -> Any: + return self._session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _session_factory(session: Any) -> sessionmaker: + factory = Mock(spec=sessionmaker) + factory.return_value = _SessionCtx(session) + return factory + + +def test_init_accepts_engine_and_sessionmaker_and_sets_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + engine: Engine = create_engine("sqlite:///:memory:") + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=engine, + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert isinstance(repo._session_factory, sessionmaker) + + sm = Mock(spec=sessionmaker) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=sm, + user=_mock_end_user(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, + ) + assert repo._creator_user_role.value == "end_user" + + +def test_init_rejects_invalid_session_factory_type(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + with pytest.raises(ValueError, match="Invalid session_factory type"): + SQLAlchemyWorkflowNodeExecutionRepository( # type: ignore[arg-type] + session_factory=object(), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_init_requires_tenant_id(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + user = _mock_account() + user.current_tenant_id = None + with pytest.raises(ValueError, match="User must have a tenant_id"): + SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=user, + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + +def test_create_truncator_uses_config(monkeypatch: pytest.MonkeyPatch) -> None: + created: dict[str, Any] = {} + + class FakeTruncator: + def __init__(self, *, max_size_bytes: int, array_element_limit: int, string_length_limit: int): + created.update( + { + "max_size_bytes": max_size_bytes, + "array_element_limit": array_element_limit, + "string_length_limit": string_length_limit, + } + ) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.VariableTruncator", + FakeTruncator, + ) + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + _ = repo._create_truncator() + assert created["max_size_bytes"] == dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + + +def test_helpers_find_first_and_replace_or_append_and_filter() -> None: + assert _deterministic_json_dump({"b": 1, "a": 2}) == '{"a": 2, "b": 1}' + assert _find_first([], lambda _: True) is None + assert _find_first([1, 2, 3], lambda x: x > 1) == 2 + + off1 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off2 = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + assert _find_first([off1, off2], _filter_by_offload_type(ExecutionOffLoadType.OUTPUTS)) is off2 + + replaced = _replace_or_append_offload([off1, off2], WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)) + assert len(replaced) == 2 + assert [o.type_ for o in replaced] == [ExecutionOffLoadType.OUTPUTS, ExecutionOffLoadType.INPUTS] + + +def test_to_db_model_requires_constructor_context(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution(inputs={"b": 1, "a": 2}, metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 1}) + + # Happy path: deterministic json dump should be sorted + db_model = repo._to_db_model(execution) + assert json.loads(db_model.inputs or "{}") == {"a": 2, "b": 1} + assert json.loads(db_model.execution_metadata or "{}")["total_tokens"] == 1 + + repo._triggered_from = None + with pytest.raises(ValueError, match="triggered_from is required"): + repo._to_db_model(execution) + + +def test_to_db_model_requires_creator_user_id_and_role(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + execution = _execution() + db_model = repo._to_db_model(execution) + assert db_model.app_id == "app" + + repo._creator_user_id = None + with pytest.raises(ValueError, match="created_by is required"): + repo._to_db_model(execution) + + repo._creator_user_id = "user" + repo._creator_user_role = None + with pytest.raises(ValueError, match="created_by_role is required"): + repo._to_db_model(execution) + + +def test_is_duplicate_key_error_and_regenerate_id( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + assert repo._is_duplicate_key_error(duplicate_error) is True + assert repo._is_duplicate_key_error(IntegrityError("other", params=None, orig=None)) is False + + execution = _execution(execution_id="old-id") + db_model = WorkflowNodeExecutionModel() + db_model.id = "old-id" + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + caplog.set_level(logging.WARNING) + repo._regenerate_id_on_duplicate(execution, db_model) + assert execution.id == "new-id" + assert db_model.id == "new-id" + assert any("Duplicate key conflict" in r.message for r in caplog.records) + + +def test_persist_to_database_updates_existing_and_inserts_new(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id1" + db_model.node_execution_id = "node1" + db_model.foo = "bar" # type: ignore[attr-defined] + db_model.__dict__["_private"] = "x" + + existing = SimpleNamespace() + session.get.return_value = existing + repo._persist_to_database(db_model) + assert existing.foo == "bar" + session.add.assert_not_called() + assert repo._node_execution_cache["node1"] is db_model + + session.reset_mock() + session.get.return_value = None + repo._node_execution_cache.clear() + repo._persist_to_database(db_model) + session.add.assert_called_once_with(db_model) + assert repo._node_execution_cache["node1"] is db_model + + +def test_truncate_and_upload_returns_none_when_no_values_or_not_truncated(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + assert repo._truncate_and_upload(None, "e", ExecutionOffLoadType.INPUTS) is None + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return value, False + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + assert repo._truncate_and_upload({"a": 1}, "e", ExecutionOffLoadType.INPUTS) is None + + +def test_truncate_and_upload_uploads_and_builds_offload(monkeypatch: pytest.MonkeyPatch) -> None: + uploaded: dict[str, Any] = {} + + class FakeFileService: + def upload_file(self, *, filename: str, content: bytes, mimetype: str, user: Any): # type: ignore[no-untyped-def] + uploaded.update({"filename": filename, "content": content, "mimetype": mimetype, "user": user}) + return SimpleNamespace(id="file-id", key="file-key") + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", lambda *_: FakeFileService() + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "offload-id") + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + class FakeTruncator: + def truncate_variable_mapping(self, value: Any): # type: ignore[no-untyped-def] + return {"truncated": True}, True + + monkeypatch.setattr(repo, "_create_truncator", lambda: FakeTruncator()) + + result = repo._truncate_and_upload({"a": 1}, "exec", ExecutionOffLoadType.INPUTS) + assert result is not None + assert result.truncated_value == {"truncated": True} + assert uploaded["filename"].startswith("node_execution_exec_inputs.json") + assert result.offload.file_id == "file-id" + assert result.offload.type_ == ExecutionOffLoadType.INPUTS + + +def test_to_domain_model_loads_offloaded_files(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"trunc": "i"}) + db_model.process_data = json.dumps({"trunc": "p"}) + db_model.outputs = json.dumps({"trunc": "o"}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = json.dumps({"total_tokens": 3}) + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + + off_in = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS) + off_out = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS) + off_proc = WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA) + off_in.file = SimpleNamespace(key="k-in") + off_out.file = SimpleNamespace(key="k-out") + off_proc.file = SimpleNamespace(key="k-proc") + db_model.offload_data = [off_out, off_in, off_proc] + + def fake_load(key: str) -> bytes: + return json.dumps({"full": key}).encode() + + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.storage.load", fake_load) + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"full": "k-in"} + assert domain.outputs == {"full": "k-out"} + assert domain.process_data == {"full": "k-proc"} + assert domain.get_truncated_inputs() == {"trunc": "i"} + assert domain.get_truncated_outputs() == {"trunc": "o"} + assert domain.get_truncated_process_data() == {"trunc": "p"} + + +def test_to_domain_model_returns_early_when_no_offload_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_model = WorkflowNodeExecutionModel() + db_model.id = "id" + db_model.node_execution_id = "node-exec" + db_model.workflow_id = "wf" + db_model.workflow_run_id = "run" + db_model.index = 1 + db_model.predecessor_node_id = None + db_model.node_id = "node" + db_model.node_type = BuiltinNodeTypes.LLM + db_model.title = "t" + db_model.inputs = json.dumps({"i": 1}) + db_model.process_data = json.dumps({"p": 2}) + db_model.outputs = json.dumps({"o": 3}) + db_model.status = WorkflowNodeExecutionStatus.SUCCEEDED + db_model.error = None + db_model.elapsed_time = 0.1 + db_model.execution_metadata = "{}" + db_model.created_at = datetime.now(UTC) + db_model.finished_at = None + db_model.offload_data = [] + + domain = repo._to_domain_model(db_model) + assert domain.inputs == {"i": 1} + assert domain.outputs == {"o": 3} + + +def test_json_encode_uses_runtime_converter(monkeypatch: pytest.MonkeyPatch) -> None: + class FakeConverter: + def to_json_encodable(self, values: Mapping[str, Any]) -> Mapping[str, Any]: + return {"wrapped": values["a"]} + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowRuntimeTypeConverter", + FakeConverter, + ) + assert SQLAlchemyWorkflowNodeExecutionRepository._json_encode({"a": 1}) == '{"wrapped": 1}' + + +def test_save_execution_data_handles_existing_db_model_and_truncation(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = SimpleNamespace( + id="id", + offload_data=[WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS)], + inputs=None, + outputs=None, + process_data=None, + ) + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + trunc_result = SimpleNamespace( + truncated_value={"trunc": True}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.INPUTS, file_id="f1"), + ) + monkeypatch.setattr( + repo, "_truncate_and_upload", lambda values, *_args, **_kwargs: trunc_result if values == {"a": 1} else None + ) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + # Inputs should be truncated, outputs/process_data encoded directly + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.inputs) == {"trunc": True} + assert json.loads(db_model.outputs) == {"b": 2} + assert json.loads(db_model.process_data) == {"c": 3} + assert any(off.type_ == ExecutionOffLoadType.INPUTS for off in db_model.offload_data) + assert execution.get_truncated_inputs() == {"trunc": True} + + +def test_save_execution_data_truncates_outputs_and_process_data(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + existing = SimpleNamespace( + id="id", + offload_data=[], + inputs=None, + outputs=None, + process_data=None, + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = existing + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}, outputs={"b": 2}, process_data={"c": 3}) + + def trunc(values: Mapping[str, Any], *_args: Any, **_kwargs: Any) -> Any: + if values == {"b": 2}: + return SimpleNamespace( + truncated_value={"b": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.OUTPUTS, file_id="f2"), + ) + if values == {"c": 3}: + return SimpleNamespace( + truncated_value={"c": "trunc"}, + offload=WorkflowNodeExecutionOffload(type_=ExecutionOffLoadType.PROCESS_DATA, file_id="f3"), + ) + return None + + monkeypatch.setattr(repo, "_truncate_and_upload", trunc) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values, sort_keys=True)) + + repo.save_execution_data(execution) + db_model = session.merge.call_args.args[0] + assert json.loads(db_model.outputs) == {"b": "trunc"} + assert json.loads(db_model.process_data) == {"c": "trunc"} + assert execution.get_truncated_outputs() == {"b": "trunc"} + assert execution.get_truncated_process_data() == {"c": "trunc"} + + +def test_save_execution_data_handles_missing_db_model(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + session = MagicMock() + session.execute.return_value.scalars.return_value.first.return_value = None + session.merge = Mock() + session.flush = Mock() + session.begin.return_value.__enter__ = Mock(return_value=session) + session.begin.return_value.__exit__ = Mock(return_value=None) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(inputs={"a": 1}) + fake_db_model = SimpleNamespace(id=execution.id, offload_data=[], inputs=None, outputs=None, process_data=None) + monkeypatch.setattr(repo, "_to_db_model", lambda *_: fake_db_model) + monkeypatch.setattr(repo, "_truncate_and_upload", lambda *_args, **_kwargs: None) + monkeypatch.setattr(repo, "_json_encode", lambda values: json.dumps(values)) + + repo.save_execution_data(execution) + merged = session.merge.call_args.args[0] + assert merged.inputs == '{"a": 1}' + + +def test_save_retries_duplicate_and_logs_non_duplicate( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + execution = _execution(execution_id="id") + unique = Mock(spec=psycopg2.errors.UniqueViolation) + duplicate_error = IntegrityError("dup", params=None, orig=unique) + other_error = IntegrityError("other", params=None, orig=None) + + calls = {"n": 0} + + def persist(_db_model: Any) -> None: + calls["n"] += 1 + if calls["n"] == 1: + raise duplicate_error + + monkeypatch.setattr(repo, "_persist_to_database", persist) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.uuidv7", lambda: "new-id") + repo.save(execution) + assert execution.id == "new-id" + assert repo._node_execution_cache[execution.node_execution_id] is not None + + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(other_error)) + with pytest.raises(IntegrityError): + repo.save(_execution(execution_id="id2", node_execution_id="node2")) + assert any("Non-duplicate key integrity error" in r.message for r in caplog.records) + + +def test_save_logs_and_reraises_on_unexpected_error( + monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture +) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + caplog.set_level(logging.ERROR) + monkeypatch.setattr(repo, "_persist_to_database", lambda _db: (_ for _ in ()).throw(RuntimeError("boom"))) + with pytest.raises(RuntimeError, match="boom"): + repo.save(_execution(execution_id="id3", node_execution_id="node3")) + assert any("Failed to save workflow node execution" in r.message for r in caplog.records) + + +def test_get_db_models_by_workflow_run_orders_and_caches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def __init__(self) -> None: + self.where_calls = 0 + self.order_by_args: tuple[Any, ...] | None = None + + def where(self, *_args: Any) -> FakeStmt: + self.where_calls += 1 + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.order_by_args = args + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + model1 = SimpleNamespace(node_execution_id="n1") + model2 = SimpleNamespace(node_execution_id=None) + session = MagicMock() + session.scalars.return_value.all.return_value = [model1, model2] + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id="app", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + order = OrderConfig(order_by=["index", "missing"], order_direction="desc") + db_models = repo.get_db_models_by_workflow_run("run", order) + assert db_models == [model1, model2] + assert repo._node_execution_cache["n1"] is model1 + assert stmt.order_by_args is not None + + +def test_get_db_models_by_workflow_run_uses_asc_order(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + class FakeStmt: + def where(self, *_args: Any) -> FakeStmt: + return self + + def order_by(self, *args: Any) -> FakeStmt: + self.args = args # type: ignore[attr-defined] + return self + + stmt = FakeStmt() + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.WorkflowNodeExecutionModel.preload_offload_data_and_files", + lambda _q: stmt, + ) + monkeypatch.setattr("core.repositories.sqlalchemy_workflow_node_execution_repository.select", lambda *_: "select") + + session = MagicMock() + session.scalars.return_value.all.return_value = [] + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=_session_factory(session), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + repo.get_db_models_by_workflow_run("run", OrderConfig(order_by=["index"], order_direction="asc")) + + +def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.FileService", + lambda *_: SimpleNamespace(upload_file=Mock()), + ) + + repo = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=Mock(spec=sessionmaker), + user=_mock_account(), + app_id=None, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + db_models = [SimpleNamespace(id="db1"), SimpleNamespace(id="db2")] + monkeypatch.setattr(repo, "get_db_models_by_workflow_run", lambda *_args, **_kwargs: db_models) + monkeypatch.setattr(repo, "_to_domain_model", lambda m: f"domain:{m.id}") + + class FakeExecutor: + def __enter__(self) -> FakeExecutor: + return self + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + def map(self, func, items, timeout: int): # type: ignore[no-untyped-def] + assert timeout == 30 + return list(map(func, items)) + + monkeypatch.setattr( + "core.repositories.sqlalchemy_workflow_node_execution_repository.ThreadPoolExecutor", + lambda max_workers: FakeExecutor(), + ) + + result = repo.get_by_workflow_run("run", order_config=None) + assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 07f28f162a..456c3dde12 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -10,11 +10,11 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom @@ -70,7 +70,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, @@ -108,7 +108,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -153,7 +153,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, @@ -195,7 +195,7 @@ class TestWorkflowNodeExecutionConflictHandling: workflow_execution_id="test-workflow-execution-id", node_execution_id="test-node-execution-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", index=1, status=WorkflowNodeExecutionStatus.RUNNING, diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index 485be90eae..eeab81a178 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -13,14 +13,15 @@ from unittest.mock import MagicMock from sqlalchemy import Engine +from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import ( +from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload @@ -41,7 +42,7 @@ class TruncationTestCase: def create_test_cases() -> list[TruncationTestCase]: """Create test cases for different truncation scenarios.""" # Create large data that will definitely exceed the threshold (10KB) - large_data = {"data": "x" * (TRUNCATION_SIZE_THRESHOLD + 1000)} + large_data = {"data": "x" * (dify_config.WORKFLOW_VARIABLE_TRUNCATION_MAX_SIZE + 1000)} small_data = {"data": "small"} return [ @@ -101,7 +102,7 @@ def create_workflow_node_execution( workflow_execution_id="test-workflow-execution-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", inputs=inputs, outputs=outputs, @@ -145,7 +146,7 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryTruncation: db_model.index = 1 db_model.predecessor_node_id = None db_model.node_id = "node-id" - db_model.node_type = NodeType.LLM + db_model.node_type = BuiltinNodeTypes.LLM db_model.title = "Test Node" db_model.inputs = json.dumps({"value": "inputs"}) db_model.process_data = json.dumps({"value": "process_data"}) diff --git a/api/tests/unit_tests/core/schemas/test_registry.py b/api/tests/unit_tests/core/schemas/test_registry.py new file mode 100644 index 0000000000..5749e72eb0 --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_registry.py @@ -0,0 +1,137 @@ +import json +from unittest.mock import patch + +from core.schemas.registry import SchemaRegistry + + +class TestSchemaRegistry: + def test_initialization(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + registry = SchemaRegistry(str(base_dir)) + assert registry.base_dir == base_dir + assert registry.versions == {} + assert registry.metadata == {} + + def test_default_registry_singleton(self): + registry1 = SchemaRegistry.default_registry() + registry2 = SchemaRegistry.default_registry() + assert registry1 is registry2 + assert isinstance(registry1, SchemaRegistry) + + def test_load_all_versions_non_existent_dir(self, tmp_path): + base_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(base_dir)) + registry.load_all_versions() + assert registry.versions == {} + + def test_load_all_versions_filtering(self, tmp_path): + base_dir = tmp_path / "schemas" + base_dir.mkdir() + (base_dir / "not_a_version_dir").mkdir() + (base_dir / "v1").mkdir() + (base_dir / "some_file.txt").write_text("content") + + registry = SchemaRegistry(str(base_dir)) + with patch.object(registry, "_load_version_dir") as mock_load: + registry.load_all_versions() + mock_load.assert_called_once() + assert mock_load.call_args[0][0] == "v1" + + def test_load_version_dir_filtering(self, tmp_path): + version_dir = tmp_path / "v1" + version_dir.mkdir() + (version_dir / "schema1.json").write_text("{}") + (version_dir / "not_a_schema.txt").write_text("content") + + registry = SchemaRegistry(str(tmp_path)) + with patch.object(registry, "_load_schema") as mock_load: + registry._load_version_dir("v1", version_dir) + mock_load.assert_called_once() + assert mock_load.call_args[0][1] == "schema1" + + def test_load_version_dir_non_existent(self, tmp_path): + version_dir = tmp_path / "non_existent" + registry = SchemaRegistry(str(tmp_path)) + registry._load_version_dir("v1", version_dir) + assert "v1" not in registry.versions + + def test_load_schema_success(self, tmp_path): + schema_path = tmp_path / "test.json" + schema_content = {"title": "Test Schema", "description": "A test schema"} + schema_path.write_text(json.dumps(schema_content)) + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "test", schema_path) + + assert registry.versions["v1"]["test"] == schema_content + uri = "https://dify.ai/schemas/v1/test.json" + assert registry.metadata[uri]["title"] == "Test Schema" + assert registry.metadata[uri]["version"] == "v1" + + def test_load_schema_invalid_json(self, tmp_path, caplog): + schema_path = tmp_path / "invalid.json" + schema_path.write_text("invalid json") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + registry._load_schema("v1", "invalid", schema_path) + + assert "Failed to load schema v1/invalid" in caplog.text + + def test_load_schema_os_error(self, tmp_path, caplog): + schema_path = tmp_path / "error.json" + schema_path.write_text("{}") + + registry = SchemaRegistry(str(tmp_path)) + registry.versions["v1"] = {} + + with patch("builtins.open", side_effect=OSError("Read error")): + registry._load_schema("v1", "error", schema_path) + + assert "Failed to load schema v1/error" in caplog.text + + def test_get_schema(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"type": "object"}}} + + # Valid URI + assert registry.get_schema("https://dify.ai/schemas/v1/test.json") == {"type": "object"} + + # Invalid URI + assert registry.get_schema("invalid-uri") is None + + # Missing version + assert registry.get_schema("https://dify.ai/schemas/v2/test.json") is None + + def test_list_versions(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v2": {}, "v1": {}} + assert registry.list_versions() == ["v1", "v2"] + + def test_list_schemas(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"b": {}, "a": {}}} + + assert registry.list_schemas("v1") == ["a", "b"] + assert registry.list_schemas("v2") == [] + + def test_get_all_schemas_for_version(self): + registry = SchemaRegistry("/tmp") + registry.versions = {"v1": {"test": {"title": "Test Label"}}} + + results = registry.get_all_schemas_for_version("v1") + assert len(results) == 1 + assert results[0]["name"] == "test" + assert results[0]["label"] == "Test Label" + assert results[0]["schema"] == {"title": "Test Label"} + + # Default label if title missing + registry.versions["v1"]["no_title"] = {} + results = registry.get_all_schemas_for_version("v1") + item = next(r for r in results if r["name"] == "no_title") + assert item["label"] == "no_title" + + # Empty if version missing + assert registry.get_all_schemas_for_version("v2") == [] diff --git a/api/tests/unit_tests/core/schemas/test_schema_manager.py b/api/tests/unit_tests/core/schemas/test_schema_manager.py new file mode 100644 index 0000000000..cb07340c6d --- /dev/null +++ b/api/tests/unit_tests/core/schemas/test_schema_manager.py @@ -0,0 +1,80 @@ +from unittest.mock import MagicMock, patch + +from core.schemas.registry import SchemaRegistry +from core.schemas.schema_manager import SchemaManager + + +def test_init_with_provided_registry(): + mock_registry = MagicMock(spec=SchemaRegistry) + manager = SchemaManager(registry=mock_registry) + assert manager.registry == mock_registry + + +@patch("core.schemas.schema_manager.SchemaRegistry.default_registry") +def test_init_with_default_registry(mock_default_registry): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_default_registry.return_value = mock_registry + + manager = SchemaManager() + + mock_default_registry.assert_called_once() + assert manager.registry == mock_registry + + +def test_get_all_schema_definitions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_definitions = [{"name": "schema1", "schema": {}}, {"name": "schema2", "schema": {}}] + mock_registry.get_all_schemas_for_version.return_value = expected_definitions + + manager = SchemaManager(registry=mock_registry) + result = manager.get_all_schema_definitions(version="v2") + + mock_registry.get_all_schemas_for_version.assert_called_once_with("v2") + assert result == expected_definitions + + +def test_get_schema_by_name_success(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_schema = {"type": "object"} + mock_registry.get_schema.return_value = mock_schema + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("my_schema", version="v1") + + expected_uri = "https://dify.ai/schemas/v1/my_schema.json" + mock_registry.get_schema.assert_called_once_with(expected_uri) + assert result == {"name": "my_schema", "schema": mock_schema} + + +def test_get_schema_by_name_not_found(): + mock_registry = MagicMock(spec=SchemaRegistry) + mock_registry.get_schema.return_value = None + + manager = SchemaManager(registry=mock_registry) + result = manager.get_schema_by_name("non_existent", version="v1") + + assert result is None + + +def test_list_available_schemas(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_schemas = ["schema1", "schema2"] + mock_registry.list_schemas.return_value = expected_schemas + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_schemas(version="v1") + + mock_registry.list_schemas.assert_called_once_with("v1") + assert result == expected_schemas + + +def test_list_available_versions(): + mock_registry = MagicMock(spec=SchemaRegistry) + expected_versions = ["v1", "v2"] + mock_registry.list_versions.return_value = expected_versions + + manager = SchemaManager(registry=mock_registry) + result = manager.list_available_versions() + + mock_registry.list_versions.assert_called_once() + assert result == expected_versions diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index b9c5fbd7d8..251d6fd25e 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,6 @@ import json -from core.workflow.file import File, FileTransferMethod, FileType, FileUploadConfig +from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 5a7547e85c..92e4b58473 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -6,7 +6,7 @@ from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 636fac7a40..90ed1647aa 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -12,9 +12,9 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ( +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, CredentialFormSchema, FormOption, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 3163d53b87..69567c54eb 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,32 +1,34 @@ +from unittest.mock import Mock, PropertyMock, patch + import pytest -from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings -from core.model_runtime.entities.model_entities import ModelType from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting @pytest.fixture -def mock_provider_entity(mocker: MockerFixture): - mock_entity = mocker.Mock() +def mock_provider_entity(): + mock_entity = Mock() mock_entity.provider = "openai" mock_entity.configurate_methods = ["predefined-model"] mock_entity.supported_model_types = [ModelType.LLM] # Use PropertyMock to ensure credential_form_schemas is iterable - provider_credential_schema = mocker.Mock() - type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[]) + provider_credential_schema = Mock() + type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[]) mock_entity.provider_credential_schema = provider_credential_schema - model_credential_schema = mocker.Mock() - type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[]) + model_credential_schema = Mock() + type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[]) mock_entity.model_credential_schema = model_credential_schema return mock_entity -def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): load_balancing_model_configs[0].id = "id1" load_balancing_model_configs[1].id = "id2" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings_only_one_lb(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent ] load_balancing_model_configs[0].id = "id1" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings_lb_disabled(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent load_balancing_model_configs[0].id = "id1" load_balancing_model_configs[1].id = "id2" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 + + +def test_get_default_model_uses_first_available_active_model(): + mock_session = Mock() + mock_session.scalar.return_value = None + + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")), + Mock(model="gpt-4", provider=Mock(provider="openai")), + ] + + manager = ProviderManager() + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is not None + assert result.model == "gpt-3.5-turbo" + assert result.provider.provider == "openai" + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_session.add.assert_called_once() + saved_default_model = mock_session.add.call_args.args[0] + assert saved_default_model.model_name == "gpt-3.5-turbo" + assert saved_default_model.provider_name == "openai" + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/__init__.py b/api/tests/unit_tests/core/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py new file mode 100644 index 0000000000..f123f60a34 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + +class _BuiltinDummyTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +def _build_tool() -> _BuiltinDummyTool: + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) + + +def test_builtin_tool_fork_and_provider_type(): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, _BuiltinDummyTool) + assert forked.runtime.tenant_id == "tenant-2" + assert tool.tool_provider_type() == ToolProviderType.BUILT_IN + + +def test_invoke_model_calls_model_invocation_utils_invoke(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: + assert ( + tool.invoke_model( + user_id="u1", + prompt_messages=[UserPromptMessage(content="hello")], + stop=[], + ) + == "result" + ) + mock_invoke.assert_called_once() + + +def test_get_max_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + assert tool.get_max_tokens() == 4096 + + +def test_get_prompt_tokens_returns_value(): + tool = _build_tool() + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + +def test_runtime_none_raises(): + tool = _build_tool() + tool.runtime = None + with pytest.raises(ValueError, match="runtime is required"): + tool.get_max_tokens() + with pytest.raises(ValueError, match="runtime is required"): + tool.get_prompt_tokens([UserPromptMessage(content="hello")]) + + +def test_builtin_tool_summary_short_and_long_content_paths(): + tool = _build_tool() + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=100): + with patch.object(_BuiltinDummyTool, "get_prompt_tokens", return_value=10): + assert tool.summary(user_id="u1", content="short") == "short" + + with patch.object(_BuiltinDummyTool, "get_max_tokens", return_value=10): + with patch.object( + _BuiltinDummyTool, + "get_prompt_tokens", + side_effect=lambda prompt_messages: len(prompt_messages[-1].content), + ): + with patch.object( + _BuiltinDummyTool, + "invoke_model", + return_value=SimpleNamespace(message=SimpleNamespace(content="S")), + ): + result = tool.summary(user_id="u1", content="x" * 30 + "\n" + "y" * 5) + + assert result + assert "S" in result diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py new file mode 100644 index 0000000000..ad6d5906ae --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_provider.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderEntity, ToolProviderType +from core.tools.errors import ToolProviderNotFoundError + + +class _FakeBuiltinTool(BuiltinTool): + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _ConcreteBuiltinProvider(BuiltinToolProviderController): + last_validation: tuple[str, dict[str, Any]] | None = None + + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + self.last_validation = (user_id, credentials) + + +def _provider_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": ["utilities"], + }, + "credentials_for_provider": { + "api_key": { + "type": "secret-input", + "required": True, + } + }, + "oauth_schema": { + "client_schema": [ + { + "name": "client_id", + "type": "text-input", + } + ], + "credentials_schema": [ + { + "name": "access_token", + "type": "secret-input", + } + ], + }, + } + + +def _tool_yaml() -> dict[str, Any]: + return { + "identity": { + "author": "Dify", + "name": "tool_a", + "label": {"en_US": "Tool A"}, + }, + "parameters": [], + } + + +def test_builtin_tool_provider_init_load_tools_and_basic_accessors(monkeypatch): + yaml_payloads = [_provider_yaml(), _tool_yaml()] + + def _load_yaml(*args, **kwargs): + return yaml_payloads.pop(0) + + monkeypatch.setattr("core.tools.builtin_tool.provider.load_yaml_file_cached", _load_yaml) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.listdir", + lambda *args, **kwargs: ["tool_a.yaml", "__init__.py", "readme.md"], + ) + monkeypatch.setattr( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + lambda *args, **kwargs: _FakeBuiltinTool, + ) + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema() + assert provider.get_tools() + assert provider.get_tool("tool_a") is not None + assert provider.get_tool("missing") is None + assert provider.provider_type == ToolProviderType.BUILT_IN + assert provider.tool_labels == ["utilities"] + assert provider.need_credentials is True + + oauth_schema = provider.get_credentials_schema_by_type(CredentialType.OAUTH2) + assert len(oauth_schema) == 1 + api_schema = provider.get_credentials_schema_by_type(CredentialType.API_KEY) + assert len(api_schema) == 1 + assert provider.get_oauth_client_schema()[0].name == "client_id" + assert set(provider.get_supported_credential_types()) == {CredentialType.API_KEY, CredentialType.OAUTH2} + + +def test_builtin_tool_provider_invalid_credential_type_raises(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + with pytest.raises(ValueError, match="Invalid credential type: invalid"): + provider.get_credentials_schema_by_type("invalid") + + +def test_builtin_tool_provider_validate_credentials_delegates(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + provider.validate_credentials("user-1", {"api_key": "secret"}) + assert provider.last_validation == ("user-1", {"api_key": "secret"}) + + +def test_builtin_tool_provider_unauthorized_schema_is_empty(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + assert provider.get_credentials_schema_by_type(CredentialType.UNAUTHORIZED) == [] + + +def test_builtin_tool_provider_init_raises_when_provider_yaml_missing(): + with patch("core.tools.builtin_tool.provider.load_yaml_file_cached", side_effect=RuntimeError("boom")): + with pytest.raises(ToolProviderNotFoundError, match="can not load provider yaml"): + _ConcreteBuiltinProvider() + + +def test_builtin_tool_provider_handles_empty_credentials_and_oauth(): + provider = object.__new__(_ConcreteBuiltinProvider) + provider.tools = [] + provider.entity = ToolProviderEntity.model_validate( + { + "identity": { + "author": "Dify", + "name": "fake_provider", + "label": {"en_US": "Fake Provider"}, + "description": {"en_US": "Fake description"}, + "icon": "icon.svg", + "tags": None, + }, + "credentials_schema": [], + "oauth_schema": None, + }, + ) + + assert provider.get_oauth_client_schema() == [] + assert provider.get_supported_credential_types() == [] + assert provider.need_credentials is False + assert provider._get_tool_labels() == [] + + +def test_builtin_tool_provider_forked_tool_runtime_is_initialized(): + with ( + patch( + "core.tools.builtin_tool.provider.load_yaml_file_cached", + side_effect=[_provider_yaml(), _tool_yaml()], + ), + patch("core.tools.builtin_tool.provider.listdir", return_value=["tool_a.yaml"]), + patch( + "core.tools.builtin_tool.provider.load_single_subclass_from_source", + return_value=_FakeBuiltinTool, + ), + ): + provider = _ConcreteBuiltinProvider() + + tool = provider.get_tool("tool_a") + assert tool is not None + assert isinstance(tool.runtime, ToolRuntime) + assert tool.runtime.tenant_id == "" + tool.runtime.invoke_from = InvokeFrom.DEBUGGER + assert tool.runtime.invoke_from == InvokeFrom.DEBUGGER diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py new file mode 100644 index 0000000000..62cfb6ce5b --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +import math +from types import SimpleNamespace + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort +from core.tools.builtin_tool.providers.audio.audio import AudioToolProvider +from core.tools.builtin_tool.providers.audio.tools.asr import ASRTool +from core.tools.builtin_tool.providers.audio.tools.tts import TTSTool +from core.tools.builtin_tool.providers.code.code import CodeToolProvider +from core.tools.builtin_tool.providers.code.tools.simple_code import SimpleCode +from core.tools.builtin_tool.providers.time.time import WikiPediaProvider +from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool +from core.tools.builtin_tool.providers.time.tools.localtime_to_timestamp import LocaltimeToTimestampTool +from core.tools.builtin_tool.providers.time.tools.timestamp_to_localtime import TimestampToLocaltimeTool +from core.tools.builtin_tool.providers.time.tools.timezone_conversion import TimezoneConversionTool +from core.tools.builtin_tool.providers.time.tools.weekday import WeekdayTool +from core.tools.builtin_tool.providers.webscraper.tools.webscraper import WebscraperTool +from core.tools.builtin_tool.providers.webscraper.webscraper import WebscraperProvider +from core.tools.builtin_tool.tool import BuiltinTool +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError +from dify_graph.file.enums import FileType +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey + + +def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + return tool_cls(provider="provider-a", entity=entity, runtime=runtime) + + +def _raise_runtime_error(*_args: object, **_kwargs: object) -> None: + raise RuntimeError("boom") + + +def test_current_time_tool(): + current_tool = _build_builtin_tool(CurrentTimeTool) + utc_text = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "UTC"}))[0].message.text + assert utc_text + + invalid_tz = list(current_tool.invoke(user_id="u", tool_parameters={"timezone": "Invalid/TZ"}))[0].message.text + assert "Invalid timezone" in invalid_tz + + +def test_localtime_to_timestamp_tool(): + localtime_tool = _build_builtin_tool(LocaltimeToTimestampTool) + ts_message = list( + localtime_tool.invoke(user_id="u", tool_parameters={"localtime": "2024-01-01 10:00:00", "timezone": "UTC"}) + )[0].message.text + ts_value = float(ts_message.strip()) + assert math.isfinite(ts_value) + assert ts_value >= 0 + with pytest.raises(ToolInvokeError): + LocaltimeToTimestampTool.localtime_to_timestamp("bad", "%Y-%m-%d %H:%M:%S", "UTC") + + +def test_timestamp_to_localtime_tool(): + to_local_tool = _build_builtin_tool(TimestampToLocaltimeTool) + local_text = list(to_local_tool.invoke(user_id="u", tool_parameters={"timestamp": 1704067200, "timezone": "UTC"}))[ + 0 + ].message.text + assert "2024" in local_text + with pytest.raises(ToolInvokeError): + TimestampToLocaltimeTool.timestamp_to_localtime("bad", "UTC") # type: ignore[arg-type] + + +def test_timezone_conversion_tool(): + timezone_tool = _build_builtin_tool(TimezoneConversionTool) + converted = list( + timezone_tool.invoke( + user_id="u", + tool_parameters={ + "current_time": "2024-01-01 08:00:00", + "current_timezone": "UTC", + "target_timezone": "Asia/Tokyo", + }, + ) + )[0].message.text + assert converted.startswith("2024-01-01") + with pytest.raises(ToolInvokeError): + TimezoneConversionTool.timezone_convert("bad", "UTC", "Asia/Tokyo") + + +def test_weekday_tool(): + weekday_tool = _build_builtin_tool(WeekdayTool) + valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text + assert "January 1, 2024" in valid + invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ + 0 + ].message.text + assert "Invalid date" in invalid + with pytest.raises(ValueError, match="Month is required"): + list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "day": 1})) + + +def test_simple_code_valid_execution(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + lambda *a: "ok", + ) + result = list( + simple_code.invoke( + user_id="u", + tool_parameters={"language": "python3", "code": "print(1)"}, + ) + )[0].message.text + assert result == "ok" + + +def test_simple_code_invalid_language(): + simple_code = _build_builtin_tool(SimpleCode) + + with pytest.raises(ValueError, match="Only python3 and javascript"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "go", "code": "fmt.Println(1)"})) + + +def test_simple_code_execution_error(monkeypatch): + simple_code = _build_builtin_tool(SimpleCode) + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.code.tools.simple_code.CodeExecutor.execute_code", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(simple_code.invoke(user_id="u", tool_parameters={"language": "python3", "code": "print(1)"})) + + +def test_webscraper_empty_url(): + webscraper = _build_builtin_tool(WebscraperTool) + empty = list(webscraper.invoke(user_id="u", tool_parameters={"url": ""}))[0].message.text + assert empty == "Please input url" + + +def test_webscraper_fetch(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + full = list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"}))[0].message.text + assert full == "page" + + +def test_webscraper_summary(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr("core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", lambda *a, **k: "page") + monkeypatch.setattr(webscraper, "summary", lambda user_id, content: "summary") + summarized = list( + webscraper.invoke( + user_id="u", + tool_parameters={"url": "https://example.com", "generate_summary": True}, + ) + )[0].message.text + assert summarized == "summary" + + +def test_webscraper_fetch_error(monkeypatch): + webscraper = _build_builtin_tool(WebscraperTool) + monkeypatch.setattr( + "core.tools.builtin_tool.providers.webscraper.tools.webscraper.get_url", + _raise_runtime_error, + ) + with pytest.raises(ToolInvokeError, match="boom"): + list(webscraper.invoke(user_id="u", tool_parameters={"url": "https://example.com"})) + + +def test_asr_invalid_file(): + asr = _build_builtin_tool(ASRTool) + file_obj = SimpleNamespace(type=FileType.DOCUMENT) + invalid_file = list(asr.invoke(user_id="u", tool_parameters={"audio_file": file_obj}))[0].message.text + assert "not a valid audio file" in invalid_file + + +def test_asr_valid_file_invocation(monkeypatch): + asr = _build_builtin_tool(ASRTool) + model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") + monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + audio_file = SimpleNamespace(type=FileType.AUDIO) + ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text + assert ok == "transcript" + + +def test_asr_available_models_and_runtime_parameters(monkeypatch): + asr = _build_builtin_tool(ASRTool) + provider_model = type("PM", (), {"provider": "p", "models": [type("Model", (), {"model": "m"})()]})() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelProviderService.get_models_by_model_type", + lambda *a, **k: [provider_model], + ) + assert asr.get_available_models() == [("p", "m")] + assert asr.get_runtime_parameters()[0].name == "model" + + +def test_tts_invoke_returns_messages(monkeypatch): + tts = _build_builtin_tool(TTSTool) + voices_model_instance = type( + "TTSM", + (), + { + "get_tts_voices": lambda self: [{"value": "voice-1"}], + "invoke_tts": lambda self, **kwargs: [b"a", b"b"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + ) + messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + + +def test_tts_get_available_models_requires_runtime(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + tts.get_available_models() + + +def test_tts_tool_raises_when_runtime_missing(): + tts = _build_builtin_tool(TTSTool) + tts.runtime = None + with pytest.raises(ValueError, match="Runtime is required"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +@pytest.mark.parametrize( + "voices", + [[{"value": None}], []], +) +def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): + tts = _build_builtin_tool(TTSTool) + tts.runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + model_without_voice = type( + "TTSModelNoVoice", + (), + { + "get_tts_voices": lambda self: voices, + "invoke_tts": lambda self, **kwargs: [b"x"], + }, + )() + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", + lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + ) + with pytest.raises(ValueError, match="no voice available"): + list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) + + +def test_tts_tool_get_available_models_and_runtime_parameters(monkeypatch): + tts = _build_builtin_tool(TTSTool) + + model_1 = SimpleNamespace( + model="model-a", + model_properties={ModelPropertyKey.VOICES: [{"mode": "v1", "name": "Voice 1"}]}, + ) + model_2 = SimpleNamespace(model="model-b", model_properties={}) + provider_models = [SimpleNamespace(provider="provider-a", models=[model_1, model_2])] + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.tts.ModelProviderService.get_models_by_model_type", + lambda *args, **kwargs: provider_models, + ) + + available_models = tts.get_available_models() + assert available_models == [ + ("provider-a", "model-a", [{"mode": "v1", "name": "Voice 1"}]), + ("provider-a", "model-b", []), + ] + + runtime_parameters = tts.get_runtime_parameters() + assert runtime_parameters[0].name == "model" + assert runtime_parameters[0].required is True + assert runtime_parameters[0].options[0].value == "provider-a#model-a" + assert runtime_parameters[1].name == "voice#provider-a#model-a" + + +def test_provider_classes_and_builtin_sort(monkeypatch): + # Use object.__new__ to avoid YAML-loading __init__; only pass-through validation is exercised. + # Ensure pass-through _validate_credentials methods are executed. + AudioToolProvider._validate_credentials(object.__new__(AudioToolProvider), "u", {}) + CodeToolProvider._validate_credentials(object.__new__(CodeToolProvider), "u", {}) + WikiPediaProvider._validate_credentials(object.__new__(WikiPediaProvider), "u", {}) + WebscraperProvider._validate_credentials(object.__new__(WebscraperProvider), "u", {}) + + providers = [SimpleNamespace(name="b"), SimpleNamespace(name="a")] + monkeypatch.setattr(BuiltinToolProviderSort, "_position", {}) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.get_tool_position_map", + lambda _: {"a": 0, "b": 1}, + ) + monkeypatch.setattr( + "core.tools.builtin_tool.providers._positions.sort_by_position_map", + lambda position, values, name_func: sorted(values, key=lambda x: name_func(x)), + ) + sorted_providers = BuiltinToolProviderSort.sort(providers) + assert [p.name for p in sorted_providers] == ["a", "b"] diff --git a/api/tests/unit_tests/core/tools/test_custom_tool.py b/api/tests/unit_tests/core/tools/test_custom_tool.py new file mode 100644 index 0000000000..79b8eaaa87 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +from types import SimpleNamespace + +import httpx +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.custom_tool.tool import ApiTool, ParsedResponse +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError + + +def _build_tool(*, openapi: dict | None = None) -> ApiTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + bundle = ApiToolBundle( + server_url="https://api.example.com/items/{id}", + method="GET", + summary="summary", + operation_id="op-id", + parameters=[], + author="author", + openapi=openapi or {"parameters": []}, + ) + runtime = ToolRuntime( + tenant_id="tenant-1", + invoke_from=InvokeFrom.DEBUGGER, + credentials={"auth_type": "api_key_header", "api_key_value": "k"}, + ) + return ApiTool(entity=entity, api_bundle=bundle, runtime=runtime, provider_id="provider-id") + + +def test_parsed_response_to_string(): + assert ParsedResponse({"a": 1}, True).to_string() == '{"a": 1}' + assert ParsedResponse("ok", False).to_string() == "ok" + + +def test_api_tool_fork_runtime_and_validate_credentials(monkeypatch): + tool = _build_tool() + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, ApiTool) + assert forked.runtime.tenant_id == "tenant-2" + + tool.api_bundle = None # type: ignore[assignment] + with pytest.raises(ValueError, match="api_bundle is required"): + tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + + tool = _build_tool() + assert tool.validate_credentials(credentials={}, parameters={}, format_only=True) == "" + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {"Authorization": "Bearer x"}) + monkeypatch.setattr( + tool, + "do_http_request", + lambda url, method, headers, parameters: httpx.Response(200, json={"ok": True}), + ) + result = tool.validate_credentials(credentials={}, parameters={"a": 1}, format_only=False) + assert result == '{"ok": true}' + + +def test_assembling_request_auth_header_assembly(): + tool = _build_tool() + + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "k" + + tool.runtime.credentials = { + "auth_type": "api_key_header", + "api_key_header_prefix": "bearer", + "api_key_value": "abc", + } + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Bearer abc" + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_header_prefix": "basic", "api_key_value": "abc"} + headers = tool.assembling_request(parameters={}) + assert headers["Authorization"] == "Basic abc" + + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_value": "abc"} + assert tool.assembling_request(parameters={}) == {} + + +def test_assembling_request_runtime_auth_errors(): + tool = _build_tool() + + tool.runtime = None + with pytest.raises(ToolProviderCredentialValidationError, match="runtime not initialized"): + tool.assembling_request(parameters={}) + + tool.runtime = ToolRuntime(tenant_id="tenant", credentials={}) + with pytest.raises(ToolProviderCredentialValidationError, match="Missing auth_type"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header"} + with pytest.raises(ToolProviderCredentialValidationError, match="Missing api_key_value"): + tool.assembling_request(parameters={}) + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": 123} + with pytest.raises(ToolProviderCredentialValidationError, match="must be a string"): + tool.assembling_request(parameters={}) + + +def test_assembling_request_parameter_validation_and_defaults(): + tool = _build_tool() + + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "x"} + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default=None), + ] + with pytest.raises(ToolParameterValidationError, match="Missing required parameter required_param"): + tool.assembling_request(parameters={}) + + tool.api_bundle.parameters = [ + SimpleNamespace(required=True, name="required_param", default="d"), + ] + params = {} + tool.assembling_request(parameters=params) + assert params["required_param"] == "d" + + +def test_validate_and_parse_response_branches(): + tool = _build_tool() + + with pytest.raises(ToolInvokeError, match="status code 500"): + tool.validate_and_parse_response(httpx.Response(500, text="boom")) + + empty = tool.validate_and_parse_response(httpx.Response(200, content=b"")) + assert empty.is_json is False + assert "Empty response from the tool" in str(empty.content) + + json_resp = tool.validate_and_parse_response( + httpx.Response(200, json={"a": 1}, headers={"content-type": "application/json"}) + ) + assert json_resp.is_json is True + assert json_resp.content == {"a": 1} + + non_json_type = tool.validate_and_parse_response( + httpx.Response(200, text='{"a": 1}', headers={"content-type": "text/plain"}) + ) + assert non_json_type.is_json is False + assert non_json_type.content == '{"a": 1}' + + plain_resp = tool.validate_and_parse_response(httpx.Response(200, text="plain")) + assert plain_resp.is_json is False + assert plain_resp.content == "plain" + + with pytest.raises(ValueError, match="Invalid response type"): + tool.validate_and_parse_response("invalid") # type: ignore[arg-type] + + +def test_get_parameter_value_and_type_conversion_helpers(): + tool = _build_tool() + + assert tool.get_parameter_value({"name": "x"}, {"x": 1}) == 1 + assert tool.get_parameter_value({"name": "x", "required": False, "schema": {"default": "d"}}, {}) == "d" + with pytest.raises(ToolParameterValidationError, match="Missing required parameter x"): + tool.get_parameter_value({"name": "x", "required": True}, {}) + + assert tool._convert_body_property_any_of({}, "12", [{"type": "integer"}]) == 12 + assert tool._convert_body_property_any_of({}, "1.5", [{"type": "number"}]) == 1.5 + assert tool._convert_body_property_any_of({}, "true", [{"type": "boolean"}]) is True + assert tool._convert_body_property_any_of({}, "", [{"type": "null"}]) is None + assert tool._convert_body_property_any_of({}, "x", [{"anyOf": [{"type": "string"}]}]) == "x" + + assert tool._convert_body_property_type({"type": "integer"}, "1") == 1 + assert tool._convert_body_property_type({"type": "number"}, "1.2") == 1.2 + assert tool._convert_body_property_type({"type": "string"}, 1) == "1" + assert tool._convert_body_property_type({"type": "boolean"}, 1) is True + assert tool._convert_body_property_type({"type": "null"}, None) is None + assert tool._convert_body_property_type({"type": "object"}, '{"a":1}') == {"a": 1} + assert tool._convert_body_property_type({"type": "array"}, "[1,2]") == [1, 2] + assert tool._convert_body_property_type({"type": "invalid"}, "v") == "v" + assert tool._convert_body_property_type({"anyOf": [{"type": "integer"}]}, "2") == 2 + + +def test_do_http_request_builds_arguments_and_handles_invalid_method(monkeypatch): + openapi = { + "parameters": [ + {"name": "id", "in": "path", "required": True, "schema": {"type": "string"}}, + {"name": "q", "in": "query", "required": False, "schema": {"default": ""}}, + {"name": "X-Extra", "in": "header", "required": False, "schema": {"default": "x"}}, + {"name": "sid", "in": "cookie", "required": False, "schema": {"default": "cookie1"}}, + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "type": "object", + "required": ["count"], + "properties": { + "count": {"type": "integer"}, + "name": {"type": "string", "default": "n"}, + }, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_query", "api_key_query_param": "key", "api_key_value": "v"} + headers = {} + captured = {} + + def _fake_get(url, **kwargs): + captured["url"] = url + captured["kwargs"] = kwargs + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.get", _fake_get) + response = tool.do_http_request( + "https://api.example.com/items/{id}", + "GET", + headers=headers, + parameters={"id": "123", "count": "2", "q": "search"}, + ) + + assert isinstance(response, httpx.Response) + assert captured["url"].endswith("/items/123") + assert captured["kwargs"]["params"]["q"] == "search" + assert captured["kwargs"]["params"]["key"] == "v" + assert captured["kwargs"]["headers"]["Content-Type"] == "application/json" + + invalid_method_tool = _build_tool(openapi={"parameters": []}) + with pytest.raises(ValueError, match="Invalid http method"): + invalid_method_tool.do_http_request("https://api.example.com", "TRACE", headers={}, parameters={}) + + +def test_do_http_request_handles_file_upload_and_invoke_paths(monkeypatch): + openapi = { + "parameters": [], + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": {"file": {"format": "binary"}}, + } + } + } + }, + } + tool = _build_tool(openapi=openapi) + tool.runtime.credentials = {"auth_type": "api_key_header", "api_key_value": "k"} + fake_file = SimpleNamespace(filename="a.txt", mime_type="text/plain") + captured = {} + + def _fake_post(url, **kwargs): + captured["headers"] = kwargs["headers"] + captured["files"] = kwargs["files"] + return httpx.Response(200, text="ok") + + monkeypatch.setattr("core.tools.custom_tool.tool.download", lambda _: b"file-bytes") + monkeypatch.setattr("core.tools.custom_tool.tool.ssrf_proxy.post", _fake_post) + response = tool.do_http_request( + "https://api.example.com/upload", + "POST", + headers={}, + parameters={"file": fake_file}, + ) + assert isinstance(response, httpx.Response) + assert "Content-Type" not in captured["headers"] + assert captured["files"][0][0] == "file" + + # _invoke JSON path + monkeypatch.setattr(tool, "assembling_request", lambda parameters: {}) + monkeypatch.setattr(tool, "do_http_request", lambda *args, **kwargs: httpx.Response(200, text='{"a":1}')) + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse({"a": 1}, True)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.TEXT] + + # _invoke text path + monkeypatch.setattr(tool, "validate_and_parse_response", lambda _: ParsedResponse("plain", False)) + messages = list(tool.invoke(user_id="u1", tool_parameters={})) + assert len(messages) == 1 + assert messages[0].message.text == "plain" diff --git a/api/tests/unit_tests/core/tools/test_custom_tool_provider.py b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py new file mode 100644 index 0000000000..93ae217e24 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_custom_tool_provider.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.custom_tool.tool import ApiTool +from core.tools.entities.tool_bundle import ApiToolBundle +from core.tools.entities.tool_entities import ApiProviderAuthType, ToolProviderType + + +def _db_provider() -> SimpleNamespace: + bundle = ApiToolBundle( + server_url="https://api.example.com/items", + method="GET", + summary="List items", + operation_id="list_items", + parameters=[], + author="author", + openapi={"parameters": []}, + ) + return SimpleNamespace( + id="provider-id", + tenant_id="tenant-1", + name="provider-a", + description="desc", + icon="icon.svg", + user=SimpleNamespace(name="Alice"), + tools=[bundle], + ) + + +def test_api_tool_provider_from_db_and_parse_tool_bundle(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_HEADER) + assert controller.provider_type == ToolProviderType.API + assert any(c.name == "api_key_value" for c in controller.entity.credentials_schema) + + tool = controller._parse_tool_bundle(_db_provider().tools[0]) + assert isinstance(tool, ApiTool) + assert tool.entity.identity.provider == "provider-id" + + +def test_api_tool_provider_from_db_query_auth_and_none_auth(): + query_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.API_KEY_QUERY) + assert any(c.name == "api_key_query_param" for c in query_controller.entity.credentials_schema) + + none_controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + assert [c.name for c in none_controller.entity.credentials_schema] == ["auth_type"] + + +def test_api_tool_provider_load_get_tools_and_get_tool(): + controller = ApiToolProviderController.from_db(_db_provider(), ApiProviderAuthType.NONE) + loaded = controller.load_bundled_tools(_db_provider().tools) + assert len(loaded) == 1 + + assert isinstance(controller.get_tool("list_items"), ApiTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + # Return cached tools without querying database. + cached = controller.get_tools("tenant-1") + assert len(cached) == 1 + + # Force DB fetch branch. + controller.tools = [] + provider_with_tools = _db_provider() + with patch("core.tools.custom_tool.provider.db") as mock_db: + scalars_result = Mock() + scalars_result.all.return_value = [provider_with_tools] + mock_db.session.scalars.return_value = scalars_result + tools = controller.get_tools("tenant-1") + assert len(tools) == 1 diff --git a/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py new file mode 100644 index 0000000000..23c0be9487 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_dataset_retriever_tool.py @@ -0,0 +1,145 @@ +"""Unit tests for DatasetRetrieverTool behavior and retrieval wiring.""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE) + + +def test_get_dataset_tools_returns_empty_for_empty_dataset_ids() -> None: + # Arrange + retrieve_config = _retrieve_config() + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=[], + retrieve_config=retrieve_config, + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_returns_empty_for_missing_retrieve_config() -> None: + # Arrange + dataset_ids = ["d1"] + + # Act + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=dataset_ids, + retrieve_config=None, # type: ignore[arg-type] + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + + # Assert + assert tools == [] + + +def test_get_dataset_tools_builds_tool_and_restores_strategy() -> None: + # Arrange + retrieve_config = _retrieve_config() + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + + # Act + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=retrieve_config, + return_resource=True, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={"x": 1}, + ) + + # Assert + assert len(tools) == 1 + assert tools[0].entity.identity.name == "dataset_tool" + assert retrieve_config.retrieve_strategy == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE + + +def _build_dataset_tool() -> tuple[DatasetRetrieverTool, SimpleNamespace]: + retrieval_tool = SimpleNamespace(name="dataset_tool", description="desc", run=lambda query: f"result:{query}") + feature = Mock() + feature.to_dataset_retriever_tool.return_value = [retrieval_tool] + with patch("core.tools.utils.dataset_retriever_tool.DatasetRetrieval", return_value=feature): + tools = DatasetRetrieverTool.get_dataset_tools( + tenant_id="tenant", + dataset_ids=["d1"], + retrieve_config=_retrieve_config(), + return_resource=False, + invoke_from=InvokeFrom.DEBUGGER, + hit_callback=Mock(), + user_id="u", + inputs={}, + ) + return tools[0], retrieval_tool + + +def test_runtime_parameters_shape() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + params = tool.get_runtime_parameters() + + # Assert + assert len(params) == 1 + assert params[0].name == "query" + + +def test_empty_query_behavior() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + empty_query = list(tool.invoke(user_id="u", tool_parameters={})) + + # Assert + assert len(empty_query) == 1 + assert empty_query[0].message.text == "please input query" + + +def test_query_invocation_result() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = list(tool.invoke(user_id="u", tool_parameters={"query": "hello"})) + + # Assert + assert len(result) == 1 + assert result[0].message.text == "result:hello" + + +def test_validate_credentials() -> None: + # Arrange + tool, _ = _build_dataset_tool() + + # Act + result = tool.validate_credentials(credentials={}, parameters={}, format_only=False) + + # Assert + assert result is None diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool.py b/api/tests/unit_tests/core/tools/test_mcp_tool.py new file mode 100644 index 0000000000..eaf054de59 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +import base64 +from unittest.mock import patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.mcp.types import ( + BlobResourceContents, + CallToolResult, + EmbeddedResource, + ImageContent, + TextContent, + TextResourceContents, +) +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_tool(*, with_output_schema: bool = True) -> MCPTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="remote-tool", + label=I18nObject(en_US="remote-tool"), + provider="provider-id", + ), + parameters=[], + output_schema={"type": "object"} if with_output_schema else {}, + ) + return MCPTool( + entity=entity, + runtime=ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER), + tenant_id="tenant-1", + icon="icon.svg", + server_url="https://mcp.example.com", + provider_id="provider-id", + headers={"x-auth": "token"}, + ) + + +def test_mcp_tool_provider_type_and_fork_runtime(): + tool = _build_mcp_tool() + assert tool.tool_provider_type() == ToolProviderType.MCP + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, MCPTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.provider_id == "provider-id" + + +def test_mcp_tool_text_and_json_processing_helpers(): + tool = _build_mcp_tool() + + json_messages = list(tool._process_text_content(TextContent(type="text", text='{"a": 1}'))) + assert json_messages[0].type == ToolInvokeMessage.MessageType.JSON + + plain_messages = list(tool._process_text_content(TextContent(type="text", text="not-json"))) + assert plain_messages[0].type == ToolInvokeMessage.MessageType.TEXT + assert plain_messages[0].message.text == "not-json" + + list_messages = list(tool._process_json_content([{"k": 1}, {"k": 2}])) + assert [m.type for m in list_messages] == [ToolInvokeMessage.MessageType.JSON, ToolInvokeMessage.MessageType.JSON] + + mixed_list_messages = list(tool._process_json_list([{"k": 1}, 2])) + assert len(mixed_list_messages) == 1 + assert mixed_list_messages[0].type == ToolInvokeMessage.MessageType.TEXT + + primitive_messages = list(tool._process_json_content(123)) + assert primitive_messages[0].message.text == "123" + + +def test_mcp_tool_usage_extraction_helpers(): + usage = MCPTool._extract_usage_dict({"usage": {"total_tokens": 9}}) + assert usage == {"total_tokens": 9} + + usage = MCPTool._extract_usage_dict({"metadata": {"usage": {"prompt_tokens": 3, "completion_tokens": 2}}}) + assert usage == {"prompt_tokens": 3, "completion_tokens": 2} + + usage = MCPTool._extract_usage_dict({"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3}) + assert usage == {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} + + usage = MCPTool._extract_usage_dict({"nested": [{"deep": {"usage": {"total_tokens": 7}}}]}) + assert usage == {"total_tokens": 7} + + result_with_usage = CallToolResult(content=[], _meta={"usage": {"prompt_tokens": 1, "completion_tokens": 2}}) + derived = MCPTool._derive_usage_from_result(result_with_usage) + assert derived.prompt_tokens == 1 + assert derived.completion_tokens == 2 + + result_without_usage = CallToolResult(content=[], _meta=None) + derived = MCPTool._derive_usage_from_result(result_without_usage) + assert derived.total_tokens == 0 + + +def test_mcp_tool_invoke_handles_content_types_and_structured_output(): + tool = _build_mcp_tool() + img_data = base64.b64encode(b"img").decode() + blob_data = base64.b64encode(b"blob").decode() + result = CallToolResult( + content=[ + TextContent(type="text", text='{"a": 1}'), + ImageContent(type="image", data=img_data, mimeType="image/png"), + EmbeddedResource( + type="resource", + resource=TextResourceContents(uri="file:///tmp/a.txt", text="embedded-text"), + ), + EmbeddedResource( + type="resource", + resource=BlobResourceContents( + uri="file:///tmp/b.bin", + blob=blob_data, + mimeType="application/octet-stream", + ), + ), + ], + structuredContent={"x": 1}, + _meta={"usage": {"prompt_tokens": 2, "completion_tokens": 3}}, + ) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"a": 1})) + + types = [m.type for m in messages] + assert ToolInvokeMessage.MessageType.JSON in types + assert ToolInvokeMessage.MessageType.BLOB in types + assert ToolInvokeMessage.MessageType.TEXT in types + assert ToolInvokeMessage.MessageType.VARIABLE in types + assert tool.latest_usage.total_tokens == 5 + + +def test_mcp_tool_invoke_raises_for_unsupported_embedded_resource(): + tool = _build_mcp_tool() + # Use model_construct to bypass pydantic validation and force unsupported resource path. + bad_resource = EmbeddedResource.model_construct(type="resource", resource=object()) + result = CallToolResult(content=[bad_resource], _meta=None) + + with patch.object(MCPTool, "invoke_remote_mcp_tool", return_value=result): + with pytest.raises(ToolInvokeError, match="Unsupported embedded resource type"): + list(tool.invoke(user_id="user-1", tool_parameters={})) + + +def test_mcp_tool_handle_none_parameter_filters_empty_values(): + tool = _build_mcp_tool() + cleaned = tool._handle_none_parameter({"a": 1, "b": None, "c": "", "d": " ", "e": "ok"}) + assert cleaned == {"a": 1, "e": "ok"} diff --git a/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py new file mode 100644 index 0000000000..1060d19ab1 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_mcp_tool_provider.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from datetime import datetime +from unittest.mock import Mock, patch + +import pytest + +from core.entities.mcp_provider import MCPProviderEntity +from core.tools.entities.tool_entities import ToolProviderType +from core.tools.mcp_tool.provider import MCPToolProviderController +from core.tools.mcp_tool.tool import MCPTool + + +def _build_mcp_entity(*, icon: str = "icon.svg") -> MCPProviderEntity: + now = datetime.now() + return MCPProviderEntity( + id="db-id", + provider_id="provider-id", + name="MCP Provider", + tenant_id="tenant-1", + user_id="user-1", + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + authed=False, + credentials={}, + tools=[ + { + "name": "remote-tool", + "description": "remote tool", + "inputSchema": {}, + "outputSchema": {"type": "object"}, + } + ], + icon=icon, + created_at=now, + updated_at=now, + ) + + +def test_mcp_tool_provider_controller_from_entity_and_get_tools(): + entity = _build_mcp_entity() + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_entity(entity) + + assert controller.provider_type == ToolProviderType.MCP + tool = controller.get_tool("remote-tool") + assert isinstance(tool, MCPTool) + assert tool.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], MCPTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_mcp_tool_provider_controller_from_entity_requires_icon(): + entity = _build_mcp_entity(icon="") + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + with pytest.raises(ValueError, match="icon is required"): + MCPToolProviderController.from_entity(entity) + + +def test_mcp_tool_provider_controller_from_db_delegates_to_entity(): + entity = _build_mcp_entity() + db_provider = Mock() + db_provider.to_entity.return_value = entity + with patch("core.tools.mcp_tool.provider.ToolTransformService.convert_mcp_schema_to_parameter", return_value=[]): + controller = MCPToolProviderController.from_db(db_provider) + assert isinstance(controller, MCPToolProviderController) diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool.py b/api/tests/unit_tests/core/tools/test_plugin_tool.py new file mode 100644 index 0000000000..4378432a0f --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool.py @@ -0,0 +1,91 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolParameter +from core.tools.plugin_tool.tool import PluginTool + + +def _build_plugin_tool(*, has_runtime_parameters: bool) -> PluginTool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ], + has_runtime_parameters=has_runtime_parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, credentials={"api_key": "x"}) + return PluginTool( + entity=entity, + runtime=runtime, + tenant_id="tenant-1", + icon="icon.svg", + plugin_unique_identifier="plugin-uid", + ) + + +def test_plugin_tool_invoke_and_fork_runtime(): + tool = _build_plugin_tool(has_runtime_parameters=False) + manager = Mock() + manager.invoke.return_value = iter([tool.create_text_message("ok")]) + + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + with patch( + "core.tools.plugin_tool.tool.convert_parameters_to_plugin_format", + return_value={"converted": 1}, + ): + messages = list(tool.invoke(user_id="user-1", tool_parameters={"raw": 1})) + + assert [m.message.text for m in messages] == ["ok"] + manager.invoke.assert_called_once() + assert manager.invoke.call_args.kwargs["tool_parameters"] == {"converted": 1} + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2")) + assert isinstance(forked, PluginTool) + assert forked.runtime.tenant_id == "tenant-2" + assert forked.plugin_unique_identifier == "plugin-uid" + + +def test_plugin_tool_get_runtime_parameters_branches(): + tool = _build_plugin_tool(has_runtime_parameters=False) + assert tool.get_runtime_parameters() == tool.entity.parameters + + tool = _build_plugin_tool(has_runtime_parameters=True) + cached = [ + ToolParameter.get_simple_instance( + name="k", + llm_description="k", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + tool.runtime_parameters = cached + assert tool.get_runtime_parameters() == cached + + tool.runtime_parameters = None + manager = Mock() + returned = [ + ToolParameter.get_simple_instance( + name="dyn", + llm_description="dyn", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + manager.get_runtime_parameters.return_value = returned + with patch("core.tools.plugin_tool.tool.PluginToolManager", return_value=manager): + assert tool.get_runtime_parameters(conversation_id="c1", app_id="a1", message_id="m1") == returned + assert tool.runtime_parameters == returned diff --git a/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py new file mode 100644 index 0000000000..5ef03cc6ca --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_plugin_tool_provider.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolProviderEntityWithPlugin, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.plugin_tool.tool import PluginTool + + +def _build_controller() -> PluginToolProviderController: + tool_entity = ToolEntity( + identity=ToolIdentity( + author="author", + name="tool-a", + label=I18nObject(en_US="tool-a"), + provider="provider-a", + ), + parameters=[], + ) + entity = ToolProviderEntityWithPlugin( + identity=ToolProviderIdentity( + author="author", + name="provider-a", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ), + credentials_schema=[], + plugin_id="plugin-id", + tools=[tool_entity], + ) + return PluginToolProviderController( + entity=entity, + plugin_id="plugin-id", + plugin_unique_identifier="plugin-uid", + tenant_id="tenant-1", + ) + + +def test_plugin_tool_provider_controller_basic_behaviors(): + controller = _build_controller() + assert controller.provider_type == ToolProviderType.PLUGIN + + tool = controller.get_tool("tool-a") + assert isinstance(tool, PluginTool) + assert tool.runtime.tenant_id == "tenant-1" + + tools = controller.get_tools() + assert len(tools) == 1 + assert isinstance(tools[0], PluginTool) + + with pytest.raises(ValueError, match="not found"): + controller.get_tool("missing") + + +def test_validate_credentials_success(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = True + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) + + manager.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant-1", + user_id="u1", + provider="provider-a", + credentials={"api_key": "x"}, + ) + + +def test_validate_credentials_failure(): + controller = _build_controller() + manager = Mock() + manager.validate_provider_credentials.return_value = False + + with patch("core.tools.plugin_tool.provider.PluginToolManager", return_value=manager): + with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"): + controller._validate_credentials(user_id="u1", credentials={"api_key": "x"}) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py new file mode 100644 index 0000000000..a5242a78c5 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -0,0 +1,119 @@ +"""Unit tests for core.tools.signature covering signing and verification invariants.""" + +from __future__ import annotations + +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature + + +def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=False) + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert parsed.scheme == "https" + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is True + + +def test_sign_tool_file_for_external_uses_files_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x04" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 120) + + url = sign_tool_file("tool-file-id", ".png", for_external=True) + parsed = urlparse(url) + + assert parsed.scheme == "https" + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/tools/tool-file-id.png" + + +def test_verify_tool_file_signature_rejects_invalid_sign(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, "bad-signature") is False + + +def test_verify_tool_file_signature_rejects_expired_signature(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x02" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 10) + + url = sign_tool_file("tool-file-id", ".txt") + parsed = urlparse(url) + query = parse_qs(parsed.query) + timestamp = query["timestamp"][0] + nonce = query["nonce"][0] + sign = query["sign"][0] + + monkeypatch.setattr("core.tools.signature.time.time", lambda: int(timestamp) + 99) + assert verify_tool_file_signature("tool-file-id", timestamp, nonce, sign) is False + + +def test_sign_upload_file_prefers_internal_url(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x03" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] + + +def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x05" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + + url = sign_upload_file("upload-id", ".png") + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "files.example.com" + assert parsed.path == "/files/upload-id/image-preview" + assert query["timestamp"][0] + assert query["nonce"][0] + assert query["sign"][0] diff --git a/api/tests/unit_tests/core/tools/test_tool_engine.py b/api/tests/unit_tests/core/tools/test_tool_engine.py new file mode 100644 index 0000000000..40c107667c --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_engine.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +from collections.abc import Generator +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolInvokeMessageBinary, + ToolInvokeMeta, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ( + ToolEngineInvokeError, + ToolInvokeError, + ToolParameterValidationError, +) +from core.tools.tool_engine import ToolEngine + + +class _DummyTool(Tool): + result: Any + raise_error: Exception | None + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime): + super().__init__(entity=entity, runtime=runtime) + self.result = [self.create_text_message("ok")] + self.raise_error = None + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + if self.raise_error: + raise self.raise_error + if isinstance(self.result, list | Generator): + yield from self.result + else: + yield self.result + + +def _build_tool(with_llm_parameter: bool = False) -> _DummyTool: + parameters = [] + if with_llm_parameter: + parameters = [ + ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + ] + entity = ToolEntity( + identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=parameters, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER, runtime_parameters={"rt": 1}) + return _DummyTool(entity=entity, runtime=runtime) + + +def test_convert_tool_response_to_str_and_extract_binary_messages(): + tool = _build_tool() + messages = [ + tool.create_text_message("hello"), + tool.create_link_message("https://example.com"), + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="https://example.com/a.png"), + meta={"mime_type": "image/png"}, + ), + tool.create_json_message({"a": 1}), + tool.create_json_message({"a": 1}, suppress_output=True), + ] + text = ToolEngine._convert_tool_response_to_str(messages) + assert "hello" in text + assert "result link: https://example.com." in text + assert '"a": 1' in text + + blob_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=ToolInvokeMessage.TextMessage(text="https://example.com/blob.bin"), + meta={"mime_type": "application/octet-stream"}, + ) + link_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://example.com/file.pdf"), + meta={"mime_type": "application/pdf"}, + ) + binaries = list(ToolEngine._extract_tool_response_binary_and_text([messages[2], blob_message, link_message])) + assert [b.mimetype for b in binaries] == ["image/png", "application/octet-stream", "application/pdf"] + + with pytest.raises(ValueError, match="missing meta data"): + list( + ToolEngine._extract_tool_response_binary_and_text( + [ + ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE, + message=ToolInvokeMessage.TextMessage(text="x"), + ) + ] + ) + ) + + +def test_create_message_files_and_invoke_generator(): + binaries = [ + ToolInvokeMessageBinary(mimetype="image/png", url="https://example.com/abc.png"), + ToolInvokeMessageBinary(mimetype="audio/wav", url="https://example.com/def.wav"), + ] + created = [] + + def _message_file_factory(**kwargs): + obj = SimpleNamespace(id=f"mf-{len(created) + 1}", **kwargs) + created.append(obj) + return obj + + with patch("core.tools.tool_engine.MessageFile", side_effect=_message_file_factory): + with patch("core.tools.tool_engine.db") as mock_db: + ids = ToolEngine._create_message_files( + tool_messages=binaries, + agent_message=SimpleNamespace(id="msg-1"), + invoke_from=InvokeFrom.DEBUGGER, + user_id="user-1", + ) + + assert ids == ["mf-1", "mf-2"] + assert mock_db.session.add.call_count == 2 + mock_db.session.close.assert_called_once() + + tool = _build_tool() + invoked = list(ToolEngine._invoke(tool, {"a": 1}, user_id="u")) + assert invoked[0].type == ToolInvokeMessage.MessageType.TEXT + assert isinstance(invoked[-1], ToolInvokeMeta) + assert invoked[-1].error is None + + +def test_generic_invoke_success_and_error_paths(): + tool = _build_tool() + callback = Mock() + callback.on_tool_execution.side_effect = lambda **kwargs: kwargs["tool_outputs"] + response = list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=callback, + workflow_call_depth=0, + conversation_id="c1", + app_id="a1", + message_id="m1", + ) + ) + assert response[0].message.text == "ok" + callback.on_tool_start.assert_called_once() + callback.on_tool_execution.assert_called_once() + + tool.raise_error = RuntimeError("boom") + error_callback = Mock() + error_callback.on_tool_execution.side_effect = lambda **kwargs: list(kwargs["tool_outputs"]) + with pytest.raises(RuntimeError, match="boom"): + list( + ToolEngine.generic_invoke( + tool=tool, + tool_parameters={"x": 1}, + user_id="u1", + workflow_tool_callback=error_callback, + workflow_call_depth=0, + ) + ) + error_callback.on_tool_error.assert_called_once() + + +def test_agent_invoke_success(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + meta = ToolInvokeMeta.empty() + + with patch.object(ToolEngine, "_invoke", return_value=iter([tool.create_text_message("ok"), meta])): + with patch( + "core.tools.tool_engine.ToolFileMessageTransformer.transform_tool_invoke_messages", + side_effect=lambda messages, **kwargs: messages, + ): + with patch.object(ToolEngine, "_extract_tool_response_binary_and_text", return_value=iter([])): + with patch.object(ToolEngine, "_create_message_files", return_value=[]): + result_text, message_files, result_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters="hello", + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert result_text == "ok" + assert message_files == [] + assert result_meta.error is None + callback.on_tool_start.assert_called_once() + callback.on_tool_end.assert_called_once() + + +def test_agent_invoke_param_validation_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolParameterValidationError("bad-param")): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool parameters validation error" in error_text + assert files == [] + assert error_meta.error + + +def test_agent_invoke_engine_meta_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + engine_error = ToolEngineInvokeError(ToolInvokeMeta.error_instance("meta failure")) + + with patch.object(ToolEngine, "_invoke", side_effect=engine_error): + error_text, files, error_meta = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "meta failure" in error_text + assert files == [] + assert error_meta.error == "meta failure" + + +def test_agent_invoke_tool_invoke_error(): + tool = _build_tool(with_llm_parameter=True) + callback = Mock() + message = SimpleNamespace(id="m1", conversation_id="c1") + + with patch.object(ToolEngine, "_invoke", side_effect=ToolInvokeError("invoke boom")): + error_text, files, _ = ToolEngine.agent_invoke( + tool=tool, + tool_parameters={"a": 1}, + user_id="u1", + tenant_id="tenant-1", + message=message, + invoke_from=InvokeFrom.DEBUGGER, + agent_tool_callback=callback, + ) + + assert "tool invoke error" in error_text + assert files == [] diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py new file mode 100644 index 0000000000..cca8254dd6 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -0,0 +1,249 @@ +"""Unit tests for `ToolFileManager` behavior. + +Covers signing/verification, file persistence flows, and retrieval APIs with +mocked storage/session boundaries (httpx, SimpleNamespace, Mock/patch) to +avoid real IO. +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest + +from core.tools.tool_file_manager import ToolFileManager + + +def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.tool_file_manager.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.SECRET_KEY", "secret") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 100) + + url = ToolFileManager.sign_file("tf-1", ".png") + return dict(part.split("=", 1) for part in url.split("?", 1)[1].split("&")) + + +def _patch_session_factory(session: Mock): + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return patch("core.tools.tool_file_manager.session_factory.create_session", return_value=session_cm) + + +def test_tool_file_manager_sign_verify_valid(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + url = ToolFileManager.sign_file("tf-1", ".png") + assert "/files/tools/tf-1.png" in url + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is True + + +def test_tool_file_manager_sign_verify_bad_signature(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], "bad") is False + + +def test_tool_file_manager_sign_verify_expired_timestamp(monkeypatch: pytest.MonkeyPatch) -> None: + query = _setup_tool_file_signing(monkeypatch) + monkeypatch.setattr("core.tools.tool_file_manager.dify_config.FILES_ACCESS_TIMEOUT", 0) + monkeypatch.setattr("core.tools.tool_file_manager.time.time", lambda: 1700000100) + + assert ToolFileManager.verify_file("tf-1", query["timestamp"], query["nonce"], query["sign"]) is False + + +def test_create_file_by_raw_stores_file_and_persists_record() -> None: + manager = ToolFileManager() + session = Mock() + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-1") + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.guess_extension", return_value=".txt"), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="abc")), + _patch_session_factory(session), + ): + file_model = manager.create_file_by_raw( + user_id="u1", + tenant_id="t1", + conversation_id="c1", + file_binary=b"hello", + mimetype="text/plain", + filename="readme", + ) + + assert file_model.name.endswith(".txt") + storage.save.assert_called_once() + session.add.assert_called_once() + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_downloads_and_persists_record() -> None: + manager = ToolFileManager() + response = Mock() + response.content = b"binary" + response.headers = {"Content-Type": "application/octet-stream"} + response.raise_for_status.return_value = None + session = Mock() + + def tool_file_factory(**kwargs): + return SimpleNamespace(**kwargs) + + session.refresh.side_effect = lambda model: setattr(model, "id", "tf-2") + with ( + patch("core.tools.tool_file_manager.storage") as storage, + patch("core.tools.tool_file_manager.ToolFile", side_effect=tool_file_factory), + patch("core.tools.tool_file_manager.uuid4", return_value=SimpleNamespace(hex="def")), + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ssrf_proxy.get", return_value=response), + ): + file_model = manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + assert file_model.file_key.startswith("tools/t1/") + storage.save.assert_called_once() + session.add.assert_called_once_with(file_model) + session.commit.assert_called_once() + session.refresh.assert_called_once_with(file_model) + + +def test_create_file_by_url_raises_on_timeout() -> None: + manager = ToolFileManager() + + with patch("core.tools.tool_file_manager.ssrf_proxy.get", side_effect=httpx.TimeoutException("timeout")): + with pytest.raises(ValueError, match="timeout when downloading file"): + manager.create_file_by_url("u1", "t1", "https://example.com/f.bin", "c1") + + +def test_get_file_binary_returns_none_when_not_found() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary("missing") + + # Assert + assert result is None + + +def test_get_file_binary_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k1", mimetype="text/plain") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"hello" + with _patch_session_factory(session): + result = manager.get_file_binary("id1") + + # Assert + assert result == (b"hello", "text/plain") + + +def test_get_file_binary_by_message_file_id_when_messagefile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = None + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_when_url_is_none() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url=None) + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = None + session.query.side_effect = [first_query, second_query] + + # Act + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result is None + + +def test_get_file_binary_by_message_file_id_returns_bytes_when_found() -> None: + # Arrange + manager = ToolFileManager() + message_file = SimpleNamespace(url="https://x/files/tools/tool123.png") + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + first_query = Mock() + second_query = Mock() + first_query.where.return_value.first.return_value = message_file + second_query.where.return_value.first.return_value = tool_file + session.query.side_effect = [first_query, second_query] + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + storage.load_once.return_value = b"img" + with _patch_session_factory(session): + result = manager.get_file_binary_by_message_file_id("mf-1") + + # Assert + assert result == (b"img", "image/png") + + +def test_get_file_generator_returns_none_when_toolfile_missing() -> None: + # Arrange + manager = ToolFileManager() + session = Mock() + session.query.return_value.where.return_value.first.return_value = None + + # Act + with _patch_session_factory(session): + stream, tool_file = manager.get_file_generator_by_tool_file_id("tool123") + + # Assert + assert stream is None + assert tool_file is None + + +def test_get_file_generator_returns_stream_when_found() -> None: + # Arrange + manager = ToolFileManager() + tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + session = Mock() + session.query.return_value.where.return_value.first.return_value = tool_file + + # Act + with patch("core.tools.tool_file_manager.storage") as storage: + stream = iter([b"a", b"b"]) + storage.load_stream.return_value = stream + with ( + _patch_session_factory(session), + patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), + ): + result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") + assert list(result_stream) == [b"a", b"b"] + assert result_file == "validated-file" diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py new file mode 100644 index 0000000000..857f4aa178 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -0,0 +1,92 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any +from unittest.mock import PropertyMock, patch + +import pytest + +from core.tools.builtin_tool.provider import BuiltinToolProviderController +from core.tools.custom_tool.provider import ApiToolProviderController +from core.tools.tool_label_manager import ToolLabelManager +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController + + +class _ConcreteBuiltinToolProviderController(BuiltinToolProviderController): + def _validate_credentials(self, user_id: str, credentials: dict[str, Any]): + return None + + +def _api_controller(provider_id: str = "api-1") -> ApiToolProviderController: + controller = object.__new__(ApiToolProviderController) + controller.provider_id = provider_id + return controller + + +def _workflow_controller(provider_id: str = "wf-1") -> WorkflowToolProviderController: + controller = object.__new__(WorkflowToolProviderController) + controller.provider_id = provider_id + return controller + + +def test_tool_label_manager_filter_tool_labels(): + filtered = ToolLabelManager.filter_tool_labels(["search", "search", "invalid", "news"]) + assert set(filtered) == {"search", "news"} + assert len(filtered) == 2 + + +def test_tool_label_manager_update_tool_labels_db(): + controller = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + delete_query = mock_db.session.query.return_value.where.return_value + delete_query.delete.return_value = None + ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) + + delete_query.delete.assert_called_once() + # only one valid unique label should be inserted. + assert mock_db.session.add.call_count == 1 + mock_db.session.commit.assert_called_once() + + +def test_tool_label_manager_update_tool_labels_unsupported(): + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.update_tool_labels(object(), ["search"]) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tool_labels_for_builtin_and_db(): + with patch.object( + _ConcreteBuiltinToolProviderController, + "tool_labels", + new_callable=PropertyMock, + return_value=["search", "news"], + ): + builtin = object.__new__(_ConcreteBuiltinToolProviderController) + assert ToolLabelManager.get_tool_labels(builtin) == ["search", "news"] + + api = _api_controller("api-1") + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["search", "news"] + labels = ToolLabelManager.get_tool_labels(api) + assert labels == ["search", "news"] + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tool_labels(object()) # type: ignore[arg-type] + + +def test_tool_label_manager_get_tools_labels_batch(): + assert ToolLabelManager.get_tools_labels([]) == {} + + api = _api_controller("api-1") + wf = _workflow_controller("wf-1") + records = [ + SimpleNamespace(tool_id="api-1", label_name="search"), + SimpleNamespace(tool_id="api-1", label_name="news"), + SimpleNamespace(tool_id="wf-1", label_name="utilities"), + ] + with patch("core.tools.tool_label_manager.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = records + labels = ToolLabelManager.get_tools_labels([api, wf]) + assert labels == {"api-1": ["search", "news"], "wf-1": ["utilities"]} + + with pytest.raises(ValueError, match="Unsupported tool type"): + ToolLabelManager.get_tools_labels([api, object()]) # type: ignore[list-item] diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py new file mode 100644 index 0000000000..0f73e22654 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -0,0 +1,899 @@ +from __future__ import annotations + +"""Unit tests for ToolManager behavior with mocked providers and collaborators.""" + +import json +import threading +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.plugin.entities.plugin_daemon import CredentialType +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.tool_entities import ( + ApiProviderAuthType, + ToolParameter, + ToolProviderType, +) +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController +from core.tools.tool_manager import ToolManager + + +class _SimpleContextVar: + def __init__(self): + self._is_set = False + self._value: Any = None + + def get(self): + if not self._is_set: + raise LookupError + return self._value + + def set(self, value: Any): + self._value = value + self._is_set = True + + +def _cm(session: Any): + context = Mock() + context.__enter__ = Mock(return_value=session) + context.__exit__ = Mock(return_value=False) + return context + + +def _setup_list_providers_from_api_mocks( + monkeypatch, + *, + session: Mock, + hardcoded_controller: SimpleNamespace, + plugin_controller: PluginToolProviderController, + api_controller: SimpleNamespace, + workflow_controller: SimpleNamespace, +): + mock_db = Mock() + mock_db.engine = object() + monkeypatch.setattr("core.tools.tool_manager.db", mock_db) + monkeypatch.setattr("core.tools.tool_manager.Session", lambda *args, **kwargs: _cm(session)) + monkeypatch.setattr( + ToolManager, + "list_builtin_providers", + Mock(return_value=[hardcoded_controller, plugin_controller]), + ) + monkeypatch.setattr( + ToolManager, + "list_default_builtin_providers", + Mock(return_value=[SimpleNamespace(provider="hardcoded")]), + ) + monkeypatch.setattr("core.tools.tool_manager.is_filtered", lambda *args, **kwargs: False) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.builtin_provider_to_user_provider", + lambda **kwargs: SimpleNamespace(name=kwargs["provider_controller"].entity.identity.name), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_controller", + Mock(side_effect=[api_controller, RuntimeError("invalid")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.api_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="api-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + Mock(side_effect=[workflow_controller, RuntimeError("deleted app")]), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_user_provider", + Mock(return_value=SimpleNamespace(name="workflow-provider")), + ) + monkeypatch.setattr( + "core.tools.tool_manager.ToolLabelManager.get_tools_labels", + Mock(side_effect=[{"api-1": ["search"]}, {"wf-1": ["utility"]}]), + ) + mock_mcp_service = Mock() + mock_mcp_service.list_providers.return_value = [SimpleNamespace(name="mcp-provider")] + monkeypatch.setattr("core.tools.tool_manager.MCPToolManageService", Mock(return_value=mock_mcp_service)) + monkeypatch.setattr("core.tools.tool_manager.BuiltinToolProviderSort.sort", lambda providers: providers) + + +@pytest.fixture(autouse=True) +def _reset_tool_manager_state(): + old_hardcoded = ToolManager._hardcoded_providers.copy() + old_loaded = ToolManager._builtin_providers_loaded + old_labels = ToolManager._builtin_tools_labels.copy() + try: + yield + finally: + ToolManager._hardcoded_providers = old_hardcoded + ToolManager._builtin_providers_loaded = old_loaded + ToolManager._builtin_tools_labels = old_labels + + +def test_get_hardcoded_provider_loads_cache_when_empty(): + provider = Mock() + ToolManager._hardcoded_providers = {} + + def _load(): + ToolManager._hardcoded_providers["weather"] = provider + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load) as mock_load: + assert ToolManager.get_hardcoded_provider("weather") is provider + + mock_load.assert_called_once() + + +def test_get_builtin_provider_returns_plugin_for_missing_hardcoded(): + hardcoded = Mock() + plugin_provider = Mock() + ToolManager._hardcoded_providers = {"time": hardcoded} + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + assert ToolManager.get_builtin_provider("time", "tenant-1") is hardcoded + assert ToolManager.get_builtin_provider("plugin/time", "tenant-1") is plugin_provider + + +def test_get_plugin_provider_uses_context_cache(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + provider_entity = SimpleNamespace(declaration=Mock(), plugin_id="pid", plugin_unique_identifier="uid") + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = provider_entity + controller = SimpleNamespace(name="controller") + with patch("core.tools.tool_manager.PluginToolProviderController", return_value=controller): + built = ToolManager.get_plugin_provider("provider-a", "tenant-1") + cached = ToolManager.get_plugin_provider("provider-a", "tenant-1") + + assert built is controller + assert cached is controller + mock_manager_cls.return_value.fetch_tool_provider.assert_called_once() + + +def test_get_plugin_provider_raises_when_provider_missing(): + provider_context = _SimpleContextVar() + lock_context = _SimpleContextVar() + lock_context.set(threading.Lock()) + + with patch("core.tools.tool_manager.contexts.plugin_tool_providers", provider_context): + with patch("core.tools.tool_manager.contexts.plugin_tool_providers_lock", lock_context): + with patch("core.tools.tool_manager.PluginToolManager") as mock_manager_cls: + mock_manager_cls.return_value.fetch_tool_provider.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="plugin provider provider-a not found"): + ToolManager.get_plugin_provider("provider-a", "tenant-1") + + +def test_get_tool_runtime_builtin_without_credentials(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace(get_tool=Mock(return_value=tool), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="current_time", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.tenant_id == "tenant-1" + assert runtime.credentials == {} + + +def test_get_tool_runtime_builtin_missing_tool_raises(): + controller = SimpleNamespace(get_tool=Mock(return_value=None), need_credentials=False) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with pytest.raises(ToolProviderNotFoundError, match="builtin tool missing not found"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="missing", + tenant_id="tenant-1", + ) + + +def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.API_KEY.value, + credentials={"encrypted": "value"}, + expires_at=-1, + user_id="user-1", + ) + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + with patch("core.helper.credential_utils.check_credential_policy_compliance"): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + builtin_provider + ) + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key": "secret"} + cache = Mock() + with patch("core.tools.tool_manager.create_provider_encrypter", return_value=(encrypter, cache)): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + runtime = tool.fork_tool_runtime.call_args.kwargs["runtime"] + assert runtime.credentials == {"api_key": "secret"} + assert runtime.credential_type == CredentialType.API_KEY + + +@patch("core.tools.tool_manager.create_provider_encrypter") +@patch("core.plugin.impl.oauth.OAuthHandler") +@patch( + "services.tools.builtin_tools_manage_service.BuiltinToolManageService.get_oauth_client", + return_value={"client_id": "id"}, +) +@patch("core.tools.tool_manager.db") +@patch("core.tools.tool_manager.time.time", return_value=1000) +@patch("core.helper.credential_utils.check_credential_policy_compliance") +def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( + mock_check, + mock_time, + mock_db, + mock_get_oauth_client, + mock_oauth_handler_cls, + mock_create_provider_encrypter, +): + tool = Mock() + tool.fork_tool_runtime.return_value = "runtime-tool" + controller = SimpleNamespace( + get_tool=Mock(return_value=tool), + need_credentials=True, + get_credentials_schema_by_type=Mock(return_value=[]), + ) + builtin_provider = SimpleNamespace( + id="cred-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"encrypted": "value"}, + encrypted_credentials=None, + expires_at=1, + user_id="user-1", + ) + refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) + + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + encrypter = Mock() + encrypter.decrypt.return_value = {"token": "old"} + encrypter.encrypt.return_value = {"token": "encrypted"} + cache = Mock() + mock_create_provider_encrypter.return_value = (encrypter, cache) + mock_oauth_handler_cls.return_value.refresh_credentials.return_value = refreshed + + with patch.object(ToolManager, "get_builtin_provider", return_value=controller): + result = ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + ) + + assert result == "runtime-tool" + assert builtin_provider.expires_at == refreshed.expires_at + assert builtin_provider.encrypted_credentials == json.dumps({"token": "encrypted"}) + mock_db.session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_get_tool_runtime_builtin_plugin_provider_deleted_raises(): + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(credentials_schema=[{"name": "k"}], oauth_schema=None) + plugin_controller.get_tool = Mock(return_value=Mock()) + plugin_controller.get_credentials_schema_by_type = Mock(return_value=[]) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_controller): + with patch("core.tools.tool_manager.is_valid_uuid", return_value=True): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.scalar.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="provider has been deleted"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.BUILT_IN, + provider_id="time", + tool_name="weekday", + tenant_id="tenant-1", + credential_id="uuid-id", + ) + + +def test_get_tool_runtime_api_path(): + api_tool = Mock() + api_tool.fork_tool_runtime.return_value = "api-runtime" + api_provider = Mock() + api_provider.get_tool.return_value = api_tool + + with patch.object(ToolManager, "get_api_provider_controller", return_value=(api_provider, {"c": "enc"})): + encrypter = Mock() + encrypter.decrypt.return_value = {"c": "dec"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + ) + == "api-runtime" + ) + + +def test_get_tool_runtime_workflow_path(): + workflow_provider = SimpleNamespace(tenant_id="tenant-1") + workflow_tool = Mock() + workflow_tool.fork_tool_runtime.return_value = "wf-runtime" + workflow_controller = Mock() + workflow_controller.get_tools.return_value = [workflow_tool] + session = Mock() + session.begin.return_value = _cm(None) + session.scalar.return_value = workflow_provider + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch( + "core.tools.tool_manager.ToolTransformService.workflow_provider_to_controller", + return_value=workflow_controller, + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.WORKFLOW, + provider_id="wf-1", + tool_name="wf", + tenant_id="tenant-1", + ) + == "wf-runtime" + ) + + +def test_get_tool_runtime_plugin_path(): + with patch.object( + ToolManager, + "get_plugin_provider", + return_value=SimpleNamespace(get_tool=lambda _: "plugin-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.PLUGIN, + provider_id="plugin-1", + tool_name="p", + tenant_id="tenant-1", + ) + == "plugin-tool" + ) + + +def test_get_tool_runtime_mcp_path(): + with patch.object( + ToolManager, + "get_mcp_provider_controller", + return_value=SimpleNamespace(get_tool=lambda _: "mcp-tool"), + ): + assert ( + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.MCP, + provider_id="mcp-1", + tool_name="m", + tenant_id="tenant-1", + ) + == "mcp-tool" + ) + + +def test_get_tool_runtime_app_not_implemented(): + with pytest.raises(NotImplementedError, match="app provider not implemented"): + ToolManager.get_tool_runtime( + provider_type=ToolProviderType.APP, + provider_id="app", + tool_name="x", + tenant_id="tenant-1", + ) + + +def test_get_agent_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={"query": "hello"}, + credential_id=None, + ) + result = ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert result is tool_runtime + assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + + +def test_get_workflow_runtime_apply_runtime_parameters(): + parameter = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + parameter.form = ToolParameter.ToolParameterForm.FORM + + workflow_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_configurations={"query": "hello"}, + credential_id=None, + ) + tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): + manager = Mock() + manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=manager): + workflow_result = ToolManager.get_workflow_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + node_id="node-1", + workflow_tool=workflow_tool, + invoke_from=InvokeFrom.DEBUGGER, + variable_pool=None, + ) + + assert workflow_result is tool_runtime2 + assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + + +def test_get_agent_runtime_raises_when_runtime_missing(): + tool_runtime = SimpleNamespace(runtime=None, get_merged_runtime_parameters=lambda: []) + agent_tool = SimpleNamespace( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tool_parameters={}, + credential_id=None, + ) + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={}): + with patch("core.tools.tool_manager.ToolParameterConfigurationManager", return_value=Mock()): + with pytest.raises(ValueError, match="runtime not found"): + ToolManager.get_agent_tool_runtime( + tenant_id="tenant-1", + app_id="app-1", + agent_tool=agent_tool, + ) + + +def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): + form_param = ToolParameter.get_simple_instance( + name="q", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + form_param.form = ToolParameter.ToolParameterForm.FORM + llm_param = ToolParameter.get_simple_instance( + name="llm", + llm_description="llm", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + llm_param.form = ToolParameter.ToolParameterForm.LLM + + tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) + tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) + + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + result = ToolManager.get_tool_runtime_from_plugin( + tool_type=ToolProviderType.API, + tenant_id="tenant-1", + provider="api-1", + tool_name="search", + tool_parameters={"q": "hello", "llm": "ignore"}, + ) + + assert result is tool_entity + assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + + +def test_hardcoded_provider_icon_success(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=True): + with patch("core.tools.tool_manager.mimetypes.guess_type", return_value=("image/svg+xml", None)): + icon_path, mime = ToolManager.get_hardcoded_provider_icon("time") + assert icon_path.endswith("icon.svg") + assert mime == "image/svg+xml" + + +def test_hardcoded_provider_icon_missing_raises(): + provider = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(icon="icon.svg"))) + with patch.object(ToolManager, "get_hardcoded_provider", return_value=provider): + with patch("core.tools.tool_manager.path.exists", return_value=False): + with pytest.raises(ToolProviderNotFoundError, match="icon not found"): + ToolManager.get_hardcoded_provider_icon("time") + + +def test_list_hardcoded_providers_cache_hit(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + assert list(ToolManager.list_hardcoded_providers()) == list(ToolManager._hardcoded_providers.values()) + + +def test_clear_hardcoded_providers_cache_resets(): + ToolManager._hardcoded_providers = {"p": Mock()} + ToolManager._builtin_providers_loaded = True + ToolManager.clear_hardcoded_providers_cache() + assert ToolManager._hardcoded_providers == {} + assert ToolManager._builtin_providers_loaded is False + + +def test_list_hardcoded_providers_internal_loader(): + good_provider = SimpleNamespace( + entity=SimpleNamespace(identity=SimpleNamespace(name="good")), + get_tools=lambda: [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="tool-a", label="A")))], + ) + provider_class = Mock(return_value=good_provider) + + with patch("core.tools.tool_manager.listdir", return_value=["good", "bad", "__skip"]): + with patch("core.tools.tool_manager.path.isdir", side_effect=lambda p: "good" in p or "bad" in p): + with patch( + "core.tools.tool_manager.load_single_subclass_from_source", + side_effect=[provider_class, RuntimeError("boom")], + ): + ToolManager._hardcoded_providers = {} + ToolManager._builtin_tools_labels = {} + providers = list(ToolManager._list_hardcoded_providers()) + + assert providers == [good_provider] + assert ToolManager._hardcoded_providers["good"] is good_provider + assert ToolManager._builtin_tools_labels["tool-a"] == "A" + assert ToolManager._builtin_providers_loaded is True + + +def test_get_tool_label_loads_cache_and_handles_missing(): + ToolManager._builtin_tools_labels = {} + + def _load(): + ToolManager._builtin_tools_labels["tool-a"] = "Label A" + + with patch.object(ToolManager, "load_hardcoded_providers_cache", side_effect=_load): + assert ToolManager.get_tool_label("tool-a") == "Label A" + assert ToolManager.get_tool_label("missing") is None + + +def test_list_default_builtin_providers_for_postgres_and_mysql(): + provider_records = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + + for scheme in ("postgresql", "mysql"): + session = Mock() + session.execute.return_value.all.return_value = [SimpleNamespace(id="id-1"), SimpleNamespace(id="id-2")] + session.query.return_value.where.return_value.all.return_value = provider_records + + with patch("core.tools.tool_manager.dify_config", SimpleNamespace(SQLALCHEMY_DATABASE_URI_SCHEME=scheme)): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + providers = ToolManager.list_default_builtin_providers("tenant-1") + + assert providers == provider_records + + +def test_list_providers_from_api_covers_builtin_api_workflow_and_mcp(monkeypatch): + hardcoded_controller = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="hardcoded"))) + plugin_controller = object.__new__(PluginToolProviderController) + plugin_controller.entity = SimpleNamespace(identity=SimpleNamespace(name="plugin-provider")) + + api_db_provider_good = SimpleNamespace(id="api-1") + api_db_provider_bad = SimpleNamespace(id="api-2") + api_controller = SimpleNamespace(provider_id="api-1") + + workflow_db_provider_good = SimpleNamespace(id="wf-1") + workflow_db_provider_bad = SimpleNamespace(id="wf-2") + workflow_controller = SimpleNamespace(provider_id="wf-1") + + session = Mock() + session.scalars.side_effect = [ + SimpleNamespace(all=lambda: [api_db_provider_good, api_db_provider_bad]), + SimpleNamespace(all=lambda: [workflow_db_provider_good, workflow_db_provider_bad]), + ] + + _setup_list_providers_from_api_mocks( + monkeypatch, + session=session, + hardcoded_controller=hardcoded_controller, + plugin_controller=plugin_controller, + api_controller=api_controller, + workflow_controller=workflow_controller, + ) + providers = ToolManager.list_providers_from_api(user_id="user-1", tenant_id="tenant-1", typ="") + + names = {provider.name for provider in providers} + assert {"hardcoded", "plugin-provider", "api-provider", "workflow-provider", "mcp-provider"} <= names + + +def test_get_api_provider_controller_returns_controller_and_credentials(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch( + "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller + ) as mock_from_db: + built_controller, credentials = ToolManager.get_api_provider_controller("tenant-1", "api-1") + + assert built_controller is controller + assert credentials == provider.credentials + mock_from_db.assert_called_with(provider, ApiProviderAuthType.API_KEY_QUERY) + controller.load_bundled_tools.assert_called_once_with(provider.tools) + + +def test_user_get_api_provider_masks_credentials_and_adds_labels(): + provider = SimpleNamespace( + id="api-1", + tenant_id="tenant-1", + name="api-provider", + description="desc", + credentials={"auth_type": "api_key_query"}, + credentials_str='{"auth_type": "api_key_query", "api_key_value": "secret"}', + schema_type="openapi", + schema="schema", + tools=[], + icon='{"background": "#000", "content": "A"}', + privacy_policy="privacy", + custom_disclaimer="disclaimer", + ) + db_query = Mock() + db_query.where.return_value.first.return_value = provider + controller = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value = db_query + with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): + encrypter = Mock() + encrypter.decrypt.return_value = {"api_key_value": "secret"} + encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} + with patch("core.tools.tool_manager.create_tool_provider_encrypter", return_value=(encrypter, Mock())): + with patch("core.tools.tool_manager.ToolLabelManager.get_tool_labels", return_value=["search"]): + user_payload = ToolManager.user_get_api_provider("api-provider", "tenant-1") + + assert user_payload["credentials"]["api_key_value"] == "***" + assert user_payload["labels"] == ["search"] + + +def test_get_api_provider_controller_not_found_raises(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): + ToolManager.get_api_provider_controller("tenant-1", "missing") + + +def test_get_mcp_provider_controller_returns_controller(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + controller = Mock() + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider.return_value = provider_entity + with patch("core.tools.tool_manager.MCPToolProviderController.from_db", return_value=controller): + built = ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + assert built is controller + + +def test_generate_mcp_tool_icon_url_returns_provider_icon(): + provider_entity = SimpleNamespace(provider_icon={"background": "#111", "content": "M"}) + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service = mock_service_cls.return_value + mock_service.get_provider_entity.return_value = provider_entity + assert ToolManager.generate_mcp_tool_icon_url("tenant-1", "mcp-1") == provider_entity.provider_icon + + +def test_get_mcp_provider_controller_missing_raises(): + session = Mock() + + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.tool_manager.Session", return_value=_cm(session)): + with patch("core.tools.tool_manager.MCPToolManageService") as mock_service_cls: + mock_service_cls.return_value.get_provider.side_effect = ValueError("missing") + with pytest.raises(ToolProviderNotFoundError, match="mcp provider mcp-1 not found"): + ToolManager.get_mcp_provider_controller("tenant-1", "mcp-1") + + +def test_generate_tool_icon_urls_for_builtin_and_plugin(): + with patch("core.tools.tool_manager.dify_config.CONSOLE_API_URL", "https://console.example.com"): + builtin_url = ToolManager.generate_builtin_tool_icon_url("time") + plugin_url = ToolManager.generate_plugin_tool_icon_url("tenant-1", "icon.svg") + + assert builtin_url.endswith("/tool-provider/builtin/time/icon") + assert "/plugin/icon" in plugin_url + + +def test_generate_tool_icon_urls_for_workflow_and_api(): + workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') + api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} + assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} + + +def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): + with patch("core.tools.tool_manager.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = None + assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" + + +def test_get_tool_icon_for_builtin_provider_variants(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_builtin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", return_value="plugin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "plugin-provider") == "plugin-icon" + + with patch.object(ToolManager, "get_builtin_provider", return_value=SimpleNamespace()): + with patch.object(ToolManager, "generate_builtin_tool_icon_url", return_value="builtin-icon"): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.BUILT_IN, "time") == "builtin-icon" + + +def test_get_tool_icon_for_api_workflow_and_mcp(): + with patch.object(ToolManager, "generate_api_tool_icon_url", return_value={"background": "#000"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.API, "api-1") == {"background": "#000"} + + with patch.object(ToolManager, "generate_workflow_tool_icon_url", return_value={"background": "#111"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.WORKFLOW, "wf-1") == {"background": "#111"} + + with patch.object(ToolManager, "generate_mcp_tool_icon_url", return_value={"background": "#222"}): + assert ToolManager.get_tool_icon("tenant-1", ToolProviderType.MCP, "mcp-1") == {"background": "#222"} + + +def test_get_tool_icon_plugin_error_returns_default(): + plugin_provider = object.__new__(PluginToolProviderController) + plugin_provider.entity = SimpleNamespace(identity=SimpleNamespace(icon="plugin.svg")) + + with patch.object(ToolManager, "get_plugin_provider", return_value=plugin_provider): + with patch.object(ToolManager, "generate_plugin_tool_icon_url", side_effect=RuntimeError("fail")): + icon = ToolManager.get_tool_icon("tenant-1", ToolProviderType.PLUGIN, "plugin-provider") + assert icon["background"] == "#252525" + + +def test_get_tool_icon_invalid_provider_type_raises(): + with pytest.raises(ValueError, match="provider type"): + ToolManager.get_tool_icon("tenant-1", "invalid", "x") # type: ignore[arg-type] + + +def test_convert_tool_parameters_type_agent_and_workflow_branches(): + file_param = ToolParameter.get_simple_instance( + name="file", + llm_description="file", + typ=ToolParameter.ToolParameterType.FILE, + required=True, + ) + file_param.form = ToolParameter.ToolParameterForm.FORM + + with pytest.raises(ValueError, match="file type parameter file not supported in agent"): + ToolManager._convert_tool_parameters_type( + parameters=[file_param], + variable_pool=None, + tool_configurations={"file": "x"}, + typ="agent", + ) + + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + plain = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=None, + tool_configurations={"text": "hello"}, + typ="workflow", + ) + assert plain == {"text": "hello"} + + variable_pool = Mock() + variable_pool.get.return_value = SimpleNamespace(value="from-variable") + variable_pool.convert_template.return_value = SimpleNamespace(text="from-template") + + mixed = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "mixed", "value": "Hello {{name}}"}}, + typ="workflow", + ) + assert mixed == {"text": "from-template"} + + variable = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "variable", "value": ["sys", "query"]}}, + typ="workflow", + ) + assert variable == {"text": "from-variable"} + + +def test_convert_tool_parameters_type_constant_branch(): + text_param = ToolParameter.get_simple_instance( + name="text", + llm_description="text", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + variable_pool = Mock() + + constant = ToolManager._convert_tool_parameters_type( + parameters=[text_param], + variable_pool=variable_pool, + tool_configurations={"text": {"type": "constant", "value": "fixed"}}, + typ="workflow", + ) + + assert constant == {"text": "fixed"} diff --git a/api/tests/unit_tests/core/tools/test_tool_provider_controller.py b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py new file mode 100644 index 0000000000..30b8494c92 --- /dev/null +++ b/api/tests/unit_tests/core/tools/test_tool_provider_controller.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any + +import pytest + +from core.entities.provider_entities import ProviderConfig +from core.tools.__base.tool import Tool +from core.tools.__base.tool_provider import ToolProviderController +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.errors import ToolProviderCredentialValidationError + + +class _DummyTool(Tool): + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + +class _DummyController(ToolProviderController): + def get_tool(self, tool_name: str) -> Tool: + entity = ToolEntity( + identity=ToolIdentity( + author="author", + name=tool_name, + label=I18nObject(en_US=tool_name), + provider="provider", + ), + parameters=[], + ) + return _DummyTool(entity=entity, runtime=ToolRuntime(tenant_id="tenant")) + + +def _provider_identity() -> ToolProviderIdentity: + return ToolProviderIdentity( + author="author", + name="provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="Provider"), + ) + + +def test_tool_provider_controller_get_credentials_schema_returns_deep_copy(): + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="api_key", required=False)], + ) + controller = _DummyController(entity=entity) + + schema = controller.get_credentials_schema() + schema[0].name = "changed" + + assert controller.entity.credentials_schema[0].name == "api_key" + + +def test_tool_provider_controller_default_provider_type(): + entity = ToolProviderEntity(identity=_provider_identity(), credentials_schema=[]) + controller = _DummyController(entity=entity) + + assert controller.provider_type == ToolProviderType.BUILT_IN + + +def test_validate_credentials_format_covers_required_default_and_type_rules(): + select_options = [ProviderConfig.Option(value="opt-a", label=I18nObject(en_US="A"))] + entity = ToolProviderEntity( + identity=_provider_identity(), + credentials_schema=[ + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="required_text", required=True), + ProviderConfig(type=ProviderConfig.Type.SECRET_INPUT, name="secret", required=False), + ProviderConfig(type=ProviderConfig.Type.SELECT, name="choice", required=False, options=select_options), + ProviderConfig(type=ProviderConfig.Type.TEXT_INPUT, name="with_default", required=False, default="x"), + ], + ) + controller = _DummyController(entity=entity) + + credentials = {"required_text": "value", "secret": None, "choice": "opt-a"} + controller.validate_credentials_format(credentials) + assert credentials["with_default"] == "x" + + with pytest.raises(ToolProviderCredentialValidationError, match="not found"): + controller.validate_credentials_format({"required_text": "value", "unknown": "v"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="is required"): + controller.validate_credentials_format({"secret": "s"}) + + with pytest.raises(ToolProviderCredentialValidationError, match="should be string"): + controller.validate_credentials_format({"required_text": 123}) # type: ignore[arg-type] + + with pytest.raises(ToolProviderCredentialValidationError, match="should be one of"): + controller.validate_credentials_format({"required_text": "value", "choice": "opt-b"}) diff --git a/api/tests/unit_tests/core/tools/utils/test_configuration.py b/api/tests/unit_tests/core/tools/utils/test_configuration.py new file mode 100644 index 0000000000..5ceaa08893 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_configuration.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from collections.abc import Generator +from typing import Any +from unittest.mock import patch + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.tool_parameter_cache import ToolParameterCache +from core.tools.__base.tool import Tool +from core.tools.__base.tool_runtime import ToolRuntime +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) +from core.tools.utils.configuration import ToolParameterConfigurationManager + + +class _DummyTool(Tool): + runtime_overrides: list[ToolParameter] + + def __init__(self, entity: ToolEntity, runtime: ToolRuntime, runtime_overrides: list[ToolParameter]): + super().__init__(entity=entity, runtime=runtime) + self.runtime_overrides = runtime_overrides + + def tool_provider_type(self) -> ToolProviderType: + return ToolProviderType.BUILT_IN + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> Generator[ToolInvokeMessage, None, None]: + yield self.create_text_message("ok") + + def get_runtime_parameters( + self, + conversation_id: str | None = None, + app_id: str | None = None, + message_id: str | None = None, + ) -> list[ToolParameter]: + return self.runtime_overrides + + +def _param( + name: str, + *, + typ: ToolParameter.ToolParameterType, + form: ToolParameter.ToolParameterForm, + required: bool = False, +) -> ToolParameter: + return ToolParameter( + name=name, + label=I18nObject(en_US=name), + placeholder=I18nObject(en_US=""), + human_description=I18nObject(en_US=""), + type=typ, + form=form, + required=required, + default=None, + ) + + +def _build_manager() -> ToolParameterConfigurationManager: + base_params = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("plain", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + runtime_overrides = [ + _param("secret", typ=ToolParameter.ToolParameterType.SECRET_INPUT, form=ToolParameter.ToolParameterForm.FORM), + _param("runtime_only", typ=ToolParameter.ToolParameterType.STRING, form=ToolParameter.ToolParameterForm.FORM), + ] + entity = ToolEntity( + identity=ToolIdentity(author="a", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), + parameters=base_params, + ) + runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + tool = _DummyTool(entity=entity, runtime=runtime, runtime_overrides=runtime_overrides) + return ToolParameterConfigurationManager( + tenant_id="tenant-1", + tool_runtime=tool, + provider_name="provider-a", + provider_type=ToolProviderType.BUILT_IN, + identity_id="ID.1", + ) + + +def test_merge_and_mask_parameters(): + manager = _build_manager() + + masked = manager.mask_tool_parameters({"secret": "abcdefghi", "plain": "x", "runtime_only": "y"}) + assert masked["secret"] == "ab*****hi" + assert masked["plain"] == "x" + assert masked["runtime_only"] == "y" + + +def test_encrypt_tool_parameters(): + manager = _build_manager() + + with patch("core.tools.utils.configuration.encrypter.encrypt_token", return_value="enc"): + encrypted = manager.encrypt_tool_parameters({"secret": "raw", "plain": "x"}) + + assert encrypted["secret"] == "enc" + assert encrypted["plain"] == "x" + + +def test_decrypt_tool_parameters_cache_hit_and_miss(): + manager = _build_manager() + + with ( + patch.object(ToolParameterCache, "get", return_value={"secret": "cached"}), + patch.object(ToolParameterCache, "set") as mock_set, + ): + assert manager.decrypt_tool_parameters({"secret": "enc"}) == {"secret": "cached"} + mock_set.assert_not_called() + + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch.object(ToolParameterCache, "set") as mock_set, + patch("core.tools.utils.configuration.encrypter.decrypt_token", return_value="dec"), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc", "plain": "x"}) + assert decrypted["secret"] == "dec" + mock_set.assert_called_once() + + +def test_delete_tool_parameters_cache(): + manager = _build_manager() + + with patch.object(ToolParameterCache, "delete") as mock_delete: + manager.delete_tool_parameters_cache() + + mock_delete.assert_called_once() + + +def test_configuration_manager_decrypt_suppresses_errors(): + manager = _build_manager() + with ( + patch.object(ToolParameterCache, "get", return_value=None), + patch("core.tools.utils.configuration.encrypter.decrypt_token", side_effect=RuntimeError("boom")), + ): + decrypted = manager.decrypt_tool_parameters({"secret": "enc"}) + # decryption failure is suppressed, original value is retained. + assert decrypted["secret"] == "enc" diff --git a/api/tests/unit_tests/core/tools/utils/test_encryption.py b/api/tests/unit_tests/core/tools/utils/test_encryption.py index 94be0bb573..ce77473dbd 100644 --- a/api/tests/unit_tests/core/tools/utils/test_encryption.py +++ b/api/tests/unit_tests/core/tools/utils/test_encryption.py @@ -1,10 +1,13 @@ import copy -from unittest.mock import patch +from types import SimpleNamespace +from typing import Any +from unittest.mock import Mock, patch import pytest from core.entities.provider_entities import BasicProviderConfig from core.helper.provider_encryption import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter # --------------------------- @@ -13,13 +16,13 @@ from core.helper.provider_encryption import ProviderConfigEncrypter class NoopCache: """Simple cache stub: always returns None, does nothing for set/delete.""" - def get(self): + def get(self) -> Any | None: return None - def set(self, config): + def set(self, config: Any) -> None: pass - def delete(self): + def delete(self) -> None: pass @@ -179,3 +182,35 @@ def test_decrypt_swallow_exception_and_keep_original(encrypter_obj): out = encrypter_obj.decrypt({"password": "ENC_ERR"}) assert out["password"] == "ENC_ERR" + + +def test_create_tool_provider_encrypter_builds_cache_and_encrypter(): + basic_config = BasicProviderConfig(name="key", type=BasicProviderConfig.Type.TEXT_INPUT) + credential_schema_item = SimpleNamespace(to_basic_provider_config=lambda: basic_config) + controller = SimpleNamespace( + provider_type=SimpleNamespace(value="builtin"), + entity=SimpleNamespace(identity=SimpleNamespace(name="provider-a")), + get_credentials_schema=lambda: [credential_schema_item], + ) + + cache_instance = Mock() + encrypter_instance = Mock() + + with patch( + "core.tools.utils.encryption.SingletonProviderCredentialsCache", return_value=cache_instance + ) as cache_cls: + with patch("core.tools.utils.encryption.ProviderConfigEncrypter", return_value=encrypter_instance) as enc_cls: + encrypter, cache = create_tool_provider_encrypter("tenant-1", controller) + + assert encrypter is encrypter_instance + assert cache is cache_instance + cache_cls.assert_called_once_with( + tenant_id="tenant-1", + provider_type="builtin", + provider_identity="provider-a", + ) + enc_cls.assert_called_once_with( + tenant_id="tenant-1", + config=[basic_config], + provider_config_cache=cache_instance, + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py new file mode 100644 index 0000000000..4ce73272bf --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -0,0 +1,478 @@ +from __future__ import annotations + +import uuid +from contextlib import nullcontext +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +from yaml import YAMLError + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.rag.models.document import Document as RagDocument +from core.tools.utils.dataset_retriever import dataset_multi_retriever_tool as multi_retriever_module +from core.tools.utils.dataset_retriever import dataset_retriever_tool as single_retriever_module +from core.tools.utils.dataset_retriever.dataset_multi_retriever_tool import DatasetMultiRetrieverTool +from core.tools.utils.dataset_retriever.dataset_retriever_tool import DatasetRetrieverTool as SingleDatasetRetrieverTool +from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.tools.utils.uuid_utils import is_valid_uuid +from core.tools.utils.yaml_utils import _load_yaml_file, load_yaml_file_cached + + +def _retrieve_config() -> DatasetRetrieveConfigEntity: + return DatasetRetrieveConfigEntity(retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE) + + +class _FakeFlaskApp: + def app_context(self): + return nullcontext() + + +class _ImmediateThread: + def __init__(self, target=None, kwargs=None, **_kwargs): + self._target = target + self._kwargs = kwargs or {} + + def start(self): + if self._target is not None: + self._target(**self._kwargs) + + def join(self): + return None + + +class _TestHitCallback(DatasetIndexToolCallbackHandler): + def __init__(self): + self.queries: list[tuple[str, str]] = [] + self.documents: list[RagDocument] | None = None + self.resources = None + + def on_query(self, query: str, dataset_id: str): + self.queries.append((query, dataset_id)) + + def on_tool_end(self, documents: list[RagDocument]): + self.documents = documents + + def return_retriever_resource_info(self, resource): + self.resources = list(resource) + + +def test_remove_leading_symbols_preserves_markdown_link_and_strips_punctuation(): + markdown = "[Example](https://example.com) content" + assert remove_leading_symbols(markdown) == markdown + + assert remove_leading_symbols("...Hello world") == "Hello world" + + +def test_is_valid_uuid_handles_valid_invalid_and_empty_values(): + assert is_valid_uuid(str(uuid.uuid4())) is True + assert is_valid_uuid("not-a-uuid") is False + assert is_valid_uuid("") is False + assert is_valid_uuid(None) is False + + +def test_load_yaml_file_valid(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + loaded = _load_yaml_file(file_path=str(valid_file)) + + assert loaded == {"a": 1, "b": "two"} + + +def test_load_yaml_file_missing(tmp_path): + with pytest.raises(FileNotFoundError): + _load_yaml_file(file_path=str(tmp_path / "missing.yaml")) + + +def test_load_yaml_file_invalid(tmp_path): + invalid_file = tmp_path / "invalid.yaml" + invalid_file.write_text("a: [1, 2\n", encoding="utf-8") + + with pytest.raises(YAMLError): + _load_yaml_file(file_path=str(invalid_file)) + + +def test_load_yaml_file_cached_hits(tmp_path): + valid_file = tmp_path / "valid.yaml" + valid_file.write_text("a: 1\nb: two\n", encoding="utf-8") + + load_yaml_file_cached.cache_clear() + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + + assert load_yaml_file_cached(str(valid_file)) == {"a": 1, "b": "two"} + assert load_yaml_file_cached.cache_info().hits == 1 + + +def test_single_dataset_retriever_from_dataset_builds_name_and_description(): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="Knowledge", description=None) + + tool = SingleDatasetRetrieverTool.from_dataset( + dataset=dataset, + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + inputs={}, + ) + + assert tool.name == "dataset_dataset_1" + assert tool.description == "useful for when you want to answer queries about the Knowledge" + + +def test_single_dataset_retriever_external_run_returns_content_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="external", + indexing_technique="high_quality", + retrieval_model={}, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ( + {"dataset-1": ["doc-a"]}, + {"logical_operator": "and"}, + ) + db_session = Mock() + db_session.scalar.return_value = dataset + external_documents = [ + {"content": "first", "metadata": {"document_id": "doc-a"}, "score": 0.9, "title": "Doc A"}, + {"content": "second", "metadata": {"document_id": "doc-b"}, "score": 0.8, "title": "Doc B"}, + ] + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={"x": 1}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object( + single_retriever_module.ExternalDatasetService, + "fetch_external_knowledge_retrieval", + return_value=external_documents, + ) as fetch_mock: + result = tool.run(query="hello") + + assert result == "first\nsecond" + assert callback.queries == [("hello", "dataset-1")] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].dataset_id == "dataset-1" + fetch_mock.assert_called_once() + + +def test_single_dataset_retriever_returns_empty_when_metadata_filter_finds_no_documents(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model=None, + ) + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = ({"dataset-1": []}, {"logical_operator": "and"}) + db_session = Mock() + db_session.scalar.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=False, + retriever_from="prod", + hit_callbacks=[_TestHitCallback()], + inputs={}, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool.run(query="hello") + + assert result == "" + retrieve_mock.assert_not_called() + + +def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="Knowledge Base", + provider="internal", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "score_threshold_enabled": True, + "score_threshold": 0.2, + "reranking_enable": True, + "reranking_model": {"reranking_provider_name": "provider", "reranking_model_name": "model"}, + "reranking_mode": "reranking_model", + "weights": {"vector_setting": {"vector_weight": 0.6}}, + }, + ) + callback = _TestHitCallback() + dataset_retrieval = Mock() + dataset_retrieval.get_metadata_filter_condition.return_value = (None, None) + low_segment = SimpleNamespace( + id="seg-low", + dataset_id="dataset-1", + document_id="doc-low", + content="raw low", + answer="low answer", + hit_count=1, + word_count=10, + position=3, + index_node_hash="hash-low", + get_sign_content=lambda: "signed low", + ) + high_segment = SimpleNamespace( + id="seg-high", + dataset_id="dataset-1", + document_id="doc-high", + content="raw high", + answer=None, + hit_count=9, + word_count=25, + position=1, + index_node_hash="hash-high", + get_sign_content=lambda: "signed high", + ) + records = [ + SimpleNamespace(segment=low_segment, score=0.2, summary="summary low"), + SimpleNamespace(segment=high_segment, score=0.9, summary=None), + ] + documents = [ + RagDocument(page_content="first", metadata={"doc_id": "node-low", "score": 0.2}), + RagDocument(page_content="second", metadata={"doc_id": "node-high", "score": 0.9}), + ] + lookup_doc_low = SimpleNamespace( + id="doc-low", name="Document Low", data_source_type="upload_file", doc_metadata={"lang": "en"} + ) + lookup_doc_high = SimpleNamespace( + id="doc-high", name="Document High", data_source_type="notion", doc_metadata={"lang": "fr"} + ) + db_session = Mock() + db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] + db_session.query.return_value.filter_by.return_value.first.return_value = dataset + + tool = SingleDatasetRetrieverTool( + tenant_id="tenant-1", + dataset_id="dataset-1", + retrieve_config=_retrieve_config(), + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + inputs={}, + top_k=2, + ) + + with patch.object(single_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(single_retriever_module, "DatasetRetrieval", return_value=dataset_retrieval): + with patch.object(single_retriever_module.RetrievalService, "retrieve", return_value=documents): + with patch.object( + single_retriever_module.RetrievalService, + "format_retrieval_documents", + return_value=records, + ): + result = tool.run(query="hello") + + assert result == "signed high\nsummary low\nquestion:signed low answer:low answer" + assert callback.documents == documents + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].segment_id == "seg-high" + assert resource_info[0].hit_count == 9 + assert resource_info[1].summary == "summary low" + assert resource_info[1].content == "question:raw low \nanswer:low answer" + + +def test_multi_dataset_retriever_from_dataset_sets_tool_name(): + tool = DatasetMultiRetrieverTool.from_dataset( + dataset_ids=["dataset-1"], + tenant_id="tenant-1", + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + assert tool.name == "dataset_tenant_1" + + +def test_multi_dataset_retriever_retriever_returns_early_when_dataset_is_missing(): + callback = _TestHitCallback() + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = None + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve") as retrieve_mock: + result = tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert result == [] + assert all_documents == [] + assert callback.queries == [] + retrieve_mock.assert_not_called() + + +def test_multi_dataset_retriever_retriever_non_economy_uses_retrieval_model(): + dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + indexing_technique="high_quality", + retrieval_model={ + "search_method": "semantic_search", + "top_k": 6, + "score_threshold_enabled": True, + "score_threshold": 0.4, + "reranking_enable": False, + "reranking_mode": None, + "weights": {"balanced": True}, + }, + ) + callback = _TestHitCallback() + documents = [RagDocument(page_content="retrieved", metadata={"doc_id": "node-1", "score": 0.4})] + all_documents: list[RagDocument] = [] + db_session = Mock() + db_session.scalar.return_value = dataset + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=False, + retriever_from="prod", + top_k=2, + ) + + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + with patch.object(multi_retriever_module.RetrievalService, "retrieve", return_value=documents) as retrieve_mock: + tool._retriever( + flask_app=_FakeFlaskApp(), + dataset_id="dataset-1", + query="hello", + all_documents=all_documents, + hit_callbacks=[callback], + ) + + assert all_documents == documents + assert callback.queries == [("hello", "dataset-1")] + retrieve_mock.assert_called_once_with( + retrieval_method="semantic_search", + dataset_id="dataset-1", + query="hello", + top_k=6, + score_threshold=0.4, + reranking_model=None, + reranking_mode="reranking_model", + weights={"balanced": True}, + ) + + +def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): + callback = _TestHitCallback() + tool = DatasetMultiRetrieverTool( + tenant_id="tenant-1", + dataset_ids=["dataset-1", "dataset-2"], + reranking_provider_name="provider", + reranking_model_name="model", + return_resource=True, + retriever_from="dev", + hit_callbacks=[callback], + top_k=2, + score_threshold=0.1, + ) + first_doc = RagDocument(page_content="first", metadata={"doc_id": "node-2", "score": 0.4}) + second_doc = RagDocument(page_content="second", metadata={"doc_id": "node-1", "score": 0.9}) + + def fake_retriever(**kwargs): + if kwargs["dataset_id"] == "dataset-1": + kwargs["all_documents"].append(first_doc) + else: + kwargs["all_documents"].append(second_doc) + + segment_for_node_2 = SimpleNamespace( + id="seg-2", + dataset_id="dataset-1", + document_id="doc-2", + index_node_id="node-2", + content="raw two", + answer="answer two", + hit_count=2, + word_count=20, + position=2, + index_node_hash="hash-2", + get_sign_content=lambda: "signed two", + ) + segment_for_node_1 = SimpleNamespace( + id="seg-1", + dataset_id="dataset-2", + document_id="doc-1", + index_node_id="node-1", + content="raw one", + answer=None, + hit_count=7, + word_count=30, + position=1, + index_node_hash="hash-1", + get_sign_content=lambda: "signed one", + ) + db_session = Mock() + db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] + db_session.query.return_value.filter_by.return_value.first.side_effect = [ + SimpleNamespace(id="dataset-2", name="Dataset Two"), + SimpleNamespace(id="dataset-1", name="Dataset One"), + ] + db_session.scalar.side_effect = [ + SimpleNamespace(id="doc-1", name="Doc One", data_source_type="upload_file", doc_metadata={"p": 1}), + SimpleNamespace(id="doc-2", name="Doc Two", data_source_type="notion", doc_metadata={"p": 2}), + ] + model_manager = Mock() + model_manager.get_model_instance.return_value = Mock() + rerank_runner = Mock() + rerank_runner.run.return_value = [second_doc, first_doc] + fake_current_app = SimpleNamespace(_get_current_object=lambda: _FakeFlaskApp()) + + with patch.object(tool, "_retriever", side_effect=fake_retriever) as retriever_mock: + with patch.object(multi_retriever_module, "current_app", fake_current_app): + with patch.object(multi_retriever_module.threading, "Thread", _ImmediateThread): + with patch.object(multi_retriever_module, "ModelManager", return_value=model_manager): + with patch.object(multi_retriever_module, "RerankModelRunner", return_value=rerank_runner): + with patch.object(multi_retriever_module, "db", SimpleNamespace(session=db_session)): + result = tool.run(query="hello") + + assert result == "signed one\nquestion:signed two answer:answer two" + assert retriever_mock.call_count == 2 + assert callback.documents == [second_doc, first_doc] + assert callback.resources is not None + resource_info = callback.resources + assert [item.position for item in resource_info] == [1, 2] + assert resource_info[0].score == 0.9 + assert resource_info[0].content == "raw one" + assert resource_info[1].score == 0.4 + assert resource_info[1].content == "question:raw two \nanswer:answer two" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py new file mode 100644 index 0000000000..2acae889b2 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -0,0 +1,158 @@ +"""Unit tests for ModelInvocationUtils. + +Covers success and error branches for ModelInvocationUtils, including +InvokeModelError and invoke error mappings for InvokeAuthorizationError, +InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, and +InvokeServerUnavailableError. Assumes mocked model instances and managers. +""" + +from __future__ import annotations + +from decimal import Decimal +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest + +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: + model_type_instance = Mock() + model_type_instance.get_model_schema.return_value = ( + SimpleNamespace(model_properties=schema or {}) if schema is not None else None + ) + return SimpleNamespace( + provider="provider", + model="model-a", + model_name="model-a", + credentials={"api_key": "x"}, + model_type_instance=model_type_instance, + get_llm_num_tokens=lambda prompt_messages: 5, + invoke_llm=Mock(), + ) + + +@pytest.mark.parametrize( + ("model_instance", "expected", "error_match"), + [ + (None, None, "Model not found"), + (_mock_model_instance(schema=None), None, "No model schema found"), + (_mock_model_instance(schema={}), 2048, None), + (_mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 8192}), 8192, None), + ], + ids=[ + "missing-model", + "missing-schema", + "default-context-size", + "schema-context-size", + ], +) +def test_get_max_llm_context_tokens_branches(model_instance, expected, error_match): + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + if error_match: + with pytest.raises(InvokeModelError, match=error_match): + ModelInvocationUtils.get_max_llm_context_tokens("tenant") + else: + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + + +def test_calculate_tokens_handles_missing_model(): + manager = Mock() + manager.get_default_model_instance.return_value = None + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with pytest.raises(InvokeModelError, match="Model not found"): + ModelInvocationUtils.calculate_tokens("tenant", []) + + +def test_invoke_success_and_error_mappings(): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.return_value = SimpleNamespace( + message=SimpleNamespace(content="ok"), + usage=SimpleNamespace( + completion_tokens=7, + completion_unit_price=Decimal("0.1"), + completion_price_unit=Decimal(1), + latency=0.3, + total_price=Decimal("0.7"), + currency="USD", + ), + ) + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + response = ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) + + assert response.message.content == "ok" + assert db_mock.session.add.call_count == 1 + assert db_mock.session.commit.call_count == 2 + + +@pytest.mark.parametrize( + ("exc", "expected"), + [ + (InvokeRateLimitError("rate"), "Invoke rate limit error"), + (InvokeBadRequestError("bad"), "Invoke bad request error"), + (InvokeConnectionError("conn"), "Invoke connection error"), + (InvokeAuthorizationError("auth"), "Invoke authorization error"), + (InvokeServerUnavailableError("down"), "Invoke server unavailable error"), + (RuntimeError("oops"), "Invoke error"), + ], + ids=[ + "rate-limit", + "bad-request", + "connection", + "authorization", + "server-unavailable", + "generic-error", + ], +) +def test_invoke_error_mappings(exc, expected): + model_instance = _mock_model_instance(schema={ModelPropertyKey.CONTEXT_SIZE: 2048}) + model_instance.invoke_llm.side_effect = exc + manager = Mock() + manager.get_default_model_instance.return_value = model_instance + + class _ToolModelInvoke: + def __init__(self, **kwargs): + self.__dict__.update(kwargs) + + db_mock = SimpleNamespace(session=Mock()) + + with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): + with patch("core.tools.utils.model_invocation_utils.db", db_mock): + with pytest.raises(InvokeModelError, match=expected): + ModelInvocationUtils.invoke( + user_id="u1", + tenant_id="tenant", + tool_type="builtin", + tool_name="tool-a", + prompt_messages=[], + ) diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py index f39158aa59..40f91b12a0 100644 --- a/api/tests/unit_tests/core/tools/utils/test_parser.py +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -1,6 +1,12 @@ +from json.decoder import JSONDecodeError +from unittest.mock import Mock, patch + import pytest from flask import Flask +from yaml import YAMLError +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter +from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError from core.tools.utils.parser import ApiBasedToolSchemaParser @@ -189,3 +195,225 @@ def test_parse_openapi_to_tool_bundle_default_value_type_casting(app): available_param = params_by_name["available"] assert available_param.type == "boolean" assert available_param.default is True + + +def test_sanitize_default_value_and_type_detection(): + assert ApiBasedToolSchemaParser._sanitize_default_value([]) is None + assert ApiBasedToolSchemaParser._sanitize_default_value({}) is None + assert ApiBasedToolSchemaParser._sanitize_default_value("ok") == "ok" + + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"format": "binary"}) == ToolParameter.ToolParameterType.FILE + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "integer"}) == ToolParameter.ToolParameterType.NUMBER + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"schema": {"type": "boolean"}}) + == ToolParameter.ToolParameterType.BOOLEAN + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"format": "binary"}}) + == ToolParameter.ToolParameterType.FILES + ) + assert ( + ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "array", "items": {"type": "string"}}) + == ToolParameter.ToolParameterType.ARRAY + ) + assert ApiBasedToolSchemaParser._get_tool_parameter_type({"type": "object"}) is None + + +def test_parse_openapi_to_tool_bundle_server_env_and_refs(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "API", "version": "1.0.0", "description": "API description"}, + "servers": [ + {"url": "https://dev.example.com", "env": "dev"}, + {"url": "https://prod.example.com", "env": "prod"}, + ], + "paths": { + "/items": { + "post": { + "description": "Create item", + "parameters": [ + {"$ref": "#/components/parameters/token"}, + {"name": "token", "schema": {"type": "string"}}, + ], + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ItemRequest"}}} + }, + } + } + }, + "components": { + "parameters": { + "token": {"name": "token", "required": True, "schema": {"type": "string"}}, + }, + "schemas": { + "ItemRequest": { + "type": "object", + "required": ["age"], + "properties": {"age": {"type": "integer", "description": "Age", "default": 18}}, + } + }, + }, + } + + extra_info: dict = {} + warning: dict = {} + with app.test_request_context(headers={"X-Request-Env": "prod"}): + bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning) + + assert len(bundles) == 1 + assert bundles[0].server_url == "https://prod.example.com/items" + assert warning["duplicated_parameter"].startswith("Parameter token") + assert extra_info["description"] == "API description" + + +def test_parse_openapi_to_tool_bundle_no_server_raises(app): + openapi = {"info": {"title": "x"}, "servers": [], "paths": {}} + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="No server found"): + ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + +def test_parse_openapi_yaml_to_tool_bundle_invalid_yaml(app): + with app.test_request_context(): + with pytest.raises(ToolApiSchemaError, match="Invalid openapi yaml"): + ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle("null") + + +def test_parse_swagger_to_openapi_branches(): + with pytest.raises(ToolApiSchemaError, match="No server found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"info": {}, "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No paths found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi({"servers": [{"url": "https://x"}], "paths": {}}) + + with pytest.raises(ToolApiSchemaError, match="No operationId found"): + ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"summary": "x", "responses": {}}}}, + } + ) + + warning: dict = {"seed": True} + converted = ApiBasedToolSchemaParser.parse_swagger_to_openapi( + { + "servers": [{"url": "https://x"}], + "paths": {"/a": {"get": {"operationId": "getA", "responses": {}}}}, + "definitions": {"A": {"type": "object"}}, + }, + warning=warning, + ) + assert converted["openapi"] == "3.0.0" + assert converted["components"]["schemas"]["A"]["type"] == "object" + assert warning["missing_summary"].startswith("No summary or description found") + + +def test_parse_openai_plugin_json_branches(app): + with app.test_request_context(): + with pytest.raises(ToolProviderNotFoundError, match="Invalid openai plugin json"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle("{bad") + + with pytest.raises(ToolNotSupportedError, match="Only openapi is supported"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "graphql"}}' + ) + + +def test_parse_openai_plugin_json_http_branches(app): + with app.test_request_context(): + response = type("Resp", (), {"status_code": 500, "text": "", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=response): + with pytest.raises(ToolProviderNotFoundError, match="cannot get openapi yaml"): + ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + response.close.assert_called_once() + + success_response = type("Resp", (), {"status_code": 200, "text": "openapi: 3.0.0", "close": Mock()})() + with patch("core.tools.utils.parser.httpx.get", return_value=success_response): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle", + return_value=["bundle"], + ) as mock_parse: + bundles = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle( + '{"api": {"url": "https://x", "type": "openapi"}}' + ) + assert bundles == ["bundle"] + mock_parse.assert_called_once() + success_response.close.assert_called_once() + + +def test_auto_parse_json_yaml_failure(): + with patch("core.tools.utils.parser.json_loads", side_effect=JSONDecodeError("bad", "x", 0)): + with patch("core.tools.utils.parser.safe_load", side_effect=YAMLError("bad yaml")): + with pytest.raises(ToolApiSchemaError, match="Invalid api schema, schema is neither json nor yaml"): + ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(":::") + + +def test_auto_parse_openapi_success(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + return_value=["openapi-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["openapi-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAPI + + +def test_auto_parse_openapi_then_swagger(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + loaded_content = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + converted_swagger = { + "openapi": "3.0.0", + "servers": [{"url": "https://x"}], + "info": {"title": "x"}, + "paths": {}, + } + + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=[ToolApiSchemaError("openapi error"), ["swagger-bundle"]], + ) as mock_parse_openapi: + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + return_value=converted_swagger, + ) as mock_parse_swagger: + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["swagger-bundle"] + assert schema_type == ApiProviderSchemaType.SWAGGER + mock_parse_swagger.assert_called_once_with(loaded_content, extra_info={}, warning={}) + assert mock_parse_openapi.call_count == 2 + mock_parse_openapi.assert_any_call(loaded_content, extra_info={}, warning={}) + mock_parse_openapi.assert_any_call(converted_swagger, extra_info={}, warning={}) + + +def test_auto_parse_openapi_swagger_then_plugin(): + openapi_content = '{"openapi": "3.0.0", "servers": [{"url": "https://x"}], "info": {"title": "x"}, "paths": {}}' + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle", + side_effect=ToolApiSchemaError("openapi error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_swagger_to_openapi", + side_effect=ToolApiSchemaError("swagger error"), + ): + with patch( + "core.tools.utils.parser.ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle", + return_value=["plugin-bundle"], + ): + bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(openapi_content) + + assert bundles == ["plugin-bundle"] + assert schema_type == ApiProviderSchemaType.OPENAI_PLUGIN diff --git a/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py new file mode 100644 index 0000000000..5691f33e65 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_system_oauth_encryption.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +import pytest + +from core.tools.utils import system_oauth_encryption as oauth_encryption +from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter + + +def test_system_oauth_encrypter_roundtrip(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"} + + encrypted = encrypter.encrypt_oauth_params(payload) + decrypted = encrypter.decrypt_oauth_params(encrypted) + + assert encrypted + assert dict(decrypted) == payload + + +def test_system_oauth_encrypter_decrypt_validates_input(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(ValueError, match="must be a string"): + encrypter.decrypt_oauth_params(123) # type: ignore[arg-type] + + with pytest.raises(ValueError, match="cannot be empty"): + encrypter.decrypt_oauth_params("") + + +def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext(): + encrypter = SystemOAuthEncrypter(secret_key="test-secret") + + with pytest.raises(OAuthEncryptionError, match="Decryption failed"): + encrypter.decrypt_oauth_params("not-base64") + + +def test_system_oauth_helpers_use_global_cached_instance(monkeypatch): + monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None) + monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret") + + first = oauth_encryption.get_system_oauth_encrypter() + second = oauth_encryption.get_system_oauth_encrypter() + assert first is second + + encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"}) + assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"} + + +def test_create_system_oauth_encrypter_factory(): + encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret") + assert isinstance(encrypter, SystemOAuthEncrypter) diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index c46e31d90f..dd79b79718 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,7 +1,9 @@ import pytest +from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): @@ -31,3 +33,91 @@ def test_ensure_no_human_input_nodes_raises_for_human_input(): WorkflowToolConfigurationUtils.ensure_no_human_input_nodes(graph) assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + + +def test_get_workflow_graph_variables_and_outputs(): + graph = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "variables": [ + { + "variable": "query", + "label": "Query", + "type": "text-input", + "required": True, + } + ], + }, + }, + { + "id": "end-1", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "string", "value_selector": ["n1", "answer"]}, + {"variable": "score", "value_type": "number", "value_selector": ["n1", "score"]}, + ], + }, + }, + { + "id": "end-2", + "data": { + "type": "end", + "outputs": [ + {"variable": "answer", "value_type": "object", "value_selector": ["n2", "answer"]}, + ], + }, + }, + ] + } + + variables = WorkflowToolConfigurationUtils.get_workflow_graph_variables(graph) + assert len(variables) == 1 + assert variables[0].variable == "query" + assert variables[0].type == VariableEntityType.TEXT_INPUT + + outputs = WorkflowToolConfigurationUtils.get_workflow_graph_output(graph) + assert [output.variable for output in outputs] == ["answer", "score"] + assert outputs[0].value_type == "object" + assert outputs[1].value_type == "number" + + no_start = WorkflowToolConfigurationUtils.get_workflow_graph_variables({"nodes": []}) + assert no_start == [] + + +def test_check_is_synced_validation(): + variables = [ + VariableEntity( + variable="query", + label="Query", + type=VariableEntityType.TEXT_INPUT, + required=True, + ) + ] + configs = [ + WorkflowToolParameterConfiguration( + name="query", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ] + + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=configs) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced(variables=variables, tool_configurations=[]) + + with pytest.raises(ValueError, match="parameter configuration mismatch"): + WorkflowToolConfigurationUtils.check_is_synced( + variables=variables, + tool_configurations=[ + WorkflowToolParameterConfiguration( + name="other", + description="desc", + form=ToolParameter.ToolParameterForm.FORM, + ) + ], + ) diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py new file mode 100644 index 0000000000..dd140cbb27 --- /dev/null +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch + +import pytest + +from core.tools.entities.common_entities import I18nObject +from core.tools.entities.tool_entities import ( + ToolParameter, + ToolProviderEntity, + ToolProviderIdentity, + ToolProviderType, +) +from core.tools.workflow_as_tool.provider import WorkflowToolProviderController +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType + + +def _controller() -> WorkflowToolProviderController: + entity = ToolProviderEntity( + identity=ToolProviderIdentity( + author="author", + name="wf-provider", + description=I18nObject(en_US="desc"), + icon="icon.svg", + label=I18nObject(en_US="WF"), + ), + credentials_schema=[], + ) + return WorkflowToolProviderController(entity=entity, provider_id="provider-1") + + +def _mock_session_with_begin() -> Mock: + session = Mock() + begin_cm = Mock() + begin_cm.__enter__ = Mock(return_value=None) + begin_cm.__exit__ = Mock(return_value=False) + session.begin.return_value = begin_cm + return session + + +def test_get_db_provider_tool_builds_entity(): + controller = _controller() + session = Mock() + workflow = SimpleNamespace(graph_dict={"nodes": []}, features_dict={}) + session.query.return_value.where.return_value.first.return_value = workflow + app = SimpleNamespace(id="app-1") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + user_id="user-1", + parameter_configurations=[ + SimpleNamespace(name="country", description="Country", form=ToolParameter.ToolParameterForm.FORM), + SimpleNamespace(name="files", description="files", form=ToolParameter.ToolParameterForm.FORM), + ], + ) + user = SimpleNamespace(name="Alice") + variables = [ + VariableEntity( + variable="country", + label="Country", + description="Country", + type=VariableEntityType.SELECT, + required=True, + options=["US", "IN"], + ) + ] + outputs = [ + SimpleNamespace(variable="json", value_type="string"), + SimpleNamespace(variable="answer", value_type="string"), + ] + + with ( + patch( + "core.tools.workflow_as_tool.provider.WorkflowAppConfigManager.convert_features", + return_value=SimpleNamespace(file_upload=True), + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_variables", + return_value=variables, + ), + patch( + "core.tools.workflow_as_tool.provider.WorkflowToolConfigurationUtils.get_workflow_graph_output", + return_value=outputs, + ), + ): + tool = controller._get_db_provider_tool(db_provider, app, session=session, user=user) + + assert tool.entity.identity.name == "workflow_tool" + # "json" output is reserved for ToolInvokeMessage.VariableMessage and filtered out. + assert tool.entity.output_schema["properties"] == {"answer": {"type": "string", "description": ""}} + assert "json" not in tool.entity.output_schema["properties"] + assert tool.entity.parameters[0].type == ToolParameter.ToolParameterType.SELECT + assert tool.entity.parameters[1].type == ToolParameter.ToolParameterType.SYSTEM_FILES + assert controller.provider_type == ToolProviderType.WORKFLOW + + +def test_get_tool_returns_hit_or_none(): + controller = _controller() + tool = SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="workflow_tool"))) + controller.tools = [tool] + + assert controller.get_tool("workflow_tool") is tool + assert controller.get_tool("missing") is None + + +def test_get_tools_returns_cached(): + controller = _controller() + cached_tools = [SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf-cached")))] + controller.tools = cached_tools # type: ignore[assignment] + + assert controller.get_tools("tenant-1") == cached_tools + + +def test_from_db_builds_controller(): + controller = _controller() + + app = SimpleNamespace(id="app-1") + user = SimpleNamespace(name="Alice") + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.side_effect = [app, user] + fake_cm = MagicMock() + fake_cm.__enter__.return_value = session + fake_cm.__exit__.return_value = False + fake_session_factory = Mock() + fake_session_factory.create_session.return_value = fake_cm + + with patch("core.tools.workflow_as_tool.provider.session_factory", fake_session_factory): + with patch.object( + WorkflowToolProviderController, + "_get_db_provider_tool", + return_value=SimpleNamespace(entity=SimpleNamespace(identity=SimpleNamespace(name="wf"))), + ): + built = WorkflowToolProviderController.from_db(db_provider) + assert isinstance(built, WorkflowToolProviderController) + assert built.tools + + +def test_get_tools_returns_empty_when_provider_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = None + session_cls.return_value.__enter__.return_value = session + + assert controller.get_tools("tenant-1") == [] + + +def test_get_tools_raises_when_app_missing(): + controller = _controller() + controller.tools = None # type: ignore[assignment] + db_provider = SimpleNamespace( + id="provider-1", + app_id="app-1", + version="1", + user_id="user-1", + label="WF Provider", + description="desc", + icon="icon.svg", + name="workflow_tool", + tenant_id="tenant-1", + parameter_configurations=[], + ) + + with patch("core.tools.workflow_as_tool.provider.db") as mock_db: + mock_db.engine = object() + with patch("core.tools.workflow_as_tool.provider.Session") as session_cls: + session = _mock_session_with_begin() + session.query.return_value.where.return_value.first.return_value = db_provider + session.get.return_value = None + session_cls.return_value.__enter__.return_value = session + with pytest.raises(ValueError, match="app not found"): + controller.get_tools("tenant-1") diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index 36fdb0218c..cc00f79698 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -1,20 +1,85 @@ +"""Unit tests for workflow-as-tool behavior. + +StubSession/StubScalars emulate SQLAlchemy session/scalars with minimal methods +(`scalar`, `scalars`, `expunge`, `commit`, `refresh`, context manager) to keep +database access mocked and predictable in tests. +""" + +import json from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock, Mock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage +from core.tools.entities.tool_entities import ( + ToolEntity, + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool +from dify_graph.file import FILE_MODEL_IDENTITY -def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): - """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when - `WorkflowAppGenerator.generate` returns a result with `error` key inside - the `data` element. - """ +class StubScalars: + """Minimal stub for SQLAlchemy scalar results.""" + + _value: Any + + def __init__(self, value: Any) -> None: + self._value = value + + def first(self) -> Any: + return self._value + + +class StubSession: + """Minimal stub for session_factory-created sessions.""" + + scalar_results: list[Any] + scalars_results: list[Any] + expunge_calls: list[object] + + def __init__(self, *, scalar_results: list[Any] | None = None, scalars_results: list[Any] | None = None) -> None: + self.scalar_results = list(scalar_results or []) + self.scalars_results = list(scalars_results or []) + self.expunge_calls: list[object] = [] + + def scalar(self, _stmt: Any) -> Any: + return self.scalar_results.pop(0) + + def scalars(self, _stmt: Any) -> StubScalars: + return StubScalars(self.scalars_results.pop(0)) + + def expunge(self, value: Any) -> None: + self.expunge_calls.append(value) + + def begin(self) -> "StubSession": + return self + + def commit(self) -> None: + pass + + def refresh(self, _value: Any) -> None: + pass + + def close(self) -> None: + pass + + def __enter__(self) -> "StubSession": + return self + + def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> bool: + return False + + +def _build_tool() -> WorkflowTool: entity = ToolEntity( identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), parameters=[], @@ -22,9 +87,9 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel has_runtime_parameters=False, ) runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", + return WorkflowTool( + workflow_app_id="app-1", + workflow_as_tool_id="wf-tool-1", version="1", workflow_entities={}, workflow_call_depth=1, @@ -32,13 +97,19 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel runtime=runtime, ) + +def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_field(monkeypatch: pytest.MonkeyPatch): + """Ensure that WorkflowTool will throw a `ToolInvokeError` exception when + `WorkflowAppGenerator.generate` returns a result with `error` key inside + the `data` element. + """ + tool = _build_tool() + # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -56,28 +127,12 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.MonkeyPatch): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure pause_state_config is passed as None.""" + tool = _build_tool() monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) - from unittest.mock import MagicMock, Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -94,22 +149,7 @@ def test_workflow_tool_does_not_use_pause_state_config(monkeypatch: pytest.Monke def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should generate variable messages when there are outputs""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # Mock workflow outputs mock_outputs = {"result": "success", "count": 42, "data": {"key": "value"}} @@ -119,8 +159,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -134,10 +172,6 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Execute tool invocation messages = list(tool.invoke("test_user", {})) - # Verify generated messages - # Should contain: 3 variable messages + 1 text message + 1 JSON message = 5 messages - assert len(messages) == 5 - # Verify variable messages variable_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.VARIABLE] assert len(variable_messages) == 3 @@ -151,7 +185,7 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch # Verify text message text_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.TEXT] assert len(text_messages) == 1 - assert '{"result": "success", "count": 42, "data": {"key": "value"}}' in text_messages[0].message.text + assert json.loads(text_messages[0].message.text) == mock_outputs # Verify JSON message json_messages = [msg for msg in messages if msg.type == ToolInvokeMessage.MessageType.JSON] @@ -161,30 +195,13 @@ def test_workflow_tool_should_generate_variable_messages_for_outputs(monkeypatch def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPatch): """Test that WorkflowTool should handle empty outputs correctly""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() # needs to patch those methods to avoid database access. monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) # Mock user resolution to avoid database access - from unittest.mock import Mock - mock_user = Mock() monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: mock_user) @@ -217,61 +234,32 @@ def test_workflow_tool_should_handle_empty_outputs(monkeypatch: pytest.MonkeyPat assert json_messages[0].message.json_object == {} -def test_create_variable_message(): - """Test the functionality of creating variable messages""" - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) - - # Test different types of variable values - test_cases = [ +@pytest.mark.parametrize( + ("var_name", "var_value"), + [ ("string_var", "test string"), ("int_var", 42), ("float_var", 3.14), ("bool_var", True), ("list_var", [1, 2, 3]), ("dict_var", {"key": "value"}), - ] + ], +) +def test_create_variable_message(var_name, var_value): + """Create variable messages for multiple value types.""" + tool = _build_tool() - for var_name, var_value in test_cases: - message = tool.create_variable_message(var_name, var_value) + message = tool.create_variable_message(var_name, var_value) - assert message.type == ToolInvokeMessage.MessageType.VARIABLE - assert message.message.variable_name == var_name - assert message.message.variable_value == var_value - assert message.message.stream is False + assert message.type == ToolInvokeMessage.MessageType.VARIABLE + assert message.message.variable_name == var_name + assert message.message.variable_value == var_value + assert message.message.stream is False def test_create_file_message_should_include_file_marker(): - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="test_tool", invoke_from=InvokeFrom.EXPLORE) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + """Ensure file message includes marker and meta payload.""" + tool = _build_tool() file_obj = object() message = tool.create_file_message(file_obj) # type: ignore[arg-type] @@ -284,103 +272,247 @@ def test_create_file_message_should_include_file_marker(): def test_resolve_user_from_database_falls_back_to_end_user(monkeypatch: pytest.MonkeyPatch): """Ensure worker context can resolve EndUser when Account is missing.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - # SQLAlchemy Session APIs used by code under test - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - # support `with session_factory.create_session() as session:` - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - tenant = SimpleNamespace(id="tenant_id") end_user = SimpleNamespace(id="end_user_id", tenant_id="tenant_id") # Monkeypatch session factory to return our stub session + stub_session = StubSession(scalar_results=[tenant, None, end_user]) monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([tenant, None, end_user]), + lambda: stub_session, ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="tenant_id", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "tenant_id" resolved_user = tool._resolve_user_from_database(user_id=end_user.id) assert resolved_user is end_user + assert stub_session.expunge_calls == [end_user] def test_resolve_user_from_database_returns_none_when_no_tenant(monkeypatch: pytest.MonkeyPatch): """Return None if tenant cannot be found in worker context.""" - class StubSession: - def __init__(self, results: list): - self.results = results - - def scalar(self, _stmt): - return self.results.pop(0) - - def expunge(self, *_args, **_kwargs): - pass - - def close(self): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - self.close() - # Monkeypatch session factory to return our stub session with no tenant monkeypatch.setattr( "core.tools.workflow_as_tool.tool.session_factory.create_session", - lambda: StubSession([None]), + lambda: StubSession(scalar_results=[None]), ) - entity = ToolEntity( - identity=ToolIdentity(author="test", name="test tool", label=I18nObject(en_US="test tool"), provider="test"), - parameters=[], - description=None, - has_runtime_parameters=False, - ) - runtime = ToolRuntime(tenant_id="missing_tenant", invoke_from=InvokeFrom.SERVICE_API) - tool = WorkflowTool( - workflow_app_id="", - workflow_as_tool_id="", - version="1", - workflow_entities={}, - workflow_call_depth=1, - entity=entity, - runtime=runtime, - ) + tool = _build_tool() + tool.runtime.invoke_from = InvokeFrom.SERVICE_API + tool.runtime.tenant_id = "missing_tenant" resolved_user = tool._resolve_user_from_database(user_id="any") assert resolved_user is None + + +def test_workflow_tool_provider_type_and_fork_runtime(): + """Verify provider type and forked runtime behavior.""" + tool = _build_tool() + assert tool.tool_provider_type() == ToolProviderType.WORKFLOW + assert tool.latest_usage.total_tokens == 0 + + forked = tool.fork_tool_runtime(ToolRuntime(tenant_id="tenant-2", invoke_from=InvokeFrom.DEBUGGER)) + assert isinstance(forked, WorkflowTool) + assert forked.workflow_app_id == tool.workflow_app_id + assert forked.runtime.tenant_id == "tenant-2" + + +def test_derive_usage_from_top_level_usage_key(): + """Derive usage from top-level usage dict.""" + usage = WorkflowTool._derive_usage_from_result({"usage": {"total_tokens": 12, "total_price": "0.2"}}) + assert usage.total_tokens == 12 + + +def test_derive_usage_from_metadata_usage(): + """Derive usage from metadata usage dict.""" + metadata_usage = WorkflowTool._derive_usage_from_result({"metadata": {"usage": {"total_tokens": 7}}}) + assert metadata_usage.total_tokens == 7 + + +def test_derive_usage_from_totals(): + """Derive usage from top-level totals fields.""" + totals_usage = WorkflowTool._derive_usage_from_result( + {"total_tokens": "9", "total_price": "1.3", "currency": "USD"} + ) + assert totals_usage.total_tokens == 9 + assert str(totals_usage.total_price) == "1.3" + + +def test_derive_usage_from_empty(): + """Default usage values when result is empty.""" + empty_usage = WorkflowTool._derive_usage_from_result({}) + assert empty_usage.total_tokens == 0 + + +def test_extract_usage_from_nested(): + """Extract nested usage dict from result payloads.""" + nested = WorkflowTool._extract_usage_dict({"nested": [{"data": {"usage": {"total_tokens": 3}}}]}) + assert nested == {"total_tokens": 3} + + +def test_invoke_raises_when_user_not_found(monkeypatch: pytest.MonkeyPatch): + """Raise ToolInvokeError when user resolution fails.""" + tool = _build_tool() + monkeypatch.setattr(tool, "_get_app", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_get_workflow", lambda *args, **kwargs: None) + monkeypatch.setattr(tool, "_resolve_user", lambda *args, **kwargs: None) + + with pytest.raises(ToolInvokeError, match="User not found"): + list(tool.invoke("missing", {})) + + +def test_resolve_user_from_database_returns_account(monkeypatch: pytest.MonkeyPatch): + """Resolve Account and set tenant in worker context.""" + tenant = SimpleNamespace(id="tenant_id") + account = SimpleNamespace(id="account_id", current_tenant=None) + session = StubSession(scalar_results=[tenant, account]) + + monkeypatch.setattr("core.tools.workflow_as_tool.tool.session_factory.create_session", lambda: session) + tool = _build_tool() + tool.runtime.tenant_id = "tenant_id" + + resolved = tool._resolve_user_from_database(user_id="account_id") + assert resolved is account + assert account.current_tenant is tenant + assert session.expunge_calls == [account] + + +def test_get_workflow_and_get_app_db_branches(monkeypatch: pytest.MonkeyPatch): + """Cover workflow/app retrieval branches and error cases.""" + tool = _build_tool() + latest_workflow = SimpleNamespace(id="wf-latest") + specific_workflow = SimpleNamespace(id="wf-v1") + app = SimpleNamespace(id="app-1") + sessions = iter( + [ + StubSession(scalar_results=[], scalars_results=[latest_workflow]), + StubSession(scalar_results=[specific_workflow], scalars_results=[]), + StubSession(scalar_results=[app], scalars_results=[]), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: next(sessions), + ) + + assert tool._get_workflow("app-1", "") is latest_workflow + assert tool._get_workflow("app-1", "1") is specific_workflow + assert tool._get_app("app-1") is app + + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.session_factory.create_session", + lambda: StubSession(scalar_results=[None, None], scalars_results=[None]), + ) + with pytest.raises(ValueError, match="workflow not found"): + tool._get_workflow("app-1", "1") + with pytest.raises(ValueError, match="app not found"): + tool._get_app("app-1") + + +def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: + """Build a WorkflowTool and stub merged runtime parameters for files/query.""" + tool = _build_tool() + files_param = ToolParameter.get_simple_instance( + name="files", + llm_description="files", + typ=ToolParameter.ToolParameterType.SYSTEM_FILES, + required=False, + ) + files_param.form = ToolParameter.ToolParameterForm.FORM + text_param = ToolParameter.get_simple_instance( + name="query", + llm_description="query", + typ=ToolParameter.ToolParameterType.STRING, + required=False, + ) + text_param.form = ToolParameter.ToolParameterForm.FORM + + monkeypatch.setattr(tool, "get_merged_runtime_parameters", lambda: [files_param, text_param]) + return tool + + +def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): + """Transform args into parameters and files payloads.""" + tool = _setup_transform_args_tool(monkeypatch) + + params, files = tool._transform_args( + { + "query": "hello", + "files": [ + { + "tenant_id": "tenant-1", + "type": "image", + "transfer_method": "tool_file", + "related_id": "tool-1", + "extension": ".png", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "local_file", + "related_id": "upload-1", + }, + { + "tenant_id": "tenant-1", + "type": "document", + "transfer_method": "remote_url", + "remote_url": "https://example.com/a.pdf", + }, + ], + } + ) + assert params == {"query": "hello"} + assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) + assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) + assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + + +def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): + """Ignore invalid file entries while keeping params.""" + tool = _setup_transform_args_tool(monkeypatch) + invalid_params, invalid_files = tool._transform_args({"query": "hello", "files": [{"invalid": True}]}) + assert invalid_params == {"query": "hello"} + assert invalid_files == [] + + +def test_extract_files(): + """Extract file outputs into result and file list.""" + tool = _build_tool() + built_files = [ + SimpleNamespace(id="file-1"), + SimpleNamespace(id="file-2"), + ] + with patch("core.tools.workflow_as_tool.tool.build_from_mapping", side_effect=built_files): + outputs = { + "attachments": [ + { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "tool_file", + "related_id": "r1", + } + ], + "single_file": { + "dify_model_identity": FILE_MODEL_IDENTITY, + "transfer_method": "local_file", + "related_id": "r2", + }, + "text": "ok", + } + result, extracted_files = tool._extract_files(outputs) + + assert result["text"] == "ok" + assert len(extracted_files) == 2 + + +def test_update_file_mapping(): + """Map tool/local file transfer methods into output shape.""" + tool = _build_tool() + tool_file = tool._update_file_mapping({"transfer_method": "tool_file", "related_id": "tool-1"}) + assert tool_file["tool_file_id"] == "tool-1" + local_file = tool._update_file_mapping({"transfer_method": "local_file", "related_id": "upload-1"}) + assert local_file["upload_file_id"] == "upload-1" diff --git a/api/tests/unit_tests/core/trigger/__init__.py b/api/tests/unit_tests/core/trigger/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/conftest.py b/api/tests/unit_tests/core/trigger/conftest.py new file mode 100644 index 0000000000..d9da80a8b7 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/conftest.py @@ -0,0 +1,93 @@ +"""Shared factory helpers for core.trigger test suite.""" + +from __future__ import annotations + +from typing import Any + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.entities.entities import ( + EventEntity, + EventIdentity, + EventParameter, + OAuthSchema, + Subscription, + SubscriptionConstructor, + TriggerProviderEntity, + TriggerProviderIdentity, +) +from core.trigger.provider import PluginTriggerProviderController +from models.provider_ids import TriggerProviderID + +# Valid format for TriggerProviderID: org/plugin/provider +VALID_PROVIDER_ID = "testorg/testplugin/testprovider" + + +def i18n(text: str = "test") -> I18nObject: + return I18nObject(en_US=text, zh_Hans=text) + + +def make_event(name: str = "test_event", parameters: list[EventParameter] | None = None) -> EventEntity: + return EventEntity( + identity=EventIdentity(author="a", name=name, label=i18n(name)), + description=i18n(name), + parameters=parameters or [], + ) + + +def make_provider_entity( + name: str = "test_provider", + events: list[EventEntity] | None = None, + constructor: SubscriptionConstructor | None = None, + subscription_schema: list[ProviderConfig] | None = None, + icon: str | None = "icon.png", + icon_dark: str | None = None, +) -> TriggerProviderEntity: + return TriggerProviderEntity( + identity=TriggerProviderIdentity( + author="a", + name=name, + label=i18n(name), + description=i18n(name), + icon=icon, + icon_dark=icon_dark, + ), + events=events if events is not None else [make_event()], + subscription_constructor=constructor, + subscription_schema=subscription_schema or [], + ) + + +def make_controller( + entity: TriggerProviderEntity | None = None, + tenant_id: str = "tenant-1", + provider_id: str = VALID_PROVIDER_ID, +) -> PluginTriggerProviderController: + return PluginTriggerProviderController( + entity=entity or make_provider_entity(), + plugin_id="plugin-1", + plugin_unique_identifier="uid-1", + provider_id=TriggerProviderID(provider_id), + tenant_id=tenant_id, + ) + + +def make_subscription(**overrides: Any) -> Subscription: + defaults = {"expires_at": 9999999999, "endpoint": "https://hook.test", "properties": {"k": "v"}, "parameters": {}} + defaults.update(overrides) + return Subscription(**defaults) + + +def make_provider_config( + name: str = "api_key", required: bool = True, config_type: str = "secret-input" +) -> ProviderConfig: + return ProviderConfig(name=name, label=i18n(name), type=config_type, required=required) + + +def make_constructor( + credentials_schema: list[ProviderConfig] | None = None, + oauth_schema: OAuthSchema | None = None, +) -> SubscriptionConstructor: + return SubscriptionConstructor( + parameters=[], credentials_schema=credentials_schema or [], oauth_schema=oauth_schema + ) diff --git a/api/tests/unit_tests/core/trigger/debug/__init__.py b/api/tests/unit_tests/core/trigger/debug/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py new file mode 100644 index 0000000000..d557c20f5e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_bus.py @@ -0,0 +1,93 @@ +""" +Tests for core.trigger.debug.event_bus.TriggerDebugEventBus. + +Covers: Lua-script dispatch/poll with Redis error resilience. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from redis import RedisError + +from core.trigger.debug.event_bus import TriggerDebugEventBus +from core.trigger.debug.events import PluginTriggerDebugEvent + + +class TestDispatch: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_dispatch_count(self, mock_redis): + mock_redis.eval.return_value = 3 + event = MagicMock() + event.model_dump_json.return_value = '{"test": true}' + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 3 + mock_redis.eval.assert_called_once() + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_zero(self, mock_redis): + mock_redis.eval.side_effect = RedisError("connection lost") + event = MagicMock() + event.model_dump_json.return_value = "{}" + + result = TriggerDebugEventBus.dispatch("tenant-1", event, "pool:key") + + assert result == 0 + + +class TestPoll: + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_deserialized_event(self, mock_redis): + event_json = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ).model_dump_json() + mock_redis.eval.return_value = event_json + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is not None + assert result.name == "push" + + @patch("core.trigger.debug.event_bus.redis_client") + def test_returns_none_when_no_event(self, mock_redis): + mock_redis.eval.return_value = None + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None + + @patch("core.trigger.debug.event_bus.redis_client") + def test_redis_error_returns_none(self, mock_redis): + mock_redis.eval.side_effect = RedisError("timeout") + + result = TriggerDebugEventBus.poll( + event_type=PluginTriggerDebugEvent, + pool_key="pool:key", + tenant_id="t1", + user_id="u1", + app_id="a1", + node_id="n1", + ) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py new file mode 100644 index 0000000000..bcb1d745e3 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -0,0 +1,281 @@ +""" +Tests for core.trigger.debug.event_selectors. + +Covers: Plugin/Webhook/Schedule pollers, create_event_poller factory, +and select_trigger_debug_events orchestrator. +""" + +from __future__ import annotations + +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from core.trigger.debug.event_selectors import ( + PluginTriggerDebugEventPoller, + ScheduleTriggerDebugEventPoller, + WebhookTriggerDebugEventPoller, + create_event_poller, + select_trigger_debug_events, +) +from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent +from dify_graph.enums import BuiltinNodeTypes, NodeType +from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID + + +def _make_poller_args(node_config: dict | None = None) -> dict: + return { + "tenant_id": "t1", + "user_id": "u1", + "app_id": "a1", + "node_config": node_config or {"data": {}}, + "node_id": "n1", + } + + +def _plugin_node_config(provider_id: str = VALID_PROVIDER_ID) -> dict: + """Valid node config for TriggerEventNodeData.model_validate.""" + return { + "data": { + "title": "test", + "plugin_id": "org/testplugin", + "provider_id": provider_id, + "event_name": "push", + "subscription_id": "s1", + "plugin_unique_identifier": "uid-1", + } + } + + +class TestPluginTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_workflow_args_on_success(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={"repo": "dify"}, + cancelled=False, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"repo": "dify"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_invoke_cancelled(self, mock_bus): + event = PluginTriggerDebugEvent( + timestamp=100, + name="push", + user_id="u1", + request_id="r1", + subscription_id="s1", + provider_id="p1", + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.trigger_service.TriggerService") as mock_trigger_svc: + mock_trigger_svc.invoke_trigger_event.return_value = TriggerInvokeEventResponse( + variables={}, + cancelled=True, + ) + + poller = PluginTriggerDebugEventPoller(**_make_poller_args(_plugin_node_config())) + + assert poller.poll() is None + + +class TestWebhookTriggerDebugEventPoller: + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_uses_inputs_directly_when_present(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"inputs": {"key": "val"}, "webhook_data": {}}, + ) + mock_bus.poll.return_value = event + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"key": "val"} + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_falls_back_to_webhook_data(self, mock_bus): + event = WebhookDebugEvent( + timestamp=100, + request_id="r1", + node_id="n1", + payload={"webhook_data": {"body": "raw"}}, + ) + mock_bus.poll.return_value = event + + with patch("services.trigger.webhook_service.WebhookService") as mock_webhook_svc: + mock_webhook_svc.build_workflow_inputs.return_value = {"parsed": "data"} + + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + assert result.workflow_args["inputs"] == {"parsed": "data"} + mock_webhook_svc.build_workflow_inputs.assert_called_once_with({"body": "raw"}) + + @patch("core.trigger.debug.event_selectors.TriggerDebugEventBus") + def test_returns_none_when_no_event(self, mock_bus): + mock_bus.poll.return_value = None + poller = WebhookTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + +class TestScheduleTriggerDebugEventPoller: + def _make_schedule_poller(self, mock_redis, mock_schedule_svc, next_run_at: datetime): + """Set up mocks and create a schedule poller.""" + mock_redis.get.return_value = None + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + return ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_returns_none_when_not_yet_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 12, 0, 0) + next_run = datetime(2025, 1, 1, 13, 0, 0) # future + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + + assert poller.poll() is None + + @patch("core.trigger.debug.event_selectors.redis_client") + @patch("core.trigger.debug.event_selectors.naive_utc_now") + @patch("core.trigger.debug.event_selectors.calculate_next_run_at") + @patch("core.trigger.debug.event_selectors.ensure_naive_utc") + def test_fires_event_when_due(self, mock_ensure, mock_calc, mock_now, mock_redis): + now = datetime(2025, 1, 1, 14, 0, 0) + next_run = datetime(2025, 1, 1, 12, 0, 0) # past + mock_now.return_value = now + mock_calc.return_value = next_run + mock_ensure.return_value = next_run + mock_redis.get.return_value = None + + with patch("services.trigger.schedule_service.ScheduleService") as mock_schedule_svc: + mock_schedule_config = MagicMock() + mock_schedule_config.cron_expression = "0 * * * *" + mock_schedule_config.timezone = "UTC" + mock_schedule_svc.to_schedule_config.return_value = mock_schedule_config + + poller = ScheduleTriggerDebugEventPoller(**_make_poller_args()) + result = poller.poll() + + assert result is not None + mock_redis.delete.assert_called_once() + + +class TestCreateEventPoller: + def _workflow_with_node(self, node_type: NodeType): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = node_type + return wf + + def test_creates_plugin_poller(self): + wf = self._workflow_with_node(TRIGGER_PLUGIN_NODE_TYPE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, PluginTriggerDebugEventPoller) + + def test_creates_webhook_poller(self): + wf = self._workflow_with_node(TRIGGER_WEBHOOK_NODE_TYPE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, WebhookTriggerDebugEventPoller) + + def test_creates_schedule_poller(self): + wf = self._workflow_with_node(TRIGGER_SCHEDULE_NODE_TYPE) + poller = create_event_poller(wf, "t1", "u1", "a1", "n1") + assert isinstance(poller, ScheduleTriggerDebugEventPoller) + + def test_raises_for_unknown_type(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = BuiltinNodeTypes.START + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + def test_raises_when_node_config_missing(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = None + + with pytest.raises(ValueError): + create_event_poller(wf, "t1", "u1", "a1", "n1") + + +class TestSelectTriggerDebugEvents: + def test_returns_first_non_none_event(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = TRIGGER_WEBHOOK_NODE_TYPE + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll") as mock_poll: + expected = MagicMock() + mock_poll.return_value = expected + + result = select_trigger_debug_events(wf, app_model, "u1", ["n1", "n2"]) + + assert result is expected + + def test_returns_none_when_no_events(self): + wf = MagicMock() + wf.get_node_config_by_id.return_value = {"data": {}} + wf.get_node_type_from_node_config.return_value = TRIGGER_WEBHOOK_NODE_TYPE + app_model = MagicMock() + app_model.tenant_id = "t1" + app_model.id = "a1" + + with patch.object(WebhookTriggerDebugEventPoller, "poll", return_value=None): + result = select_trigger_debug_events(wf, app_model, "u1", ["n1"]) + + assert result is None diff --git a/api/tests/unit_tests/core/trigger/test_provider.py b/api/tests/unit_tests/core/trigger/test_provider.py new file mode 100644 index 0000000000..3c2f297e90 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_provider.py @@ -0,0 +1,332 @@ +""" +Tests for core.trigger.provider.PluginTriggerProviderController. + +Covers: to_api_entity creation-method logic, credential validation pipeline, +schema resolution by type, event lookup, dispatch/invoke/subscribe delegation. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.trigger.entities.entities import ( + EventParameter, + EventParameterType, + OAuthSchema, + TriggerCreationMethod, +) +from core.trigger.errors import TriggerProviderCredentialValidationError +from tests.unit_tests.core.trigger.conftest import ( + i18n, + make_constructor, + make_controller, + make_event, + make_provider_config, + make_provider_entity, + make_subscription, +) + +ICON_URL = "https://cdn/icon.png" + + +class TestToApiEntity: + @patch("core.trigger.provider.PluginService") + def test_includes_icons_when_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(icon="icon.png", icon_dark="dark.png")) + + api = ctrl.to_api_entity() + + assert api.icon == ICON_URL + assert api.icon_dark == ICON_URL + + @patch("core.trigger.provider.PluginService") + def test_icons_none_when_absent(self, mock_plugin_svc): + ctrl = make_controller(entity=make_provider_entity(icon=None, icon_dark=None)) + + api = ctrl.to_api_entity() + + assert api.icon is None + assert api.icon_dark is None + mock_plugin_svc.get_plugin_icon_url.assert_not_called() + + @patch("core.trigger.provider.PluginService") + def test_manual_only_without_schemas(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + api = ctrl.to_api_entity() + + assert api.supported_creation_methods == [TriggerCreationMethod.MANUAL] + + @patch("core.trigger.provider.PluginService") + def test_adds_oauth_when_oauth_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.OAUTH in api.supported_creation_methods + assert TriggerCreationMethod.MANUAL in api.supported_creation_methods + + @patch("core.trigger.provider.PluginService") + def test_adds_apikey_when_credentials_schema_present(self, mock_plugin_svc): + mock_plugin_svc.get_plugin_icon_url.return_value = ICON_URL + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + api = ctrl.to_api_entity() + + assert TriggerCreationMethod.APIKEY in api.supported_creation_methods + + +class TestGetEvent: + def test_returns_matching_event(self): + evt = make_event("push") + ctrl = make_controller(entity=make_provider_entity(events=[evt, make_event("pr")])) + + assert ctrl.get_event("push") is evt + + def test_returns_none_for_unknown(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event("nonexistent") is None + + +class TestGetSubscriptionDefaultProperties: + def test_returns_defaults_skipping_none(self): + config1 = make_provider_config("key1") + config1.default = "val1" + config2 = make_provider_config("key2") + config2.default = None + ctrl = make_controller(entity=make_provider_entity(subscription_schema=[config1, config2])) + + props = ctrl.get_subscription_default_properties() + + assert props == {"key1": "val1"} + + +class TestValidateCredentials: + def test_raises_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + + with pytest.raises(ValueError, match="Subscription constructor not found"): + ctrl.validate_credentials("u1", {"key": "val"}) + + def test_raises_for_missing_required_field(self): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + + with pytest.raises(TriggerProviderCredentialValidationError, match="Missing required"): + ctrl.validate_credentials("u1", {}) + + @patch("core.trigger.provider.PluginTriggerClient") + def test_passes_with_valid_credentials(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = True + + ctrl.validate_credentials("u1", {"api_key": "secret123"}) # should not raise + + @patch("core.trigger.provider.PluginTriggerClient") + def test_raises_when_plugin_rejects(self, mock_client): + required_cfg = make_provider_config("api_key", required=True) + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[required_cfg])) + ) + mock_client.return_value.validate_provider_credentials.return_value = None + + with pytest.raises(TriggerProviderCredentialValidationError, match="Invalid credentials"): + ctrl.validate_credentials("u1", {"api_key": "bad"}) + + +class TestGetSupportedCredentialTypes: + def test_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_supported_credential_types() == [] + + def test_oauth_only(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY not in types + + def test_apikey_only(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.API_KEY in types + assert CredentialType.OAUTH2 not in types + + def test_both(self): + oauth = OAuthSchema(client_schema=[], credentials_schema=[make_provider_config("oauth_secret")]) + ctrl = make_controller( + entity=make_provider_entity( + constructor=make_constructor(credentials_schema=[make_provider_config()], oauth_schema=oauth) + ) + ) + + types = ctrl.get_supported_credential_types() + + assert CredentialType.OAUTH2 in types + assert CredentialType.API_KEY in types + + +class TestGetCredentialsSchema: + def test_returns_empty_when_no_constructor(self): + ctrl = make_controller(entity=make_provider_entity(constructor=None)) + assert ctrl.get_credentials_schema(CredentialType.API_KEY) == [] + + def test_returns_apikey_credentials(self): + cfg = make_provider_config("token") + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(credentials_schema=[cfg]))) + + result = ctrl.get_credentials_schema(CredentialType.API_KEY) + + assert len(result) == 1 + assert result[0].name == "token" + + def test_returns_oauth_credentials(self): + oauth_cred = make_provider_config("oauth_token") + oauth = OAuthSchema(client_schema=[], credentials_schema=[oauth_cred]) + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor(oauth_schema=oauth))) + + result = ctrl.get_credentials_schema(CredentialType.OAUTH2) + + assert len(result) == 1 + assert result[0].name == "oauth_token" + + def test_unauthorized_returns_empty(self): + ctrl = make_controller( + entity=make_provider_entity(constructor=make_constructor(credentials_schema=[make_provider_config()])) + ) + assert ctrl.get_credentials_schema(CredentialType.UNAUTHORIZED) == [] + + def test_invalid_type_raises(self): + ctrl = make_controller(entity=make_provider_entity(constructor=make_constructor())) + with pytest.raises(ValueError, match="Invalid credential type"): + ctrl.get_credentials_schema("bogus_type") + + +class TestGetEventParameters: + def test_returns_params_for_known_event(self): + param = EventParameter(name="branch", label=i18n("branch"), type=EventParameterType.STRING) + evt = make_event("push", parameters=[param]) + ctrl = make_controller(entity=make_provider_entity(events=[evt])) + + result = ctrl.get_event_parameters("push") + + assert "branch" in result + assert result["branch"].name == "branch" + + def test_returns_empty_for_unknown_event(self): + ctrl = make_controller(entity=make_provider_entity(events=[make_event("push")])) + + assert ctrl.get_event_parameters("nonexistent") == {} + + +class TestDispatch: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.dispatch_event.return_value = expected + + result = ctrl.dispatch( + request=MagicMock(), + subscription=make_subscription(), + credentials={"k": "v"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + mock_client.return_value.dispatch_event.assert_called_once() + + +class TestInvokeTriggerEvent: + @patch("core.trigger.provider.PluginTriggerClient") + def test_delegates_to_client(self, mock_client): + ctrl = make_controller() + expected = MagicMock() + mock_client.return_value.invoke_trigger_event.return_value = expected + + result = ctrl.invoke_trigger_event( + user_id="u1", + event_name="push", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + subscription=make_subscription(), + request=MagicMock(), + payload={}, + ) + + assert result is expected + + +class TestSubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_subscription(self, mock_client): + ctrl = make_controller() + mock_client.return_value.subscribe.return_value.subscription = { + "expires_at": 123, + "endpoint": "https://e", + "properties": {}, + } + + result = ctrl.subscribe_trigger( + user_id="u1", + endpoint="https://e", + parameters={}, + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.endpoint == "https://e" + + +class TestUnsubscribeTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_returns_validated_result(self, mock_client): + ctrl = make_controller() + mock_client.return_value.unsubscribe.return_value.subscription = {"success": True, "message": "ok"} + + result = ctrl.unsubscribe_trigger( + user_id="u1", + subscription=make_subscription(), + credentials={}, + credential_type=CredentialType.API_KEY, + ) + + assert result.success is True + + +class TestRefreshTrigger: + @patch("core.trigger.provider.PluginTriggerClient") + def test_uses_system_user_id(self, mock_client): + ctrl = make_controller() + mock_client.return_value.refresh.return_value.subscription = { + "expires_at": 456, + "endpoint": "https://e", + "properties": {}, + } + + ctrl.refresh_trigger(subscription=make_subscription(), credentials={}, credential_type=CredentialType.API_KEY) + + call_kwargs = mock_client.return_value.refresh.call_args[1] + assert call_kwargs["user_id"] == "system" diff --git a/api/tests/unit_tests/core/trigger/test_trigger_manager.py b/api/tests/unit_tests/core/trigger/test_trigger_manager.py new file mode 100644 index 0000000000..612be25ec9 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/test_trigger_manager.py @@ -0,0 +1,307 @@ +""" +Tests for core.trigger.trigger_manager.TriggerManager. + +Covers: icon URL construction, provider listing with error resilience, +double-check lock caching, error translation, EventIgnoreError -> cancelled, +and delegation to provider controller. +""" + +from __future__ import annotations + +from threading import Lock +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import CredentialType +from core.plugin.entities.request import TriggerInvokeEventResponse +from core.plugin.impl.exc import PluginDaemonError, PluginNotFoundError +from core.trigger.errors import EventIgnoreError +from core.trigger.trigger_manager import TriggerManager +from models.provider_ids import TriggerProviderID +from tests.unit_tests.core.trigger.conftest import ( + VALID_PROVIDER_ID, + make_controller, + make_provider_entity, + make_subscription, +) + +PID = TriggerProviderID(VALID_PROVIDER_ID) +PID_STR = str(PID) + + +class TestGetTriggerPluginIcon: + @patch("core.trigger.trigger_manager.dify_config") + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_builds_correct_url(self, mock_client, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + provider = MagicMock() + provider.declaration.identity.icon = "my-icon.svg" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + url = TriggerManager.get_trigger_plugin_icon("tenant-1", VALID_PROVIDER_ID) + + assert "tenant_id=tenant-1" in url + assert "filename=my-icon.svg" in url + assert url.startswith("https://console.example.com/console/api/workspaces/current/plugin/icon") + + +class TestListPluginTriggerProviders: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_wraps_entities_into_controllers(self, mock_client): + entity = MagicMock() + entity.declaration = make_provider_entity("p1") + entity.plugin_id = "plugin-1" + entity.plugin_unique_identifier = "uid-1" + entity.provider = VALID_PROVIDER_ID + mock_client.return_value.fetch_trigger_providers.return_value = [entity] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "plugin-1" + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + def test_skips_failing_providers(self, mock_client): + good = MagicMock() + good.declaration = make_provider_entity("good") + good.plugin_id = "good-plugin" + good.plugin_unique_identifier = "uid-good" + good.provider = VALID_PROVIDER_ID + + bad = MagicMock() + bad.declaration = make_provider_entity("bad") + bad.plugin_id = "bad-plugin" + bad.plugin_unique_identifier = "uid-bad" + bad.provider = "bad/format" # 2-part: fails TriggerProviderID validation + + mock_client.return_value.fetch_trigger_providers.return_value = [bad, good] + + controllers = TriggerManager.list_plugin_trigger_providers("tenant-1") + + assert len(controllers) == 1 + assert controllers[0].plugin_id == "good-plugin" + + +class TestGetTriggerProvider: + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_initializes_context_on_first_call(self, mock_ctx, mock_client): + # get() called 3 times: (1) try block, (2) after set, (3) under lock + mock_ctx.plugin_trigger_providers.get.side_effect = [LookupError, {}, {}] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + mock_ctx.plugin_trigger_providers.set.assert_called_once_with({}) + mock_ctx.plugin_trigger_providers_lock.set.assert_called_once() + assert result is not None + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_returns_cached_without_fetch(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.return_value = {PID_STR: cached} + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_double_check_lock_uses_cached_from_other_thread(self, mock_ctx, mock_client): + cached = make_controller() + mock_ctx.plugin_trigger_providers.get.side_effect = [ + {}, # first check misses + {PID_STR: cached}, # under-lock check hits + ] + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is cached + mock_client.return_value.fetch_trigger_provider.assert_not_called() + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_fetches_and_caches_on_miss(self, mock_ctx, mock_client): + cache: dict = {} + mock_ctx.plugin_trigger_providers.get.return_value = cache + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + provider = MagicMock() + provider.declaration = make_provider_entity() + provider.plugin_id = "p1" + provider.plugin_unique_identifier = "uid-1" + mock_client.return_value.fetch_trigger_provider.return_value = provider + + result = TriggerManager.get_trigger_provider("t1", PID) + + assert result is not None + assert PID_STR in cache + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_none_fetch_raises_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.return_value = None + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/missing")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_not_found_becomes_value_error(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginNotFoundError("gone") + + with pytest.raises(ValueError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + @patch("core.trigger.trigger_manager.PluginTriggerClient") + @patch("core.trigger.trigger_manager.contexts") + def test_plugin_daemon_error_propagates(self, mock_ctx, mock_client): + mock_ctx.plugin_trigger_providers.get.return_value = {} + mock_ctx.plugin_trigger_providers_lock.get.return_value = Lock() + mock_client.return_value.fetch_trigger_provider.side_effect = PluginDaemonError("test error") + + with pytest.raises(PluginDaemonError): + TriggerManager.get_trigger_provider("t1", TriggerProviderID("org/plug/miss")) + + +class TestListAllTriggerProviders: + @patch.object(TriggerManager, "list_plugin_trigger_providers") + def test_delegates_to_list_plugin(self, mock_list): + expected = [make_controller()] + mock_list.return_value = expected + + assert TriggerManager.list_all_trigger_providers("t1") is expected + mock_list.assert_called_once_with("t1") + + +class TestListTriggersByProvider: + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_provider_events(self, mock_get): + ctrl = make_controller() + mock_get.return_value = ctrl + + result = TriggerManager.list_triggers_by_provider("t1", PID) + + assert result == ctrl.get_events() + + +class TestInvokeTriggerEvent: + def _args(self): + return { + "tenant_id": "t1", + "user_id": "u1", + "provider_id": PID, + "event_name": "on_push", + "parameters": {"branch": "main"}, + "credentials": {"token": "abc"}, + "credential_type": CredentialType.API_KEY, + "subscription": make_subscription(), + "request": MagicMock(), + "payload": {"action": "push"}, + } + + @patch.object(TriggerManager, "get_trigger_provider") + def test_returns_invoke_response(self, mock_get): + ctrl = MagicMock() + expected = TriggerInvokeEventResponse(variables={"v": "1"}, cancelled=False) + ctrl.invoke_trigger_event.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result is expected + assert result.cancelled is False + + @patch.object(TriggerManager, "get_trigger_provider") + def test_event_ignore_returns_cancelled(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = EventIgnoreError("skip") + mock_get.return_value = ctrl + + result = TriggerManager.invoke_trigger_event(**self._args()) + + assert result.cancelled is True + assert result.variables == {} + + @patch.object(TriggerManager, "get_trigger_provider") + def test_other_errors_propagate(self, mock_get): + ctrl = MagicMock() + ctrl.invoke_trigger_event.side_effect = RuntimeError("boom") + mock_get.return_value = ctrl + + with pytest.raises(RuntimeError, match="boom"): + TriggerManager.invoke_trigger_event(**self._args()) + + +class TestSubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.subscribe_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.subscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + endpoint="https://hook.test", + parameters={"f": "all"}, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + ctrl.subscribe_trigger.assert_called_once() + + +class TestUnsubscribeTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = MagicMock() + ctrl.unsubscribe_trigger.return_value = expected + mock_get.return_value = ctrl + sub = make_subscription() + + result = TriggerManager.unsubscribe_trigger( + tenant_id="t1", + user_id="u1", + provider_id=PID, + subscription=sub, + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected + + +class TestRefreshTrigger: + @patch.object(TriggerManager, "get_trigger_provider") + def test_delegates_with_correct_args(self, mock_get): + ctrl = MagicMock() + expected = make_subscription() + ctrl.refresh_trigger.return_value = expected + mock_get.return_value = ctrl + + result = TriggerManager.refresh_trigger( + tenant_id="t1", + provider_id=PID, + subscription=make_subscription(), + credentials={"token": "x"}, + credential_type=CredentialType.API_KEY, + ) + + assert result is expected diff --git a/api/tests/unit_tests/core/trigger/utils/__init__.py b/api/tests/unit_tests/core/trigger/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py new file mode 100644 index 0000000000..8804526e2e --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_encryption.py @@ -0,0 +1,62 @@ +"""Tests for core.trigger.utils.encryption — masking logic and cache key generation.""" + +from __future__ import annotations + +from core.entities.provider_entities import ProviderConfig +from core.tools.entities.common_entities import I18nObject +from core.trigger.utils.encryption import ( + TriggerProviderCredentialsCache, + TriggerProviderOAuthClientParamsCache, + TriggerProviderPropertiesCache, + masked_credentials, +) + + +def _make_schema(name: str, field_type: str = "secret-input") -> ProviderConfig: + return ProviderConfig( + name=name, + label=I18nObject(en_US=name, zh_Hans=name), + type=field_type, + ) + + +class TestMaskedCredentials: + def test_short_secret_fully_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "ab"}) + assert result["key"] == "**" + + def test_long_secret_partially_masked(self): + schema = [_make_schema("key", "secret-input")] + result = masked_credentials(schema, {"key": "abcdef"}) + assert result["key"].startswith("ab") + assert result["key"].endswith("ef") + assert "**" in result["key"] + + def test_non_secret_field_unchanged(self): + schema = [_make_schema("host", "text-input")] + result = masked_credentials(schema, {"host": "example.com"}) + assert result["host"] == "example.com" + + def test_unknown_key_passes_through(self): + result = masked_credentials([], {"unknown": "value"}) + assert result["unknown"] == "value" + + +class TestCacheKeyGeneration: + def test_credentials_cache_key_contains_ids(self): + cache = TriggerProviderCredentialsCache(tenant_id="t1", provider_id="p1", credential_id="c1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "c1" in cache.cache_key + + def test_oauth_client_cache_key_contains_ids(self): + cache = TriggerProviderOAuthClientParamsCache(tenant_id="t1", provider_id="p1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + + def test_properties_cache_key_contains_ids(self): + cache = TriggerProviderPropertiesCache(tenant_id="t1", provider_id="p1", subscription_id="s1") + assert "t1" in cache.cache_key + assert "p1" in cache.cache_key + assert "s1" in cache.cache_key diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py new file mode 100644 index 0000000000..e5879aea0a --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_endpoint.py @@ -0,0 +1,31 @@ +"""Tests for core.trigger.utils.endpoint — URL generation.""" + +from __future__ import annotations + +from unittest.mock import patch + +from yarl import URL + +from core.trigger.utils import endpoint + + +class TestGeneratePluginTriggerEndpointUrl: + def test_builds_correct_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_plugin_trigger_endpoint_url("endpoint-123") + + assert url == "https://api.example.com/triggers/plugin/endpoint-123" + + +class TestGenerateWebhookTriggerEndpoint: + def test_non_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=False) + + assert url == "https://api.example.com/triggers/webhook/sub-456" + + def test_debug_url(self): + with patch.object(endpoint, "base_url", URL("https://api.example.com")): + url = endpoint.generate_webhook_trigger_endpoint("sub-456", debug=True) + + assert url == "https://api.example.com/triggers/webhook-debug/sub-456" diff --git a/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py new file mode 100644 index 0000000000..4fa202b164 --- /dev/null +++ b/api/tests/unit_tests/core/trigger/utils/test_utils_locks.py @@ -0,0 +1,23 @@ +"""Tests for core.trigger.utils.locks — Redis lock key builders.""" + +from __future__ import annotations + +from core.trigger.utils.locks import build_trigger_refresh_lock_key, build_trigger_refresh_lock_keys + + +class TestBuildTriggerRefreshLockKey: + def test_correct_format(self): + key = build_trigger_refresh_lock_key("tenant-1", "sub-1") + + assert key == "trigger_provider_refresh_lock:tenant-1_sub-1" + + +class TestBuildTriggerRefreshLockKeys: + def test_maps_over_pairs(self): + pairs = [("t1", "s1"), ("t2", "s2")] + + keys = build_trigger_refresh_lock_keys(pairs) + + assert len(keys) == 2 + assert keys[0] == "trigger_provider_refresh_lock:t1_s1" + assert keys[1] == "trigger_provider_refresh_lock:t2_s2" diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index a9af8bea1d..91259c9a45 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,12 +1,15 @@ import dataclasses +import orjson +import pytest from pydantic import BaseModel from core.helper import encrypter -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ( +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -22,8 +25,13 @@ from core.workflow.variables.segments import ( StringSegment, get_segment_discriminator, ) -from core.workflow.variables.types import SegmentType -from core.workflow.variables.variables import ( +from dify_graph.variables.types import SegmentType +from dify_graph.variables.utils import ( + dumps_with_segments, + segment_orjson_default, + to_selector, +) +from dify_graph.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -379,3 +387,125 @@ class TestSegmentDumpAndLoad: assert get_segment_discriminator("not_a_dict") is None assert get_segment_discriminator(42) is None assert get_segment_discriminator(object) is None + + +class TestSegmentAdditionalProperties: + def test_base_segment_text_log_markdown_size_and_to_object(self): + """Ensure StringSegment exposes text, log, markdown, size and to_object.""" + segment = StringSegment(value="hello") + + assert segment.text == "hello" + assert segment.log == "hello" + assert segment.markdown == "hello" + assert segment.size > 0 + assert segment.to_object() == "hello" + + def test_none_segment_empty_outputs(self): + """Ensure NoneSegment renders empty text, log and markdown.""" + segment = NoneSegment() + + assert segment.text == "" + assert segment.log == "" + assert segment.markdown == "" + + def test_object_segment_json_outputs(self): + """Ensure ObjectSegment renders JSON output for text, log and markdown.""" + segment = ObjectSegment(value={"key": "值", "n": 1}) + + assert segment.text == '{"key": "值", "n": 1}' + assert segment.log == '{\n "key": "值",\n "n": 1\n}' + assert segment.markdown == '{\n "key": "值",\n "n": 1\n}' + + def test_array_segment_text_and_markdown(self): + """Ensure ArrayAnySegment handles empty/non-empty text and markdown rendering.""" + empty_segment = ArrayAnySegment(value=[]) + non_empty_segment = ArrayAnySegment(value=[1, "two"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == "[1, 'two']" + assert non_empty_segment.markdown == "- 1\n- two" + + def test_file_segment_properties(self): + """Ensure FileSegment markdown, text and log fields match expected behavior.""" + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="doc.txt") + segment = FileSegment(value=file) + + assert segment.markdown == "[doc.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + def test_array_string_segment_text_branches(self): + """Ensure ArrayStringSegment text handling for empty and non-empty values.""" + empty_segment = ArrayStringSegment(value=[]) + non_empty_segment = ArrayStringSegment(value=["hello", "世界"]) + + assert empty_segment.text == "" + assert non_empty_segment.text == '["hello", "世界"]' + + def test_array_file_segment_markdown_and_empty_text_log(self): + """Ensure ArrayFileSegment markdown renders links and text/log stay empty.""" + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + segment = ArrayFileSegment(value=[file1, file2]) + + assert segment.markdown == "[a.txt](https://example.com/file.txt)\n[b.txt](https://example.com/file.txt)" + assert segment.log == "" + assert segment.text == "" + + +class TestSegmentGroupAdditional: + def test_segment_group_markdown_and_to_object(self): + group = SegmentGroup(value=[StringSegment(value="A"), NoneSegment(), StringSegment(value="B")]) + + assert group.markdown == "AB" + assert group.to_object() == ["A", None, "B"] + + +class TestSegmentUtils: + def test_to_selector_without_paths(self): + assert to_selector("node-1", "output") == ["node-1", "output"] + + def test_to_selector_with_paths(self): + assert to_selector("node-1", "output", ("a", "b")) == ["node-1", "output", "a", "b"] + + def test_array_file_segment_serialization(self): + file1 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="a.txt") + file2 = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="b.txt") + + result = segment_orjson_default(ArrayFileSegment(value=[file1, file2])) + + assert len(result) == 2 + assert result[0]["filename"] == "a.txt" + assert result[1]["filename"] == "b.txt" + + def test_file_segment_serialization(self): + file = create_test_file(transfer_method=FileTransferMethod.REMOTE_URL, filename="single.txt") + + result = segment_orjson_default(FileSegment(value=file)) + + assert result["filename"] == "single.txt" + assert result["remote_url"] == "https://example.com/file.txt" + + def test_segment_group_and_segment_serialization(self): + group = SegmentGroup(value=[StringSegment(value="a"), StringSegment(value="b")]) + + assert segment_orjson_default(group) == ["a", "b"] + assert segment_orjson_default(StringSegment(value="value")) == "value" + + def test_segment_orjson_default_unsupported_type(self): + with pytest.raises(TypeError, match="not JSON serializable"): + segment_orjson_default(object()) + + def test_dumps_with_segments(self): + data = { + "segment": StringSegment(value="hello"), + "group": SegmentGroup(value=[StringSegment(value="x"), StringSegment(value="y")]), + 1: "numeric-key", + } + + dumped = dumps_with_segments(data) + loaded = orjson.loads(dumped) + + assert loaded["segment"] == "hello" + assert loaded["group"] == ["x", "y"] + assert loaded["1"] == "numeric-key" diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index e28fed187b..bb234d9bbd 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,6 +1,8 @@ import pytest -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: @@ -69,22 +71,36 @@ class TestSegmentTypeIsValidArrayValidation: """ def test_array_validation_all_success(self): + # Arrange value = ["hello", "world", "foo"] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert is_valid def test_array_validation_all_fail(self): + # Arrange value = ["hello", 123, "world"] - # Should return False, since 123 is not a string - assert not SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.ALL) + # Assert + assert not is_valid def test_array_validation_first(self): + # Arrange value = ["hello", 123, None] - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.FIRST) + # Assert + assert is_valid def test_array_validation_none(self): + # Arrange value = [1, 2, 3] - # validation is None, skip - assert SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Act + is_valid = SegmentType.ARRAY_STRING.is_valid(value, array_validation=ArrayValidation.NONE) + # Assert + assert is_valid class TestSegmentTypeGetZeroValue: @@ -163,3 +179,62 @@ class TestSegmentTypeGetZeroValue: for seg_type in unsupported_types: with pytest.raises(ValueError, match="unsupported variable type"): SegmentType.get_zero_value(seg_type) + + +class TestSegmentTypeInferSegmentType: + @pytest.mark.parametrize( + ("value", "expected"), + [ + ([], SegmentType.ARRAY_NUMBER), + ([1, 2, 3], SegmentType.ARRAY_NUMBER), + ([1, 2.5], SegmentType.ARRAY_NUMBER), + (["a", "b"], SegmentType.ARRAY_STRING), + ([{"k": "v"}], SegmentType.ARRAY_OBJECT), + ([None], SegmentType.ARRAY_ANY), + ([True, False], SegmentType.ARRAY_BOOLEAN), + ([[1], [2]], SegmentType.ARRAY_ANY), + ([1, "a"], SegmentType.ARRAY_ANY), + (None, SegmentType.NONE), + (True, SegmentType.BOOLEAN), + (1, SegmentType.INTEGER), + (1.2, SegmentType.FLOAT), + ("abc", SegmentType.STRING), + ({"k": "v"}, SegmentType.OBJECT), + ], + ) + def test_infer_segment_type_supported_values(self, value, expected): + assert SegmentType.infer_segment_type(value) == expected + + +class TestSegmentTypeAdditionalMethods: + def test_cast_value_for_bool_number_and_array_number(self): + assert SegmentType.cast_value(True, SegmentType.INTEGER) == 1 + assert SegmentType.cast_value(False, SegmentType.NUMBER) == 0 + assert SegmentType.cast_value([True, False], SegmentType.ARRAY_NUMBER) == [1, 0] + + mixed = [True, 1] + assert SegmentType.cast_value(mixed, SegmentType.ARRAY_NUMBER) is mixed + assert SegmentType.cast_value("x", SegmentType.STRING) == "x" + + def test_exposed_type_and_element_type(self): + assert SegmentType.INTEGER.exposed_type() == SegmentType.NUMBER + assert SegmentType.FLOAT.exposed_type() == SegmentType.NUMBER + assert SegmentType.STRING.exposed_type() == SegmentType.STRING + + assert SegmentType.ARRAY_STRING.element_type() == SegmentType.STRING + assert SegmentType.ARRAY_ANY.element_type() is None + + with pytest.raises(ValueError, match="element_type is only supported by array type"): + SegmentType.STRING.element_type() + + def test_group_validation_for_segment_group_and_list(self): + valid_group = SegmentGroup(value=[StringSegment(value="a")]) + assert SegmentType.GROUP.is_valid(valid_group) is True + assert SegmentType.GROUP.is_valid([StringSegment(value="b")]) is True + assert SegmentType.GROUP.is_valid(["not-segment"]) is False + + def test_unreachable_assertion_branch(self, monkeypatch): + monkeypatch.setattr(SegmentType, "is_array_type", lambda self: False) + + with pytest.raises(AssertionError, match="unreachable"): + SegmentType.ARRAY_STRING.is_valid(["a"]) diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 52e5dd180c..41ce483447 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -10,10 +10,10 @@ from typing import Any import pytest -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables.segment_group import SegmentGroup -from core.workflow.variables.segments import ( +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables.segment_group import SegmentGroup +from dify_graph.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +22,7 @@ from core.workflow.variables.segments import ( ObjectSegment, StringSegment, ) -from core.workflow.variables.types import ArrayValidation, SegmentType +from dify_graph.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 6fc162e533..dd0fe2e65a 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,7 @@ import pytest from pydantic import ValidationError -from core.workflow.variables import ( +from dify_graph.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +11,7 @@ from core.workflow.variables import ( SegmentType, StringVariable, ) -from core.workflow.variables.variables import VariableBase +from dify_graph.variables.variables import VariableBase def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index 8dd669e17f..d09b8397c3 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from core.workflow.context.execution_context import ( +from dify_graph.context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from core.workflow.context.execution_context import capture_current_context + from dify_graph.context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from core.workflow.context.execution_context import capture_current_context + from dify_graph.context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from core.workflow.context import reset_context_provider + from dify_graph.context import reset_context_provider reset_context_provider() def teardown_method(self): - from core.workflow.context import reset_context_provider + from dify_graph.context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from core.workflow.context import ContextProviderNotFoundError + from dify_graph.context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 8d49394653..22792eb5b3 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,8 +4,10 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool +from dify_graph.variables.variables import StringVariable class StubCoordinator: @@ -115,7 +117,7 @@ class TestGraphRuntimeState: queue = state.ready_queue - from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue + from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue assert isinstance(queue, InMemoryReadyQueue) @@ -124,7 +126,7 @@ class TestGraphRuntimeState: execution = state.graph_execution - from core.workflow.graph_engine.domain.graph_execution import GraphExecution + from dify_graph.graph_engine.domain.graph_execution import GraphExecution assert isinstance(execution, GraphExecution) assert execution.workflow_id == "" @@ -139,7 +141,7 @@ class TestGraphRuntimeState: mock_graph = MagicMock() with patch( - "core.workflow.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True + "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True ) as coordinator_cls: coordinator_instance = coordinator_cls.return_value state.configure(graph=mock_graph) @@ -278,3 +280,17 @@ class TestGraphRuntimeState: assert restored_execution.started is True assert new_stub.state == "configured" + + def test_snapshot_restore_preserves_updated_conversation_variable(self): + variable_pool = VariablePool( + conversation_variables=[StringVariable(name="session_name", value="before")], + ) + variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + snapshot = state.dumps() + restored = GraphRuntimeState.from_snapshot(snapshot) + + restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) + assert restored_value is not None + assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py index 6144df06e0..158f7018b5 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py @@ -5,7 +5,7 @@ Tests for PauseReason discriminated union serialization/deserialization. import pytest from pydantic import BaseModel, ValidationError -from core.workflow.entities.pause_reason import ( +from dify_graph.entities.pause_reason import ( HumanInputRequired, PauseReason, SchedulingPause, diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py index f3197ea282..2d4c7f7b77 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ b/api/tests/unit_tests/core/workflow/entities/test_template.py @@ -1,6 +1,6 @@ """Tests for template module.""" -from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment +from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment class TestTemplate: diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index d4254df319..6100ebede5 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -1,5 +1,5 @@ -from core.workflow.runtime import VariablePool -from core.workflow.variables.segments import ( +from dify_graph.runtime import VariablePool +from dify_graph.variables.segments import ( BooleanSegment, IntegerSegment, NoneSegment, diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py index a4b1189a1c..216e64db8d 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py @@ -8,8 +8,8 @@ from typing import Any import pytest -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.enums import NodeType +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import BuiltinNodeTypes class TestWorkflowNodeExecutionProcessDataTruncation: @@ -25,7 +25,7 @@ class TestWorkflowNodeExecutionProcessDataTruncation: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=process_data, created_at=datetime.now(), @@ -212,7 +212,7 @@ class TestWorkflowNodeExecutionProcessDataScenarios: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=scenario.original_data, created_at=datetime.now(), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py index 01b514ed7c..24bd9ccbed 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph.py @@ -2,10 +2,10 @@ from unittest.mock import Mock -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.graph.edge import Edge -from core.workflow.graph.graph import Graph -from core.workflow.nodes.base.node import Node +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from dify_graph.graph.edge import Edge +from dify_graph.graph.graph import Graph +from dify_graph.nodes.base.node import Node def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: @@ -14,7 +14,7 @@ def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: Nod node.id = node_id node.execution_type = execution_type node.state = state - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START return node diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py index 15d1dcb48d..64c2eee776 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py @@ -2,12 +2,12 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.graph import Graph -from core.workflow.nodes.base.node import Node +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.graph import Graph +from dify_graph.nodes.base.node import Node -def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: +def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: node = MagicMock(spec=Node) node.id = node_id node.node_type = node_type @@ -17,9 +17,9 @@ def _make_node(node_id: str, node_type: NodeType = NodeType.START) -> Node: def test_graph_builder_creates_linear_graph(): builder = Graph.new() - root = _make_node("root", NodeType.START) - mid = _make_node("mid", NodeType.LLM) - end = _make_node("end", NodeType.END) + root = _make_node("root", BuiltinNodeTypes.START) + mid = _make_node("mid", BuiltinNodeTypes.LLM) + end = _make_node("end", BuiltinNodeTypes.END) graph = builder.add_root(root).add_node(mid).add_node(end).build() diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 6858120335..75de07bd8b 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -4,15 +4,13 @@ from typing import Any import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph.validation import GraphValidationError -from core.workflow.nodes import NodeType -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.graph import Graph +from dify_graph.graph.validation import GraphValidationError +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params def _build_iteration_graph(node_id: str) -> dict[str, Any]: @@ -53,14 +51,14 @@ def _build_loop_graph(node_id: str) -> dict[str, Any]: def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + user_from="account", + invoke_from="debugger", call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -94,7 +92,7 @@ def test_iteration_root_requires_skip_validation(): ) assert graph.root_node.id == node_id - assert graph.root_node.node_type == NodeType.ITERATION + assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION def test_loop_root_requires_skip_validation(): @@ -117,4 +115,4 @@ def test_loop_root_requires_skip_validation(): ) assert graph.root_node.id == node_id - assert graph.root_node.node_type == NodeType.LOOP + assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index 5716aae4c7..e94ad74eb0 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,25 +6,24 @@ from dataclasses import dataclass import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams -from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType -from core.workflow.graph import Graph -from core.workflow.graph.validation import GraphValidationError -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType +from dify_graph.graph import Graph +from dify_graph.graph.validation import GraphValidationError +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _TestNodeData(BaseNodeData): - type: NodeType | str | None = None + type: NodeType | None = None execution_type: NodeExecutionType | str | None = None class _TestNode(Node[_TestNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER execution_type = NodeExecutionType.EXECUTABLE @classmethod @@ -47,13 +46,8 @@ class _TestNode(Node[_TestNodeData]): ) node_type_value = self.data.get("type") - if isinstance(node_type_value, NodeType): + if isinstance(node_type_value, str): self.node_type = node_type_value - elif isinstance(node_type_value, str): - try: - self.node_type = NodeType(node_type_value) - except ValueError: - pass def _run(self): raise NotImplementedError @@ -92,14 +86,14 @@ class _SimpleNodeFactory: @pytest.fixture def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + user_from="account", + invoke_from="service-api", call_depth=0, ) variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) @@ -113,14 +107,17 @@ def test_graph_initialization_runs_default_validators( ): node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, - {"id": "answer", "data": {"type": NodeType.ANSWER, "title": "Answer"}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, ] graph_config["edges"] = [ {"source": "start", "target": "answer", "sourceHandle": "success"}, ] - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph.root_node.id == "start" assert "answer" in graph.nodes @@ -131,14 +128,17 @@ def test_graph_validation_fails_for_unknown_edge_targets( ) -> None: node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, ] graph_config["edges"] = [ {"source": "start", "target": "missing", "sourceHandle": "success"}, ] with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory) + Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) @@ -148,11 +148,14 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( ) -> None: node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, { "id": "branch", "data": { - "type": NodeType.IF_ELSE, + "type": BuiltinNodeTypes.IF_ELSE, "title": "Branch", "error_strategy": ErrorStrategy.FAIL_BRANCH, }, @@ -162,25 +165,55 @@ def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( {"source": "start", "target": "branch", "sourceHandle": "success"}, ] - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH -def test_graph_validation_blocks_start_and_trigger_coexistence( +def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], ) -> None: node_factory, graph_config = graph_init_dependencies graph_config["nodes"] = [ - {"id": "start", "data": {"type": NodeType.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}}, { - "id": "trigger", - "data": {"type": NodeType.TRIGGER_WEBHOOK, "title": "Webhook", "execution_type": NodeExecutionType.ROOT}, + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, + }, + {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, + { + "id": "note", + "type": "custom-note", + "data": { + "type": "", + "title": "", + "desc": "", + "text": "{}", + "theme": "blue", + }, + }, + ] + graph_config["edges"] = [ + {"source": "start", "target": "answer", "sourceHandle": "success"}, + ] + + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + assert graph.root_node.id == "start" + assert "answer" in graph.nodes + assert "note" not in graph.nodes + + +def test_graph_init_fails_for_unknown_root_node_id( + graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], +) -> None: + node_factory, graph_config = graph_init_dependencies + graph_config["nodes"] = [ + { + "id": "start", + "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, }, ] graph_config["edges"] = [] - with pytest.raises(GraphValidationError) as exc_info: - Graph.init(graph_config=graph_config, node_factory=node_factory) - - assert any(issue.code == "TRIGGER_START_NODE_CONFLICT" for issue in exc_info.value.issues) + with pytest.raises(ValueError, match="Root node id missing not found in the graph"): + Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 3fff4cf6a9..40ed61eb02 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -68,7 +68,7 @@ print(f"Success rate: {suite_result.success_rate:.1f}%") #### Event Sequence Validation ```python -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, @@ -376,39 +376,39 @@ See `test_mock_example.py` for comprehensive examples including: ```bash # Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py # Run with specific test patterns -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -k "test_echo" +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" # Run with verbose output -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py -v +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v ``` ### Mock System Tests ```bash # Run auto-mock system tests -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py # Run examples -uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py # Run simple validation -uv run python api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py ``` ### All Tests ```bash # Run all graph engine tests -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ # Run with coverage -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ --cov=core.workflow.graph_engine +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine # Run in parallel -uv run pytest api/tests/unit_tests/core/workflow/graph_engine/ -n auto +uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto ``` ## Troubleshooting diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index db9b977e4a..4dec618e49 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,15 +3,15 @@ import json from unittest.mock import MagicMock -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import ( +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.entities.commands import ( AbortCommand, CommandType, GraphEngineCommand, UpdateVariablesCommand, VariableUpdate, ) -from core.workflow.variables import IntegerVariable, StringVariable +from dify_graph.variables import IntegerVariable, StringVariable class TestRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py index 5d17b7a243..6f821ba799 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py @@ -2,18 +2,18 @@ from __future__ import annotations -from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine.domain.graph_execution import GraphExecution -from core.workflow.graph_engine.event_management.event_handlers import EventHandler -from core.workflow.graph_engine.event_management.event_manager import EventManager -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from core.workflow.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.base.entities import RetryConfig -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.entities.base_node_data import RetryConfig +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine.domain.graph_execution import GraphExecution +from dify_graph.graph_engine.event_management.event_handlers import EventHandler +from dify_graph.graph_engine.event_management.event_manager import EventManager +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue +from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator +from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.runtime import GraphRuntimeState, VariablePool from libs.datetime_utils import naive_utc_now @@ -73,7 +73,7 @@ def test_retry_does_not_emit_additional_start_event() -> None: handler, event_manager, graph_execution = _build_event_handler(node_id) execution_id = "exec-1" - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE start_time = naive_utc_now() start_event = NodeRunStartedEvent( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py index 15eac6b537..25494dc647 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py @@ -4,9 +4,9 @@ from __future__ import annotations import logging -from core.workflow.graph_engine.event_management.event_manager import EventManager -from core.workflow.graph_engine.layers.base import GraphEngineLayer -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_engine.event_management.event_manager import EventManager +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import GraphEngineEvent class _FaultyLayer(GraphEngineLayer): diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py index 0019020ede..73d59ea4e9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py @@ -2,9 +2,9 @@ from unittest.mock import MagicMock, create_autospec -from core.workflow.graph import Edge, Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.graph_traversal.skip_propagator import SkipPropagator +from dify_graph.graph import Edge, Graph +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator class TestSkipPropagator: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index 2ef23c7f0f..fc8133f5e1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,8 +7,8 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 903800ce88..9e7b3654b7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -10,7 +10,7 @@ from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes @pytest.fixture @@ -44,7 +44,7 @@ def mock_start_node(): node.id = "test-start-node-id" node.title = "Start Node" node.execution_id = "test-start-execution-id" - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START return node @@ -55,7 +55,7 @@ def mock_llm_node(): node.id = "test-llm-node-id" node.title = "LLM Node" node.execution_id = "test-llm-execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM return node @@ -63,13 +63,13 @@ def mock_llm_node(): def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" from core.tools.entities.tool_entities import ToolProviderType - from core.workflow.nodes.tool.entities import ToolNodeData + from dify_graph.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" node.title = "Test Tool Node" node.execution_id = "test-tool-execution-id" - node.node_type = NodeType.TOOL + node.node_type = BuiltinNodeTypes.TOOL tool_data = ToolNodeData( title="Test Tool Node", @@ -108,7 +108,7 @@ def mock_retrieval_node(): node.id = "test-retrieval-node-id" node.title = "Retrieval Node" node.execution_id = "test-retrieval-execution-id" - node.node_type = NodeType.KNOWLEDGE_RETRIEVAL + node.node_type = BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL return node @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from core.workflow.graph_events.node import NodeRunSucceededEvent - from core.workflow.node_events.base import NodeRunResult + from dify_graph.graph_events.node import NodeRunSucceededEvent + from dify_graph.node_events.base import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -130,7 +130,7 @@ def mock_result_event(): return NodeRunSucceededEvent( id="test-execution-id", node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.now(), node_run_result=node_run_result, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py index f1086c9936..db32527849 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py @@ -2,13 +2,13 @@ from __future__ import annotations import pytest -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.layers.base import ( +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import ( GraphEngineLayer, GraphEngineLayerNotInitializedError, ) -from core.workflow.graph_events import GraphEngineEvent +from dify_graph.graph_events import GraphEngineEvent from ..test_table_runner import WorkflowRunner diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 9a491d24e1..2a36f712fd 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -4,18 +4,18 @@ from unittest.mock import MagicMock, patch from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.entities.commands import CommandType -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.entities.commands import CommandType +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult def _build_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", node_id="llm-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -30,8 +30,9 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -50,8 +51,9 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node = MagicMock() node.id = "question-classifier-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.QUESTION_CLASSIFIER + node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -70,8 +72,9 @@ def test_non_llm_node_is_ignored() -> None: node = MagicMock() node.id = "start-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.START + node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node._model_instance = object() result_event = _build_succeeded_event() @@ -86,8 +89,9 @@ def test_quota_error_is_handled_in_layer() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() result_event = _build_succeeded_event() @@ -107,8 +111,9 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.execution_id = "execution-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" + node.require_dify_context.return_value.tenant_id = "tenant-id" node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -135,7 +140,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -161,7 +166,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" - node.node_type = NodeType.LLM + node.node_type = BuiltinNodeTypes.LLM node.model_instance = object() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index ade846df28..478a2b592e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -16,7 +16,7 @@ import pytest from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -29,7 +29,7 @@ class TestObservabilityLayerInitialization: layer = ObservabilityLayer() assert not layer._is_disabled assert layer._tracer is not None - assert NodeType.TOOL in layer._parsers + assert BuiltinNodeTypes.TOOL in layer._parsers assert layer._default_parser is not None @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", False) @@ -39,7 +39,7 @@ class TestObservabilityLayerInitialization: layer = ObservabilityLayer() assert not layer._is_disabled assert layer._tracer is not None - assert NodeType.TOOL in layer._parsers + assert BuiltinNodeTypes.TOOL in layer._parsers assert layer._default_parser is not None @@ -117,7 +117,7 @@ class TestObservabilityLayerParserIntegration: attrs = spans[0].attributes assert attrs["node.id"] == mock_start_node.id assert attrs["node.execution_id"] == mock_start_node.execution_id - assert attrs["node.type"] == mock_start_node.node_type.value + assert attrs["node.type"] == mock_start_node.node_type @patch("core.app.workflow.layers.observability.dify_config.ENABLE_OTEL", True) @pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false") @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from core.workflow.node_events.base import NodeRunResult + from dify_graph.node_events.base import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py index c1fc4acd73..548c10ce8d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py @@ -5,18 +5,18 @@ from __future__ import annotations import queue from unittest import mock -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.event_management.event_handlers import EventHandler -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from core.workflow.graph_events import ( +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.event_management.event_handlers import EventHandler +from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher +from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from dify_graph.graph_events import ( GraphNodeEventBase, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult +from dify_graph.node_events import NodeRunResult from libs.datetime_utils import naive_utc_now @@ -26,7 +26,7 @@ def test_dispatcher_should_consume_remains_events_after_pause(): GraphNodeEventBase( id="test", node_id="test", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, ) ) event_handler = mock.Mock(spec=EventHandler) @@ -107,7 +107,7 @@ def _make_started_event() -> NodeRunStartedEvent: return NodeRunStartedEvent( id="start-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Test Node", start_at=naive_utc_now(), ) @@ -117,7 +117,7 @@ def _make_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="success-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Test Node", start_at=naive_utc_now(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), @@ -151,20 +151,20 @@ def test_dispatcher_drain_event_queue(): NodeRunStartedEvent( id="start-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, node_title="Code", start_at=naive_utc_now(), ), NodeRunPauseRequestedEvent( id="pause-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, reason=SchedulingPause(message="test pause"), ), NodeRunSucceededEvent( id="success-event", node_id="node-1", - node_type=NodeType.CODE, + node_type=BuiltinNodeTypes.CODE, start_at=naive_utc_now(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), ), diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py index fd1e6fc6dc..7af6b26d87 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py index b291f95e0f..fc0d22f739 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py @@ -7,7 +7,8 @@ for workflows containing nodes that require third-party services. import pytest -from core.workflow.enums import NodeType +from dify_graph.enums import BuiltinNodeTypes +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from .test_table_runner import TableTestRunner, WorkflowTestCase @@ -199,22 +200,19 @@ def test_mock_config_builder(): def test_mock_factory_node_type_detection(): """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -229,23 +227,23 @@ def test_mock_factory_node_type_detection(): ) # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) - assert factory.should_mock_node(NodeType.TOOL) - assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(NodeType.HTTP_REQUEST) - assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.TOOL) + assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) + assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) + assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Test that non-service nodes are not mocked - assert not factory.should_mock_node(NodeType.START) - assert not factory.should_mock_node(NodeType.END) - assert not factory.should_mock_node(NodeType.IF_ELSE) - assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR) + assert not factory.should_mock_node(BuiltinNodeTypes.START) + assert not factory.should_mock_node(BuiltinNodeTypes.END) + assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) + assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) def test_custom_mock_handler(): @@ -309,11 +307,9 @@ def test_workflow_without_auto_mock(): def test_register_custom_mock_node(): """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.nodes.template_transform import TemplateTransformNode - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.nodes.template_transform import TemplateTransformNode + from dify_graph.runtime import GraphRuntimeState, VariablePool from .test_mock_factory import MockNodeFactory @@ -323,15 +319,14 @@ def test_register_custom_mock_node(): # Custom mock implementation pass - graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", + graph_init_params = build_test_graph_init_params( workflow_id="test", graph_config={}, + tenant_id="test", + app_id="test", user_id="test", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, - call_depth=0, ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), @@ -346,15 +341,15 @@ def test_register_custom_mock_node(): ) # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Unregister mock - factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) + assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Re-register custom mock - factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) def test_default_config_by_node_type(): @@ -363,7 +358,7 @@ def test_default_config_by_node_type(): # Set default config for all LLM nodes mock_config.set_default_config( - NodeType.LLM, + BuiltinNodeTypes.LLM, { "default_response": "Default LLM response for all nodes", "temperature": 0.7, @@ -372,23 +367,23 @@ def test_default_config_by_node_type(): # Set default config for all HTTP nodes mock_config.set_default_config( - NodeType.HTTP_REQUEST, + BuiltinNodeTypes.HTTP_REQUEST, { "default_status": 200, "default_timeout": 30, }, ) - llm_config = mock_config.get_default_config(NodeType.LLM) + llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) assert llm_config["default_response"] == "Default LLM response for all nodes" assert llm_config["temperature"] == 0.7 - http_config = mock_config.get_default_config(NodeType.HTTP_REQUEST) + http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) assert http_config["default_status"] == 200 assert http_config["default_timeout"] == 30 # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(NodeType.TOOL) + tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) assert tool_config == {} diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py index b04643b78a..30acbdaf3d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 6c3700ea2b..765c4deba3 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,24 +3,23 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.pause_reason import SchedulingPause -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.pause_reason import SchedulingPause +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.entities.commands import ( AbortCommand, CommandType, PauseCommand, UpdateVariablesCommand, VariableUpdate, ) -from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.variables import IntegerVariable, StringVariable -from models.enums import UserFrom +from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.variables import IntegerVariable, StringVariable def test_abort_command(): @@ -41,13 +40,17 @@ def test_abort_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -99,7 +102,7 @@ def test_redis_channel_serialization(): mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel + from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel # Create channel with a specific key channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") @@ -151,13 +154,17 @@ def test_pause_command(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, @@ -207,13 +214,17 @@ def test_update_variables_command_updates_pool(): id="start", config={"id": "start", "data": {"title": "start", "variables": []}}, graph_init_params=GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ), graph_runtime_state=shared_runtime_state, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py index 96926797ec..3a9a0b18bc 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py @@ -7,7 +7,7 @@ This test suite validates the behavior of a workflow that: 3. Handles multiple answer nodes with different outputs """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py index ee944c8e3e..76bf179f33 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py @@ -6,10 +6,10 @@ This test validates that: - When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) """ -from core.workflow.enums import NodeType -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, @@ -74,7 +74,11 @@ def test_streaming_output_with_blocking_equals_one(): # Find indices of first LLM success event and first stream chunk event llm2_start_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM), + ( + i + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM + ), -1, ) first_chunk_index = next( @@ -96,16 +100,16 @@ def test_streaming_output_with_blocking_equals_one(): # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.TEMPLATE_TRANSFORM + e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.TEMPLATE_TRANSFORM] + template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" ) # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END] + end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" @@ -168,7 +172,11 @@ def test_streaming_output_with_blocking_not_equals_one(): # Find indices of first LLM success event and first stream chunk event llm2_start_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM), + ( + i + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM + ), -1, ) first_chunk_index = next( @@ -194,15 +202,15 @@ def test_streaming_output_with_blocking_not_equals_one(): assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == NodeType.LLM] + start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] + llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] llm_node_ids = {se.node_id for se in start_events} assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( "Expected all LLM chunk events to be from LLM nodes" ) # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.END] + end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py index bf8034487c..778dad5952 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py @@ -1,10 +1,10 @@ import queue from datetime import datetime -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher -from core.workflow.graph_events import NodeRunSucceededEvent -from core.workflow.node_events import NodeRunResult +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher +from dify_graph.graph_events import NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult class StubExecutionCoordinator: @@ -51,7 +51,7 @@ def test_dispatcher_drains_events_when_paused() -> None: event = NodeRunSucceededEvent( id="exec-1", node_id="node-1", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, start_at=datetime.utcnow(), node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py index b1380cd6d2..c87dc75b95 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py @@ -6,7 +6,7 @@ field is missing from the output configuration, ensuring backward compatibility with older workflow definitions. """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py index 53de8908a8..35406997ed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor -from core.workflow.graph_engine.domain.graph_execution import GraphExecution -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from core.workflow.graph_engine.worker_management.worker_pool import WorkerPool +from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor +from dify_graph.graph_engine.domain.graph_execution import GraphExecution +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator +from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index 5a55d7086e..4e13177d2b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -10,15 +10,15 @@ import time from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.workflow.enums import ErrorStrategy -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType +from dify_graph.enums import ErrorStrategy +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunPartialSucceededEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.nodes.base.entities import DefaultValue, DefaultValueType # Import the test framework from the new module from .test_mock_config import MockConfigBuilder @@ -455,7 +455,7 @@ def test_if_else_workflow_property_diverse_inputs(query_input): # Tests for the Layer system def test_layer_system_basic(): """Test basic layer functionality with DebugLoggingLayer.""" - from core.workflow.graph_engine.layers import DebugLoggingLayer + from dify_graph.graph_engine.layers import DebugLoggingLayer runner = WorkflowRunner() @@ -495,7 +495,7 @@ def test_layer_system_basic(): def test_layer_chaining(): """Test chaining multiple layers.""" - from core.workflow.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer + from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer # Create a custom test layer class TestLayer(GraphEngineLayer): @@ -549,7 +549,7 @@ def test_layer_chaining(): def test_layer_error_handling(): """Test that layer errors don't crash the engine.""" - from core.workflow.graph_engine.layers import GraphEngineLayer + from dify_graph.graph_engine.layers import GraphEngineLayer # Create a layer that throws errors class FaultyLayer(GraphEngineLayer): @@ -591,7 +591,7 @@ def test_layer_error_handling(): def test_event_sequence_validation(): """Test the new event sequence validation feature.""" - from core.workflow.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() @@ -678,7 +678,7 @@ def test_event_sequence_validation(): def test_event_sequence_validation_with_table_tests(): """Test event sequence validation with table-driven tests.""" - from core.workflow.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent + from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent runner = TableTestRunner() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py index 6385b0b91f..255784b77d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py @@ -6,13 +6,13 @@ import json from collections import deque from unittest.mock import MagicMock -from core.workflow.enums import NodeExecutionType, NodeState, NodeType -from core.workflow.graph_engine.domain import GraphExecution -from core.workflow.graph_engine.response_coordinator import ResponseStreamCoordinator -from core.workflow.graph_engine.response_coordinator.path import Path -from core.workflow.graph_engine.response_coordinator.session import ResponseSession -from core.workflow.graph_events import NodeRunStreamChunkEvent -from core.workflow.nodes.base.template import Template, TextSegment, VariableSegment +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState +from dify_graph.graph_engine.domain import GraphExecution +from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator +from dify_graph.graph_engine.response_coordinator.path import Path +from dify_graph.graph_engine.response_coordinator.session import ResponseSession +from dify_graph.graph_events import NodeRunStreamChunkEvent +from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment class CustomGraphExecutionError(Exception): @@ -101,7 +101,9 @@ def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> No class DummyNode: def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: self.id = node_id - self.node_type = NodeType.ANSWER if execution_type == NodeExecutionType.RESPONSE else NodeType.LLM + self.node_type = ( + BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM + ) self.execution_type = execution_type self.state = NodeState.UNKNOWN self.title = node_id @@ -160,7 +162,7 @@ def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> No event = NodeRunStreamChunkEvent( id="exec-1", node_id="response-1", - node_type=NodeType.ANSWER, + node_type=BuiltinNodeTypes.ANSWER, selector=["node-source", "text"], chunk="chunk-1", is_final=False, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index 65d34c2009..d54f0be190 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,26 +1,27 @@ import time from collections.abc import Mapping -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeState -from core.workflow.graph import Graph -from core.workflow.graph_engine.graph_state_manager import GraphStateManager -from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.llm.entities import ( +from dify_graph.entities import GraphInitParams +from dify_graph.enums import NodeState +from dify_graph.graph import Graph +from dify_graph.graph_engine.graph_state_manager import GraphStateManager +from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -73,11 +74,11 @@ def _build_llm_node( def _build_graph(runtime_state: GraphRuntimeState) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index b117b26b4c..538f53c603 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,10 +4,8 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -16,25 +14,27 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -47,11 +47,11 @@ def _build_branching_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 45505909ea..36bba6deb6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,10 +3,8 @@ import time from unittest import mock from unittest.mock import MagicMock -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, @@ -15,25 +13,27 @@ from core.workflow.graph_events import ( NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.node import NodeRunHumanInputFormFilledEvent -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -46,11 +46,11 @@ def _build_llm_human_llm_graph( graph_runtime_state: GraphRuntimeState | None = None, ) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index f33d37e8ff..8da179c15e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,34 +1,34 @@ import time from unittest import mock -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events import ( +from dify_graph.graph import Graph +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.base.entities import OutputVariableEntity, OutputVariableType -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.nodes.if_else.if_else_node import IfElseNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.utils.condition.entities import Condition +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.utils.condition.entities import Condition +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig from .test_mock_nodes import MockLLMNode @@ -37,15 +37,10 @@ from .test_table_runner import TableTestRunner, WorkflowTestCase def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + graph_init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 3e21a5b44d..733fd53bc8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -5,7 +5,7 @@ This test validates the behavior of a loop containing an answer node inside the loop that may produce output errors. """ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py index d88c1d9f9e..6ff2722f78 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunLoopNextEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py index 5ceb8dd7f7..8a4649693d 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_config.py @@ -11,8 +11,6 @@ from collections.abc import Callable from dataclasses import dataclass, field from typing import Any -from core.workflow.enums import NodeType - @dataclass class NodeMockConfig: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index b862cbe89e..93010eea54 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -5,12 +5,12 @@ This module provides a MockNodeFactory that automatically detects and mocks node requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request). """ -from collections.abc import Mapping from typing import TYPE_CHECKING, Any -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.enums import NodeType -from core.workflow.nodes.base.node import Node +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +28,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -61,63 +61,51 @@ class MockNodeFactory(DifyNodeFactory): # Map of node types that should be mocked self._mock_node_types = { - NodeType.LLM: MockLLMNode, - NodeType.AGENT: MockAgentNode, - NodeType.TOOL: MockToolNode, - NodeType.KNOWLEDGE_RETRIEVAL: MockKnowledgeRetrievalNode, - NodeType.HTTP_REQUEST: MockHttpRequestNode, - NodeType.QUESTION_CLASSIFIER: MockQuestionClassifierNode, - NodeType.PARAMETER_EXTRACTOR: MockParameterExtractorNode, - NodeType.DOCUMENT_EXTRACTOR: MockDocumentExtractorNode, - NodeType.ITERATION: MockIterationNode, - NodeType.LOOP: MockLoopNode, - NodeType.TEMPLATE_TRANSFORM: MockTemplateTransformNode, - NodeType.CODE: MockCodeNode, + BuiltinNodeTypes.LLM: MockLLMNode, + BuiltinNodeTypes.AGENT: MockAgentNode, + BuiltinNodeTypes.TOOL: MockToolNode, + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL: MockKnowledgeRetrievalNode, + BuiltinNodeTypes.HTTP_REQUEST: MockHttpRequestNode, + BuiltinNodeTypes.QUESTION_CLASSIFIER: MockQuestionClassifierNode, + BuiltinNodeTypes.PARAMETER_EXTRACTOR: MockParameterExtractorNode, + BuiltinNodeTypes.DOCUMENT_EXTRACTOR: MockDocumentExtractorNode, + BuiltinNodeTypes.ITERATION: MockIterationNode, + BuiltinNodeTypes.LOOP: MockLoopNode, + BuiltinNodeTypes.TEMPLATE_TRANSFORM: MockTemplateTransformNode, + BuiltinNodeTypes.CODE: MockCodeNode, } - def create_node(self, node_config: Mapping[str, Any]) -> Node: + def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ Create a node instance, using mock implementations for third-party service nodes. :param node_config: Node configuration dictionary :return: Node instance (real or mocked) """ - # Get node type from config - node_data = node_config.get("data", {}) - node_type_str = node_data.get("type") - - if not node_type_str: - # Fall back to parent implementation for nodes without type - return super().create_node(node_config) - - try: - node_type = NodeType(node_type_str) - except ValueError: - # Unknown node type, use parent implementation - return super().create_node(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + node_data = typed_node_config["data"] + node_type = node_data.type # Check if this node type should be mocked if node_type in self._mock_node_types: - node_id = node_config.get("id") - if not node_id: - raise ValueError("Node config missing id") + node_id = typed_node_config["id"] # Create mock node instance mock_class = self._mock_node_types[node_type] - if node_type == NodeType.CODE: + if node_type == BuiltinNodeTypes.CODE: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, code_executor=self._code_executor, code_limits=self._code_limits, ) - elif node_type == NodeType.HTTP_REQUEST: + elif node_type == BuiltinNodeTypes.HTTP_REQUEST: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -126,10 +114,14 @@ class MockNodeFactory(DifyNodeFactory): tool_file_manager_factory=self._http_request_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) - elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: + elif node_type in { + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + }: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -139,7 +131,7 @@ class MockNodeFactory(DifyNodeFactory): else: mock_instance = mock_class( id=node_id, - config=node_config, + config=typed_node_config, graph_init_params=self.graph_init_params, graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, @@ -148,7 +140,7 @@ class MockNodeFactory(DifyNodeFactory): return mock_instance # For non-mocked node types, use parent implementation - return super().create_node(node_config) + return super().create_node(typed_node_config) def should_mock_node(self, node_type: NodeType) -> bool: """ diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index aae4de9a27..3e4247f33f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,34 +2,31 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -import sys -from pathlib import Path - -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - -from core.workflow.enums import NodeType +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory def test_mock_factory_registers_iteration_node(): """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create a MockNodeFactory instance graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -45,30 +42,29 @@ def test_mock_factory_registers_iteration_node(): ) # Check that iteration node is registered - assert NodeType.ITERATION in factory._mock_node_types + assert BuiltinNodeTypes.ITERATION in factory._mock_node_types print("✓ Iteration node is registered in MockNodeFactory") # Check that loop node is registered - assert NodeType.LOOP in factory._mock_node_types + assert BuiltinNodeTypes.LOOP in factory._mock_node_types print("✓ Loop node is registered in MockNodeFactory") # Check the class types from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - assert factory._mock_node_types[NodeType.ITERATION] == MockIterationNode + assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode print("✓ Iteration node maps to MockIterationNode class") - assert factory._mock_node_types[NodeType.LOOP] == MockLoopNode + assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode print("✓ Loop node maps to MockLoopNode class") def test_mock_iteration_node_preserves_config(): """Test that MockIterationNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode # Create mock config @@ -76,13 +72,17 @@ def test_mock_iteration_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) @@ -127,10 +127,9 @@ def test_mock_iteration_node_preserves_config(): def test_mock_loop_node_preserves_config(): """Test that MockLoopNode preserves mock configuration.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode # Create mock config @@ -138,13 +137,17 @@ def test_mock_loop_node_preserves_config(): # Create minimal graph init params graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={"nodes": [], "edges": []}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 5aed463a45..454263bef9 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,28 +11,45 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.agent import AgentNode -from core.workflow.nodes.code import CodeNode -from core.workflow.nodes.document_extractor import DocumentExtractorNode -from core.workflow.nodes.http_request import HttpRequestNode -from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode -from core.workflow.nodes.llm import LLMNode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.nodes.parameter_extractor import ParameterExtractorNode -from core.workflow.nodes.question_classifier import QuestionClassifierNode -from core.workflow.nodes.template_transform import TemplateTransformNode -from core.workflow.nodes.tool import ToolNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.code import CodeNode +from dify_graph.nodes.document_extractor import DocumentExtractorNode +from dify_graph.nodes.http_request import HttpRequestNode +from dify_graph.nodes.llm import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.parameter_extractor import ParameterExtractorNode +from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.question_classifier import QuestionClassifierNode +from dify_graph.nodes.template_transform import TemplateTransformNode +from dify_graph.nodes.template_transform.template_renderer import ( + Jinja2TemplateRenderer, + TemplateRenderError, +) +from dify_graph.nodes.tool import ToolNode if TYPE_CHECKING: - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState from .test_mock_config import MockConfig +class _TestJinja2Renderer(Jinja2TemplateRenderer): + """Simple Jinja2 renderer for tests (avoids code executor).""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + from jinja2 import Template as _Jinja2Template + + try: + return _Jinja2Template(template).render(**variables) + except Exception as exc: # pragma: no cover - pass through as contract error + raise TemplateRenderError(str(exc)) from exc + + class MockNodeMixin: """Mixin providing common mock functionality.""" @@ -49,6 +66,28 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + # LLM-like nodes now require an http_client; provide a mock by default for tests. + kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + # Ensure TemplateTransformNode receives a renderer now required by constructor + if isinstance(self, TemplateTransformNode): + kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + + # Provide default tool_file_manager_factory for ToolNode subclasses + from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + + if isinstance(self, _ToolNode): + kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + + if isinstance(self, AgentNode): + presentation_provider = MagicMock() + presentation_provider.get_icon.return_value = None + kwargs.setdefault("strategy_resolver", MagicMock()) + kwargs.setdefault("presentation_provider", presentation_provider) + kwargs.setdefault("runtime_support", MagicMock()) + kwargs.setdefault("message_transformer", MagicMock()) super().__init__( id=id, @@ -557,8 +596,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from core.workflow.nodes.iteration import IterationNode -from core.workflow.nodes.loop import LoopNode +from dify_graph.nodes.iteration import IterationNode +from dify_graph.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -572,24 +611,20 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) @@ -621,7 +656,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from core.workflow.nodes.iteration.exc import IterationGraphNotFoundError + from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -648,24 +683,20 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from core.workflow.entities import GraphInitParams - from core.workflow.graph import Graph - from core.workflow.graph_engine import GraphEngine, GraphEngineConfig - from core.workflow.graph_engine.command_channels import InMemoryChannel - from core.workflow.runtime import GraphRuntimeState + from dify_graph.entities import GraphInitParams + from dify_graph.graph import Graph + from dify_graph.graph_engine import GraphEngine, GraphEngineConfig + from dify_graph.graph_engine.command_channels import InMemoryChannel + from dify_graph.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory # Create GraphInitParams from node attributes graph_init_params = GraphInitParams( - tenant_id=self.tenant_id, - app_id=self.app_id, workflow_id=self.workflow_id, graph_config=self.graph_config, - user_id=self.user_id, - user_from=self.user_from.value, - invoke_from=self.invoke_from.value, + run_context=self.run_context, call_depth=self.workflow_call_depth, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index 6c4178dfed..a8398e8f79 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,8 +6,9 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.code.limits import CodeNodeLimits +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode @@ -39,18 +40,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_default_output(self): """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -98,18 +103,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_custom_output(self): """Test that MockTemplateTransformNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -158,18 +167,22 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_error_simulation(self): """Test that MockTemplateTransformNode can simulate errors.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -215,19 +228,23 @@ class TestMockTemplateTransformNode: def test_mock_template_transform_node_with_variables(self): """Test that MockTemplateTransformNode processes templates with variables.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from core.workflow.variables import StringVariable + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool + from dify_graph.variables import StringVariable # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -281,18 +298,22 @@ class TestMockCodeNode: def test_mock_code_node_default_output(self): """Test that MockCodeNode returns default output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -343,18 +364,22 @@ class TestMockCodeNode: def test_mock_code_node_with_output_schema(self): """Test that MockCodeNode generates outputs based on schema.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -413,18 +438,22 @@ class TestMockCodeNode: def test_mock_code_node_custom_output(self): """Test that MockCodeNode returns custom configured output.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -485,18 +514,22 @@ class TestMockNodeFactory: def test_code_and_template_nodes_mocked_by_default(self): """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -517,27 +550,31 @@ class TestMockNodeFactory: ) # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) def test_factory_creates_mock_template_transform_node(self): """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -573,22 +610,26 @@ class TestMockNodeFactory: # Verify the correct mock type was created assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) def test_factory_creates_mock_code_node(self): """Test that MockNodeFactory creates MockCodeNode for code type.""" - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool # Create test parameters graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config={}, - user_id="test_user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -626,4 +667,4 @@ class TestMockNodeFactory: # Verify the correct mock type was created assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(NodeType.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 1b781545f5..5b35b3310a 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -3,13 +3,9 @@ Simple test to validate the auto-mock system without external dependencies. """ import sys -from pathlib import Path -# Add api directory to path -api_dir = Path(__file__).parent.parent.parent.parent.parent.parent -sys.path.insert(0, str(api_dir)) - -from core.workflow.enums import NodeType +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory @@ -68,8 +64,8 @@ def test_mock_config_operations(): assert error_config.error == "Test error" # Test default configs by node type - config.set_default_config(NodeType.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(NodeType.LLM) + config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) + llm_config = config.get_default_config(BuiltinNodeTypes.LLM) assert llm_config == {"temperature": 0.7} print("✓ MockConfig operations test passed") @@ -101,21 +97,24 @@ def test_node_mock_config(): def test_mock_factory_detection(): """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory detection...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -131,44 +130,47 @@ def test_mock_factory_detection(): ) # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(NodeType.LLM) - assert factory.should_mock_node(NodeType.AGENT) - assert factory.should_mock_node(NodeType.TOOL) - assert factory.should_mock_node(NodeType.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(NodeType.HTTP_REQUEST) - assert factory.should_mock_node(NodeType.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(NodeType.DOCUMENT_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.LLM) + assert factory.should_mock_node(BuiltinNodeTypes.AGENT) + assert factory.should_mock_node(BuiltinNodeTypes.TOOL) + assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) + assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) + assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) + assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(NodeType.CODE) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.CODE) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Test that non-service nodes are not mocked - assert not factory.should_mock_node(NodeType.START) - assert not factory.should_mock_node(NodeType.END) - assert not factory.should_mock_node(NodeType.IF_ELSE) - assert not factory.should_mock_node(NodeType.VARIABLE_AGGREGATOR) + assert not factory.should_mock_node(BuiltinNodeTypes.START) + assert not factory.should_mock_node(BuiltinNodeTypes.END) + assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) + assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) print("✓ MockNodeFactory detection test passed") def test_mock_factory_registration(): """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom - from core.workflow.entities import GraphInitParams - from core.workflow.runtime import GraphRuntimeState, VariablePool - from models.enums import UserFrom + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + from dify_graph.entities import GraphInitParams + from dify_graph.runtime import GraphRuntimeState, VariablePool print("Testing MockNodeFactory registration...") graph_init_params = GraphInitParams( - tenant_id="test", - app_id="test", workflow_id="test", graph_config={}, - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) graph_runtime_state = GraphRuntimeState( @@ -184,18 +186,18 @@ def test_mock_factory_registration(): ) # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Unregister mock - factory.unregister_mock_node_type(NodeType.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) + assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) # Register custom mock (using a dummy class for testing) class DummyMockNode: pass - factory.register_mock_node_type(NodeType.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(NodeType.TEMPLATE_TRANSFORM) + factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) + assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) print("✓ MockNodeFactory registration test passed") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index a6aab81f6c..e681b39cc7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,34 +4,34 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params class PauseStateStore(Protocol): @@ -126,11 +126,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 62aa56fc57..60167c0441 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,41 +4,41 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunPauseRequestedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -129,11 +129,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index a93d03c87e..b954a4faac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -12,25 +12,24 @@ import time from unittest.mock import MagicMock, patch from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, ) -from core.workflow.node_events import NodeRunResult, StreamCompletedEvent -from core.workflow.nodes.llm.node import LLMNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.node_events import NodeRunResult, StreamCompletedEvent +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -87,11 +86,11 @@ def test_parallel_streaming_workflow(): graph_config = workflow_config.get("graph", {}) # Create graph initialization parameters - init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + init_params = build_test_graph_init_params( workflow_id="test_workflow", graph_config=graph_config, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.WEB_APP, @@ -100,8 +99,8 @@ def test_parallel_streaming_workflow(): # Create variable pool with system variables system_variables = SystemVariable( - user_id=init_params.user_id, - app_id=init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=init_params.workflow_id, files=[], query="Tell me about yourself", # User query @@ -119,7 +118,11 @@ def test_parallel_streaming_workflow(): with patch.object( DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True ): - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=get_default_root_node_id(graph_config), + ) # Create the graph engine engine = GraphEngine( @@ -165,7 +168,9 @@ def test_parallel_streaming_workflow(): stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] # Get Answer node start event - answer_start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == NodeType.ANSWER] + answer_start_events = [ + e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER + ] assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" answer_start_event = answer_start_events[0] @@ -212,7 +217,9 @@ def test_parallel_streaming_workflow(): # Get LLM completion events llm_completed_events = [ - (i, e) for i, e in enumerate(events) if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.LLM + (i, e) + for i, e in enumerate(events) + if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM ] # Check LLM completion order - in the current implementation, LLMs run sequentially @@ -264,7 +271,7 @@ def test_parallel_streaming_workflow(): # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' # This means LLM 2 output should come first, then LLM 1 output answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == NodeType.ANSWER + e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER ] assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 156cfefcd6..7328ce443f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,42 +4,42 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.config import GraphEngineConfig -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.config import GraphEngineConfig +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, ModelConfig, VisionConfig, ) -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig, NodeMockConfig from .test_mock_nodes import MockLLMNode @@ -121,11 +121,11 @@ def _build_runtime_state() -> GraphRuntimeState: def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 700b3f4b8b..15a7de3c52 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,33 +3,33 @@ import time from typing import Any from unittest.mock import MagicMock -from core.workflow.entities import GraphInitParams -from core.workflow.entities.workflow_start_reason import WorkflowStartReason -from core.workflow.graph import Graph -from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from core.workflow.graph_engine.graph_engine import GraphEngine -from core.workflow.graph_events import ( +from dify_graph.entities.workflow_start_reason import WorkflowStartReason +from dify_graph.graph import Graph +from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel +from dify_graph.graph_engine.graph_engine import GraphEngine +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, NodeRunSucceededEvent, ) -from core.workflow.graph_events.graph import GraphRunStartedEvent -from core.workflow.nodes.base.entities import OutputVariableEntity -from core.workflow.nodes.end.end_node import EndNode -from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.repositories.human_input_form_repository import ( +from dify_graph.graph_events.graph import GraphRunStartedEvent +from dify_graph.nodes.base.entities import OutputVariableEntity +from dify_graph.nodes.end.end_node import EndNode +from dify_graph.nodes.end.entities import EndNodeData +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.repositories.human_input_form_repository import ( HumanInputFormEntity, HumanInputFormRepository, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now +from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: @@ -79,11 +79,11 @@ def _build_human_input_graph( form_repository: HumanInputFormRepository, ) -> Graph: graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = GraphInitParams( - tenant_id="tenant", - app_id="app", + params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from="account", invoke_from="service-api", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py index 0920940e51..9c84f42db6 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py @@ -12,9 +12,9 @@ import pytest import redis from core.app.apps.base_app_queue_manager import AppQueueManager -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from core.workflow.graph_engine.manager import GraphEngineManager +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from dify_graph.graph_engine.manager import GraphEngineManager class TestRedisStopIntegration: diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py new file mode 100644 index 0000000000..cd9d56f683 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py @@ -0,0 +1,55 @@ +"""Unit tests for response session creation.""" + +from __future__ import annotations + +import pytest + +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType +from dify_graph.graph_engine.response_coordinator.session import ResponseSession +from dify_graph.nodes.base.template import Template, TextSegment + + +class DummyResponseNode: + """Minimal response-capable node for session tests.""" + + def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: + self.id = node_id + self.node_type = node_type + self.execution_type = NodeExecutionType.RESPONSE + self.state = NodeState.UNKNOWN + self._template = template + + def get_streaming_template(self) -> Template: + return self._template + + +class DummyNodeWithoutStreamingTemplate: + """Minimal node that violates the response-session contract.""" + + def __init__(self, *, node_id: str, node_type: NodeType) -> None: + self.id = node_id + self.node_type = node_type + self.execution_type = NodeExecutionType.RESPONSE + self.state = NodeState.UNKNOWN + + +def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: + """Session creation depends on the streaming-template contract rather than node type.""" + node = DummyResponseNode( + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + template=Template(segments=[TextSegment(text="hello")]), + ) + + session = ResponseSession.from_node(node) + + assert session.node_id == "llm-node" + assert session.template.segments == [TextSegment(text="hello")] + + +def test_response_session_from_node_requires_streaming_template_method() -> None: + """Allowed node types still need to implement the streaming-template contract.""" + node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) + + with pytest.raises(TypeError, match="get_streaming_template"): + ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 99157a7c3e..4f1741d4fb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -1,4 +1,4 @@ -from core.workflow.graph_events import ( +from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunStartedEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index 5cbb7cf36e..ab8fb346b8 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,27 +12,29 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Sequence +from collections.abc import Callable, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, cast -from core.app.workflow.node_factory import DifyNodeFactory +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ( +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -48,6 +50,47 @@ from .test_mock_factory import MockNodeFactory logger = logging.getLogger(__name__) +class _TableTestChildEngineBuilder: + def __init__(self, *, use_mock_factory: bool, mock_config: MockConfig | None) -> None: + self._use_mock_factory = use_mock_factory + self._mock_config = mock_config + + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> GraphEngine: + if self._use_mock_factory: + node_factory = MockNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + mock_config=self._mock_config, + ) + else: + node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + + child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) + if not child_graph: + raise ValueError("child graph not found") + + child_engine = GraphEngine( + workflow_id=workflow_id, + graph=child_graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + child_engine_builder=self, + ) + for layer in layers: + child_engine.layer(cast(GraphEngineLayer, layer)) + return child_engine + + @dataclass class WorkflowTestCase: """Represents a single test case for table-driven testing.""" @@ -149,19 +192,23 @@ class WorkflowRunner: raise ValueError("Fixture missing workflow.graph configuration") graph_init_params = GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", workflow_id="test_workflow", graph_config=graph_config, - user_id="test_user", - user_from="account", - invoke_from="debugger", # Set to debugger to avoid conversation_id requirement + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, # Set to debugger to avoid conversation_id requirement + } + }, call_depth=0, ) system_variables = SystemVariable( - user_id=graph_init_params.user_id, - app_id=graph_init_params.app_id, + user_id="test_user", + app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, @@ -210,7 +257,11 @@ class WorkflowRunner: else: node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init( + graph_config=graph_config, + node_factory=node_factory, + root_node_id=get_default_root_node_id(graph_config), + ) return graph, graph_runtime_state @@ -315,6 +366,10 @@ class TableTestRunner: scale_up_threshold=self.graph_engine_scale_up_threshold, scale_down_idle_time=self.graph_engine_scale_down_idle_time, ), + child_engine_builder=_TableTestChildEngineBuilder( + use_mock_factory=test_case.use_auto_mock, + mock_config=test_case.mock_config, + ), ) # Execute and collect events diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index bfcc6e1a5f..7f26bc11a7 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from core.workflow.graph_engine import GraphEngine, GraphEngineConfig -from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_events import ( +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py index 221e1291d1..f63e8ff4ce 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py @@ -2,9 +2,9 @@ from unittest.mock import patch import pytest -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from .test_table_runner import TableTestRunner, WorkflowTestCase diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py new file mode 100644 index 0000000000..bc00b49fba --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py @@ -0,0 +1,145 @@ +import queue +from collections.abc import Generator +from datetime import UTC, datetime, timedelta +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue +from dify_graph.graph_engine.worker import Worker +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent + + +def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: + fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) + + worker = Worker( + ready_queue=InMemoryReadyQueue(), + event_queue=queue.Queue(), + graph=MagicMock(), + layers=[], + ) + node = SimpleNamespace( + execution_id="exec-1", + id="node-1", + node_type=BuiltinNodeTypes.LLM, + ) + + event = worker._build_fallback_failure_event(node, RuntimeError("boom")) + + assert event.start_at == fixed_time + assert event.finished_at == fixed_time + assert event.error == "boom" + assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert event.node_run_result.error == "boom" + assert event.node_run_result.error_type == "RuntimeError" + + +def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: + start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + failure_time = start_at + timedelta(seconds=5) + captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] + + class FakeNode: + execution_id = "exec-1" + id = "node-1" + node_type = BuiltinNodeTypes.LLM + + def ensure_execution_id(self) -> str: + return self.execution_id + + def run(self) -> Generator[NodeRunStartedEvent, None, None]: + yield NodeRunStartedEvent( + id=self.execution_id, + node_id=self.id, + node_type=self.node_type, + node_title="LLM", + start_at=start_at, + ) + + worker = Worker( + ready_queue=MagicMock(), + event_queue=MagicMock(), + graph=MagicMock(nodes={"node-1": FakeNode()}), + layers=[], + ) + + worker._ready_queue.get.side_effect = ["node-1"] + + def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: + captured_events.append(event) + if len(captured_events) == 1: + raise RuntimeError("queue boom") + worker.stop() + + worker._event_queue.put.side_effect = put_side_effect + + with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + worker.run() + + fallback_event = captured_events[-1] + + assert isinstance(fallback_event, NodeRunFailedEvent) + assert fallback_event.start_at == start_at + assert fallback_event.finished_at == failure_time + assert fallback_event.error == "queue boom" + assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED + + +def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: + parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) + child_start = parent_start + timedelta(seconds=3) + failure_time = parent_start + timedelta(seconds=5) + captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] + + class FakeIterationNode: + execution_id = "iteration-exec" + id = "iteration-node" + node_type = BuiltinNodeTypes.ITERATION + + def ensure_execution_id(self) -> str: + return self.execution_id + + def run(self) -> Generator[NodeRunStartedEvent, None, None]: + yield NodeRunStartedEvent( + id=self.execution_id, + node_id=self.id, + node_type=self.node_type, + node_title="Iteration", + start_at=parent_start, + ) + yield NodeRunStartedEvent( + id="child-exec", + node_id="child-node", + node_type=BuiltinNodeTypes.LLM, + node_title="LLM", + start_at=child_start, + in_iteration_id=self.id, + ) + + worker = Worker( + ready_queue=MagicMock(), + event_queue=MagicMock(), + graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), + layers=[], + ) + + worker._ready_queue.get.side_effect = ["iteration-node"] + + def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: + captured_events.append(event) + if len(captured_events) == 2: + raise RuntimeError("queue boom") + worker.stop() + + worker._event_queue.put.side_effect = put_side_effect + + with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + worker.run() + + fallback_event = captured_events[-1] + + assert isinstance(fallback_event, NodeRunFailedEvent) + assert fallback_event.start_at == parent_start + assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index 1e95ec1970..fd563d1be2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,16 +2,15 @@ import time import uuid from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.graph import Graph -from core.workflow.nodes.answer.answer_node import AnswerNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.answer.answer_node import AnswerNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_answer(): @@ -36,11 +35,11 @@ def test_execute_answer(): ], } - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -65,7 +64,7 @@ def test_execute_answer(): graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "answer", diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 21a642c2f8..81d3f5be9c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,14 +1,12 @@ import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.node import Node -# Ensures that all node classes are imported. -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING - -# Ensure `NODE_TYPE_CLASSES_MAPPING` is used and not automatically removed. -_ = NODE_TYPE_CLASSES_MAPPING +# Ensures that all production node classes are imported and registered. +_ = get_node_type_classes_mapping() class _TestNodeData(BaseNodeData): @@ -43,7 +41,7 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined node_type = cls.node_type node_version = cls.version() - assert isinstance(cls.node_type, NodeType) + assert isinstance(cls.node_type, str) assert isinstance(node_version, str) node_type_and_version = (node_type, node_version) assert node_type_and_version not in type_version_set, ( @@ -56,7 +54,7 @@ def test_extract_node_data_type_from_generic_extracts_type(): """When a class inherits from Node[T], it should extract T.""" class _ConcreteNode(Node[_TestNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: @@ -108,7 +106,7 @@ def test_init_subclass_rejects_explicit_node_data_type_without_generic(): class _ExplicitNode(Node): _node_data_type = _TestNodeData - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: @@ -119,10 +117,27 @@ def test_init_subclass_sets_node_data_type_from_generic(): """Verify that __init_subclass__ sets _node_data_type from the generic parameter.""" class _AutoNode(Node[_TestNodeData]): - node_type = NodeType.CODE + node_type = BuiltinNodeTypes.CODE @staticmethod def version() -> str: return "1" assert _AutoNode._node_data_type is _TestNodeData + + +def test_validate_node_data_uses_declared_node_data_type(): + """Public validation should hydrate the subclass-declared node data model.""" + + class _AutoNode(Node[_TestNodeData]): + node_type = BuiltinNodeTypes.CODE + + @staticmethod + def version() -> str: + return "1" + + base_node_data = BaseNodeData.model_validate({"type": BuiltinNodeTypes.CODE, "title": "Test"}) + + validated = _AutoNode.validate_node_data(base_node_data) + + assert isinstance(validated, _TestNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 45d222b98c..972a945ca0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,26 +1,27 @@ import types from collections.abc import Mapping -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node +from core.workflow.node_factory import get_node_type_classes_mapping +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.enums import BuiltinNodeTypes, NodeType +from dify_graph.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from core.workflow.nodes.variable_assigner.v1.node import ( +from dify_graph.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from core.workflow.nodes.variable_assigner.v2.node import ( +from dify_graph.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act - mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() # Assert basic presence - assert NodeType.VARIABLE_ASSIGNER in mapping - va_versions = mapping[NodeType.VARIABLE_ASSIGNER] + assert BuiltinNodeTypes.VARIABLE_ASSIGNER in mapping + va_versions = mapping[BuiltinNodeTypes.VARIABLE_ASSIGNER] # Both concrete versions must be present assert va_versions.get("1") is VariableAssignerV1 @@ -34,7 +35,7 @@ def test_latest_prefers_highest_numeric_version(): # Arrange: define two ephemeral subclasses with numeric versions under a NodeType # that has no concrete implementations in production to avoid interference. class _Version1(Node[BaseNodeData]): # type: ignore[misc] - node_type = NodeType.LEGACY_VARIABLE_AGGREGATOR + node_type = BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR def init_node_data(self, data): pass @@ -73,11 +74,11 @@ def test_latest_prefers_highest_numeric_version(): return "version2" # Act: build a fresh mapping (it should now see our ephemeral subclasses) - mapping: Mapping[NodeType, Mapping[str, type[Node]]] = Node.get_node_type_classes_mapping() + mapping: Mapping[NodeType, Mapping[str, type[Node]]] = get_node_type_classes_mapping() # Assert: both numeric versions exist for this NodeType; 'latest' points to the higher numeric version - assert NodeType.LEGACY_VARIABLE_AGGREGATOR in mapping - legacy_versions = mapping[NodeType.LEGACY_VARIABLE_AGGREGATOR] + assert BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR in mapping + legacy_versions = mapping[BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR] assert legacy_versions.get("1") is _Version1 assert legacy_versions.get("2") is _Version2 diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 00c8cb3779..784e08edd2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,13 @@ from configs import dify_config -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.nodes.code.exc import ( +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.variables.types import SegmentType CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, @@ -272,7 +272,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result == {} @@ -292,7 +292,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_1", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert "node_1.input_text" in result @@ -315,7 +315,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="code_node", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert len(result) == 3 @@ -338,7 +338,7 @@ class TestCodeNodeExtractVariableSelector: result = CodeNode._extract_variable_selector_to_variable_mapping( graph_config={}, node_id="node_x", - node_data=node_data, + node_data=CodeNodeData.model_validate(node_data, from_attributes=True), ) assert result["node_x.deep_var"] == ["node", "obj", "nested", "value"] @@ -437,7 +437,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.title == "Test Node" assert node._node_data.code_language == CodeLanguage.PYTHON3 @@ -453,7 +453,7 @@ class TestCodeNodeInitialization: "outputs": {"x": {"type": "number"}}, } - node._node_data = node._hydrate_node_data(data) + node._node_data = CodeNode._node_data_type.model_validate(data, from_attributes=True) assert node._node_data.code_language == CodeLanguage.JAVASCRIPT diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py index 28d59c3568..de7ed0815e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from core.workflow.nodes.code.entities import CodeLanguage, CodeNodeData -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData +from dify_graph.variables.types import SegmentType class TestCodeNodeDataOutput: diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 584ed23e91..859115ceb3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,6 +1,7 @@ -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from core.workflow.nodes.datasource.datasource_node import DatasourceNode +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: @@ -28,13 +29,17 @@ class _GraphState: class _GraphParams: - tenant_id = "t1" - app_id = "app-1" workflow_id = "wf-1" graph_config = {} - user_id = "u1" - user_from = "account" - invoke_from = "debugger" + run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "t1", + "app_id": "app-1", + "user_id": "u1", + "user_from": "account", + "invoke_from": "debugger", + } + } call_depth = 0 @@ -69,6 +74,8 @@ def test_datasource_node_delegates_to_manager_stream(mocker): def get_upload_file_by_id(cls, **_): raise AssertionError("not called") + mocker.patch("core.workflow.nodes.datasource.datasource_node.DatasourceManager", new=_Mgr) + node = DatasourceNode( id="n", config={ @@ -85,7 +92,6 @@ def test_datasource_node_delegates_to_manager_stream(mocker): }, graph_init_params=gp, graph_runtime_state=gs, - datasource_manager=_Mgr, ) evts = list(node._run()) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py index 90f4cd018b..cd822a6f89 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.http_request import build_http_request_config +from dify_graph.nodes.http_request import build_http_request_config def test_build_http_request_config_uses_literal_defaults(): diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py index 47a5df92a4..fec6ad90eb 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py @@ -4,7 +4,7 @@ from unittest.mock import Mock, PropertyMock, patch import httpx import pytest -from core.workflow.nodes.http_request.entities import Response +from dify_graph.nodes.http_request.entities import Response @pytest.fixture @@ -104,7 +104,7 @@ def test_mimetype_based_detection(mock_response, content_type, expected_main_typ mock_response.headers = {"content-type": content_type} type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - with patch("core.workflow.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: + with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: # Mock the return value based on expected_main_type if expected_main_type: mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 67da890eb2..cea7195417 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,19 +2,19 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy -from core.workflow.file.file_manager import file_manager -from core.workflow.nodes.http_request import ( +from dify_graph.file.file_manager import file_manager +from dify_graph.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout -from core.workflow.nodes.http_request.exc import AuthorizationConfigError -from core.workflow.nodes.http_request.executor import Executor -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout +from dify_graph.nodes.http_request.exc import AuthorizationConfigError +from dify_graph.nodes.http_request.executor import Executor +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index cad0466809..5e34bf1d94 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -4,17 +4,16 @@ from typing import Any import httpx import pytest -from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file.file_manager import file_manager -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file.file_manager import file_manager +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=10, @@ -99,11 +98,11 @@ def _build_http_node( ], "edges": [], } - graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params = build_test_graph_init_params( workflow_id="workflow", graph_config=graph_config, + tenant_id="tenant", + app_id="app", user_id="user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -162,7 +161,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("core.workflow.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index ca4a887d20..d52dfa2a65 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,5 @@ -from core.workflow.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from core.workflow.runtime import VariablePool +from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from dify_graph.runtime import VariablePool def test_render_body_template_replaces_variable_values(): @@ -14,3 +14,64 @@ def test_render_body_template_replaces_variable_values(): result = config.render_body_template(body=config.body, url="https://example.com", variable_pool=variable_pool) assert result == "Hello World https://example.com" + + +def test_render_markdown_body_renders_markdown_to_html(): + rendered = EmailDeliveryConfig.render_markdown_body("**Bold** and [link](https://example.com)") + + assert "Bold" in rendered + assert 'link' in rendered + + +def test_render_markdown_body_sanitizes_unsafe_html(): + rendered = EmailDeliveryConfig.render_markdown_body( + 'Click' + ) + + assert "bad" in rendered + assert 'ok' in rendered + + +def test_render_markdown_body_does_not_allow_raw_html_tags(): + rendered = EmailDeliveryConfig.render_markdown_body("raw html and **markdown**") + + assert "" not in rendered + assert "raw html" in rendered + assert "markdown" in rendered + + +def test_render_markdown_body_supports_table_syntax(): + rendered = EmailDeliveryConfig.render_markdown_body("| h1 | h2 |\n| --- | ---: |\n| v1 | v2 |") + + assert "" in rendered + assert "" in rendered + assert "" in rendered + assert 'align="right"' in rendered + assert "style=" not in rendered + + +def test_sanitize_subject_removes_crlf(): + sanitized = EmailDeliveryConfig.sanitize_subject("Notice\r\nBCC:attacker@example.com") + + assert "\r" not in sanitized + assert "\n" not in sanitized + assert sanitized == "Notice BCC:attacker@example.com" + + +def test_sanitize_subject_removes_html_tags(): + sanitized = EmailDeliveryConfig.sanitize_subject("Alert") + + assert "<" not in sanitized + assert ">" not in sanitized + assert sanitized == "Alert" diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index bfe7b03c13..55aa62a1c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,10 +8,11 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError -from core.workflow.entities import GraphInitParams -from core.workflow.node_events import PauseRequestedEvent -from core.workflow.node_events.node import StreamCompletedEvent -from core.workflow.nodes.human_input.entities import ( +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.node_events import PauseRequestedEvent +from dify_graph.node_events.node import StreamCompletedEvent +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -24,7 +25,7 @@ from core.workflow.nodes.human_input.entities import ( WebAppDeliveryMethod, _WebAppDeliveryConfig, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( ButtonStyle, DeliveryMethodType, EmailRecipientType, @@ -32,10 +33,10 @@ from core.workflow.nodes.human_input.enums import ( PlaceholderType, TimeoutUnit, ) -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.repositories.human_input_form_repository import HumanInputFormRepository -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -314,13 +315,17 @@ class TestHumanInputNodeVariableResolution: variable_pool.add(("start", "name"), "Jane Doe") runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -384,13 +389,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -439,13 +448,17 @@ class TestHumanInputNodeVariableResolution: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user-123", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user-123", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) @@ -550,13 +563,17 @@ class TestHumanInputNodeRenderedContent: ) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from="account", - invoke_from="debugger", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": "account", + "invoke_from": "debugger", + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index a19ee4dee3..b0ed47158d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,20 +1,19 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.graph_events import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus -from core.workflow.nodes.human_input.human_input_node import HumanInputNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.human_input_node import HumanInputNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now -from models.enums import UserFrom class _FakeFormRepository: @@ -32,19 +31,23 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) config = { "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, + "type": BuiltinNodeTypes.HUMAN_INPUT, "data": { "title": "Human Input", "form_content": form_content, @@ -92,19 +95,23 @@ def _build_timeout_node() -> HumanInputNode: start_at=0.0, ) graph_init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", workflow_id="workflow", graph_config={"nodes": [], "edges": []}, - user_id="user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) config = { "id": "node-1", - "type": NodeType.HUMAN_INPUT.value, + "type": BuiltinNodeTypes.HUMAN_INPUT, "data": { "title": "Human Input", "form_content": "Please enter your name:\n\n{{#$output.name#}}", diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py index d669cc7465..93c199514e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py @@ -1,4 +1,4 @@ -from core.workflow.nodes.iteration.entities import ( +from dify_graph.nodes.iteration.entities import ( ErrorHandleMode, IterationNodeData, IterationStartNodeData, diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py index b67e84d1d4..fdf5f4d1f8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py @@ -1,6 +1,7 @@ -from core.workflow.enums import NodeType -from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from core.workflow.nodes.iteration.exc import ( +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.exc import ( InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -8,7 +9,7 @@ from core.workflow.nodes.iteration.exc import ( IteratorVariableNotFoundError, StartNodeIdNotFoundError, ) -from core.workflow.nodes.iteration.iteration_node import IterationNode +from dify_graph.nodes.iteration.iteration_node import IterationNode class TestIterationNodeExceptions: @@ -90,7 +91,7 @@ class TestIterationNodeClassAttributes: def test_node_type(self): """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == NodeType.ITERATION + assert IterationNode.node_type == BuiltinNodeTypes.ITERATION def test_version(self): """Test IterationNode version method.""" @@ -388,3 +389,50 @@ class TestIterationNodeErrorStrategies: result = node._get_default_value_dict() assert isinstance(result, dict) + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "iteration_id": "iteration-node", + }, + } + + IterationNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration", "result"], + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="iteration-node", + node_data=IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration", "result"], + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py new file mode 100644 index 0000000000..2eb4feef5f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -0,0 +1,100 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +import pytest + +from dify_graph.entities import GraphInitParams +from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from dify_graph.runtime import ( + ChildEngineBuilderNotConfiguredError, + ChildGraphNotFoundError, + GraphRuntimeState, + VariablePool, +) +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +class _MissingGraphBuilder: + def build_child_engine( + self, + *, + workflow_id: str, + graph_init_params: GraphInitParams, + graph_runtime_state: GraphRuntimeState, + graph_config: Mapping[str, Any], + root_node_id: str, + layers: Sequence[object] = (), + ) -> object: + raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") + + +def _build_runtime_state() -> GraphRuntimeState: + return GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + start_at=0.0, + ) + + +def _build_iteration_node( + *, + graph_config: Mapping[str, Any], + runtime_state: GraphRuntimeState, + start_node_id: str, +) -> IterationNode: + init_params = build_test_graph_init_params(graph_config=graph_config) + return IterationNode( + id="iteration-node", + config={ + "id": "iteration-node", + "data": { + "type": "iteration", + "title": "Iteration", + "iterator_selector": ["start", "items"], + "output_selector": ["iteration-node", "output"], + "start_node_id": start_node_id, + }, + }, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + +def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing(): + runtime_state = _build_runtime_state() + graph_init_params = build_test_graph_init_params() + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + runtime_state.create_child_engine( + workflow_id="workflow", + graph_init_params=graph_init_params, + graph_runtime_state=_build_runtime_state(), + graph_config={}, + root_node_id="root", + ) + + +def test_iteration_node_only_translates_child_graph_not_found_error(): + runtime_state = _build_runtime_state() + runtime_state.bind_child_engine_builder(_MissingGraphBuilder()) + node = _build_iteration_node( + graph_config={"nodes": [{"id": "present-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="missing-node", + ) + + with pytest.raises(IterationGraphNotFoundError): + node._create_graph_engine(index=0, item="item") + + +def test_iteration_node_propagates_non_graph_not_found_errors(): + runtime_state = _build_runtime_state() + node = _build_iteration_node( + graph_config={"nodes": [{"id": "start-node"}], "edges": []}, + runtime_state=runtime_state, + start_node_id="start-node", + ) + + with pytest.raises(ChildEngineBuilderNotConfiguredError): + node._create_graph_engine(index=0, item="item") diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py new file mode 100644 index 0000000000..8660449032 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -0,0 +1,63 @@ +import time +from contextlib import nullcontext +from datetime import UTC, datetime + +import pytest + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.graph_events import NodeRunSucceededEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.iteration_node import IterationNode + + +def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: + node = IterationNode.__new__(IterationNode) + node._node_data = IterationNodeData( + title="Parallel Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration", "output"], + is_parallel=True, + parallel_nums=2, + error_handle_mode=ErrorHandleMode.TERMINATED, + ) + node._capture_execution_context = lambda: nullcontext() + node._sync_conversation_variables_from_snapshot = lambda snapshot: None + node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) + + def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): + return ( + 0.1 + (index * 0.1), + [ + NodeRunSucceededEvent( + id=f"exec-{index}", + node_id=f"llm-{index}", + node_type=BuiltinNodeTypes.LLM, + start_at=datetime.now(UTC).replace(tzinfo=None), + ), + ], + f"output-{item}", + {}, + LLMUsage.empty_usage(), + ) + + node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + + outputs: list[object] = [] + iter_run_map: dict[str, float] = {} + usage_accumulator = [LLMUsage.empty_usage()] + + generator = node._execute_parallel_iterations( + iterator_list_value=["a", "b"], + outputs=outputs, + iter_run_map=iter_run_map, + usage_accumulator=usage_accumulator, + ) + + for _ in generator: + # Simulate a slow consumer replaying buffered events. + time.sleep(0.02) + + assert outputs == ["output-a", "output-b"] + assert iter_run_map["0"] == pytest.approx(0.1) + assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/__init__.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py new file mode 100644 index 0000000000..33f7ace5ab --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -0,0 +1,650 @@ +import time +import uuid +from unittest.mock import Mock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.nodes.knowledge_index.entities import KnowledgeIndexNodeData +from core.workflow.nodes.knowledge_index.exc import KnowledgeIndexNodeError +from core.workflow.nodes.knowledge_index.knowledge_index_node import KnowledgeIndexNode +from core.workflow.nodes.knowledge_index.protocols import ( + IndexProcessorProtocol, + Preview, + PreviewItem, + SummaryIndexServiceProtocol, +) +from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params + + +@pytest.fixture +def mock_graph_init_params(): + """Create mock GraphInitParams.""" + return build_test_graph_init_params( + workflow_id=str(uuid.uuid4()), + graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), + user_id=str(uuid.uuid4()), + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + """Create mock GraphRuntimeState.""" + variable_pool = VariablePool( + system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +@pytest.fixture +def mock_index_processor(mocker): + """Create mock IndexProcessorProtocol.""" + mock_processor = Mock(spec=IndexProcessorProtocol) + mocker.patch( + "core.workflow.nodes.knowledge_index.knowledge_index_node.IndexProcessor", + return_value=mock_processor, + ) + return mock_processor + + +@pytest.fixture +def mock_summary_index_service(mocker): + """Create mock SummaryIndexServiceProtocol.""" + mock_service = Mock(spec=SummaryIndexServiceProtocol) + mocker.patch( + "core.workflow.nodes.knowledge_index.knowledge_index_node.SummaryIndex", + return_value=mock_service, + ) + return mock_service + + +@pytest.fixture +def sample_node_data(): + """Create sample KnowledgeIndexNodeData.""" + return KnowledgeIndexNodeData( + title="Knowledge Index", + type="knowledge-index", + chunk_structure="general_structure", + index_chunk_variable_selector=["start", "chunks"], + indexing_technique="high_quality", + summary_index_setting=None, + ) + + +@pytest.fixture +def sample_chunks(): + """Create sample chunks data.""" + return { + "general_chunks": ["Chunk 1 content", "Chunk 2 content"], + "data_source_info": {"file_id": str(uuid.uuid4())}, + } + + +class TestKnowledgeIndexNode: + """ + Test suite for KnowledgeIndexNode. + """ + + def test_node_initialization( + self, mock_graph_init_params, mock_graph_runtime_state, mock_index_processor, mock_summary_index_service + ): + """Test KnowledgeIndexNode initialization.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Index", + "type": "knowledge-index", + "chunk_structure": "general_structure", + "index_chunk_variable_selector": ["start", "chunks"], + }, + } + + # Act + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Assert + assert node.id == node_id + assert node.index_processor == mock_index_processor + assert node.summary_index_service == mock_summary_index_service + + def test_run_without_dataset_id( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + """Test _run raises KnowledgeIndexNodeError when dataset_id is not provided.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act & Assert + with pytest.raises(KnowledgeIndexNodeError, match="Dataset ID is required"): + node._run() + + def test_run_without_index_chunk_variable( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + """Test _run raises KnowledgeIndexNodeError when index chunk variable is not provided.""" + # Arrange + dataset_id = str(uuid.uuid4()) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act & Assert + with pytest.raises(KnowledgeIndexNodeError, match="Index chunk variable is required"): + node._run() + + def test_run_with_empty_chunks( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + """Test _run fails when chunks is empty.""" + # Arrange + dataset_id = str(uuid.uuid4()) + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, StringSegment(value="")) + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Chunks is required" in result.error + + def test_run_preview_mode_success( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + sample_chunks, + ): + """Test _run succeeds in preview mode.""" + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DOCUMENT_ID], + StringSegment(value=document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.INVOKE_FROM], + StringSegment(value=InvokeFrom.DEBUGGER), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks) + + # Mock preview output + mock_preview = Preview( + chunk_structure="general_structure", + preview=[PreviewItem(content="Chunk 1"), PreviewItem(content="Chunk 2")], + total_segments=2, + ) + mock_index_processor.get_preview_output.return_value = mock_preview + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert mock_index_processor.get_preview_output.called + + def test_run_production_mode_success( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + sample_chunks, + ): + """Test _run succeeds in production mode.""" + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + original_document_id = str(uuid.uuid4()) + batch = "batch_123" + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DOCUMENT_ID], + StringSegment(value=document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID], + StringSegment(value=original_document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.BATCH], + StringSegment(value=batch), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.INVOKE_FROM], + StringSegment(value=InvokeFrom.SERVICE_API), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks) + + # Mock index_and_clean output + mock_index_processor.index_and_clean.return_value = {"status": "indexed"} + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs is not None + assert mock_summary_index_service.generate_and_vectorize_summary.called + assert mock_index_processor.index_and_clean.called + + def test_run_production_mode_without_batch( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + sample_chunks, + ): + """Test _run fails when batch is not provided in production mode.""" + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DOCUMENT_ID], + StringSegment(value=document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.INVOKE_FROM], + StringSegment(value=InvokeFrom.SERVICE_API), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks) + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Batch is required" in result.error + + def test_run_with_knowledge_index_node_error( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + sample_chunks, + ): + """Test _run handles KnowledgeIndexNodeError properly.""" + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + batch = "batch_123" + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DOCUMENT_ID], + StringSegment(value=document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.BATCH], + StringSegment(value=batch), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.INVOKE_FROM], + StringSegment(value=InvokeFrom.SERVICE_API), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks) + + # Mock to raise KnowledgeIndexNodeError + mock_index_processor.index_and_clean.side_effect = KnowledgeIndexNodeError("Indexing failed") + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Indexing failed" in result.error + assert result.error_type == "KnowledgeIndexNodeError" + + def test_run_with_generic_exception( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + sample_chunks, + ): + """Test _run handles generic exceptions properly.""" + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + batch = "batch_123" + chunks_selector = ["start", "chunks"] + + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DATASET_ID], + StringSegment(value=dataset_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.DOCUMENT_ID], + StringSegment(value=document_id), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.BATCH], + StringSegment(value=batch), + ) + mock_graph_runtime_state.variable_pool.add( + ["sys", SystemVariableKey.INVOKE_FROM], + StringSegment(value=InvokeFrom.SERVICE_API), + ) + mock_graph_runtime_state.variable_pool.add(chunks_selector, sample_chunks) + + # Mock to raise generic exception + mock_index_processor.index_and_clean.side_effect = Exception("Unexpected error") + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Unexpected error" in result.error + assert result.error_type == "Exception" + + def test_invoke_knowledge_index( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + original_document_id = str(uuid.uuid4()) + batch = "batch_123" + chunks = {"general_chunks": ["content"]} + + mock_index_processor.index_and_clean.return_value = {"status": "indexed"} + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._invoke_knowledge_index( + dataset_id=dataset_id, + document_id=document_id, + original_document_id=original_document_id, + is_preview=False, + batch=batch, + chunks=chunks, + summary_index_setting=None, + ) + + # Assert + assert mock_summary_index_service.generate_and_vectorize_summary.called + assert mock_index_processor.index_and_clean.called + assert result == {"status": "indexed"} + + def test_version_method(self): + """Test version class method.""" + # Act + version = KnowledgeIndexNode.version() + + # Assert + assert version == "1" + + def test_get_streaming_template( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + """Test get_streaming_template method.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + template = node.get_streaming_template() + + # Assert + assert template is not None + assert template.segments == [] + + +class TestInvokeKnowledgeIndex: + def test_invoke_with_summary_index_setting( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_index_processor, + mock_summary_index_service, + sample_node_data, + ): + # Arrange + dataset_id = str(uuid.uuid4()) + document_id = str(uuid.uuid4()) + original_document_id = str(uuid.uuid4()) + batch = "batch_123" + chunks = {"general_chunks": ["content"]} + summary_setting = {"enabled": True} + + mock_index_processor.index_and_clean.return_value = {"status": "indexed"} + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeIndexNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + # Act + result = node._invoke_knowledge_index( + dataset_id=dataset_id, + document_id=document_id, + original_document_id=original_document_id, + is_preview=False, + batch=batch, + chunks=chunks, + summary_index_setting=summary_setting, + ) + + # Assert + mock_summary_index_service.generate_and_vectorize_summary.assert_called_once_with( + dataset_id, document_id, False, summary_setting + ) + mock_index_processor.index_and_clean.assert_called_once_with( + dataset_id, document_id, original_document_id, chunks, batch, summary_setting + ) + assert result == {"status": "indexed"} diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index a60dde199d..99997db6b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -4,33 +4,34 @@ from unittest.mock import Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.model_runtime.entities.llm_entities import LLMUsage -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( + Condition, KnowledgeRetrievalNodeData, + MetadataFilteringCondition, MultipleRetrievalConfig, RerankingModelConfig, SingleRetrievalConfig, ) from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import StringSegment -from models.enums import UserFrom +from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import StringSegment +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def mock_graph_init_params(): """Create mock GraphInitParams.""" - return GraphInitParams( - tenant_id=str(uuid.uuid4()), - app_id=str(uuid.uuid4()), + return build_test_graph_init_params( workflow_id=str(uuid.uuid4()), graph_config={}, + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), user_id=str(uuid.uuid4()), user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -51,11 +52,15 @@ def mock_graph_runtime_state(): @pytest.fixture -def mock_rag_retrieval(): +def mock_rag_retrieval(mocker): """Create mock RAGRetrievalProtocol.""" mock_retrieval = Mock(spec=RAGRetrievalProtocol) mock_retrieval.knowledge_retrieval.return_value = [] mock_retrieval.llm_usage = LLMUsage.empty_usage() + mocker.patch( + "core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node.DatasetRetrieval", + return_value=mock_retrieval, + ) return mock_retrieval @@ -105,13 +110,11 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Assert assert node.id == node_id assert node._rag_retrieval == mock_rag_retrieval - assert node._llm_file_saver is not None def test_run_with_no_query_or_attachment( self, @@ -136,7 +139,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -155,7 +157,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from core.workflow.nodes.llm.entities import ModelConfig + from dify_graph.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -196,7 +198,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -206,6 +207,7 @@ class TestKnowledgeRetrievalNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert "result" in result.outputs assert mock_rag_retrieval.knowledge_retrieval.called + mock_source.model_dump.assert_called_once_with(by_alias=True) def test_run_with_query_variable_multiple_mode( self, @@ -240,7 +242,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -277,7 +278,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -313,7 +313,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -355,7 +354,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -395,7 +393,6 @@ class TestKnowledgeRetrievalNode: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -409,14 +406,14 @@ class TestKnowledgeRetrievalNode: """Test _extract_variable_selector_to_variable_mapping class method.""" # Arrange node_id = "knowledge_node_1" - node_data = { - "type": "knowledge-retrieval", - "title": "Knowledge Retrieval", - "dataset_ids": [str(uuid.uuid4())], - "retrieval_mode": "multiple", - "query_variable_selector": ["start", "query"], - "query_attachment_selector": ["start", "attachments"], - } + node_data = KnowledgeRetrievalNodeData( + type="knowledge-retrieval", + title="Knowledge Retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + query_variable_selector=["start", "query"], + query_attachment_selector=["start", "attachments"], + ) graph_config = {} # Act @@ -444,7 +441,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from core.workflow.nodes.llm.entities import ModelConfig + from dify_graph.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} @@ -477,7 +474,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -515,7 +511,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -571,7 +566,6 @@ class TestFetchDatasetRetriever: config=config, graph_init_params=mock_graph_init_params, graph_runtime_state=mock_graph_runtime_state, - rag_retrieval=mock_rag_retrieval, ) # Act @@ -593,3 +587,104 @@ class TestFetchDatasetRetriever: # Assert assert version == "1" + + def test_resolve_metadata_filtering_conditions_templates( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_resolve_metadata_filtering_conditions should expand {{#...#}} and keep numbers/None unchanged.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Retrieval", + "type": "knowledge-retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + }, + } + # Variable in pool used by template + mock_graph_runtime_state.variable_pool.add(["start", "query"], StringSegment(value="readme")) + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + conditions = MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.query#}}"), + Condition(name="tags", comparison_operator="in", value=["x", "{{#start.query#}}"]), + Condition(name="year", comparison_operator="=", value=2025), + ], + ) + + # Act + resolved = node._resolve_metadata_filtering_conditions(conditions) + + # Assert + assert resolved.logical_operator == "and" + assert resolved.conditions[0].value == "readme" + assert isinstance(resolved.conditions[1].value, list) + assert resolved.conditions[1].value[1] == "readme" + assert resolved.conditions[2].value == 2025 + + def test_fetch_passes_resolved_metadata_conditions( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """_fetch_dataset_retriever should pass resolved metadata conditions into request.""" + # Arrange + query = "hi" + variables = {"query": query} + mock_graph_runtime_state.variable_pool.add(["start", "q"], StringSegment(value="readme")) + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=4, + score_threshold=0.0, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RerankingModelConfig(provider="cohere", model="rerank-v2"), + ), + metadata_filtering_mode="manual", + metadata_filtering_conditions=MetadataFilteringCondition( + logical_operator="and", + conditions=[ + Condition(name="document_name", comparison_operator="is", value="{{#start.q#}}"), + ], + ), + ) + + node_id = str(uuid.uuid4()) + config = {"id": node_id, "data": node_data.model_dump()} + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + ) + + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + # Act + node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert the passed request has resolved value + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.metadata_filtering_conditions is not None + assert request.metadata_filtering_conditions.conditions[0].value == "readme" diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index 63a87623da..d71e0921c1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.list_operator.node import ListOperatorNode -from core.workflow.variables import ArrayNumberSegment, ArrayStringSegment -from models.workflow import WorkflowType +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.nodes.list_operator.node import ListOperatorNode +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment class TestListOperatorNode: @@ -22,43 +21,40 @@ class TestListOperatorNode: mock_state.variable_pool = mock_variable_pool return mock_state - @pytest.fixture - def mock_graph(self): - """Create mock Graph.""" - return MagicMock(spec=Graph) - @pytest.fixture def graph_init_params(self): """Create GraphInitParams fixture.""" return GraphInitParams( - tenant_id="test", - app_id="test", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test", graph_config={}, - user_id="test", - user_from="test", - invoke_from="test", + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test", + "app_id": "test", + "user_id": "test", + "user_from": "test", + "invoke_from": "test", + } + }, call_depth=0, ) @pytest.fixture - def list_operator_node_factory(self, graph_init_params, mock_graph, mock_graph_runtime_state): + def list_operator_node_factory(self, graph_init_params, mock_graph_runtime_state): """Factory fixture for creating ListOperatorNode instances.""" def _create_node(config, mock_variable): mock_graph_runtime_state.variable_pool.get.return_value = mock_variable return ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) return _create_node - def test_node_initialization(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, mock_graph_runtime_state, graph_init_params): """Test node initializes correctly.""" config = { "title": "List Operator", @@ -70,13 +66,12 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) - assert node.node_type == NodeType.LIST_OPERATOR + assert node.node_type == BuiltinNodeTypes.LIST_OPERATOR assert node._node_data.title == "List Operator" def test_version(self): @@ -101,7 +96,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_empty_array(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_array(self, mock_graph_runtime_state, graph_init_params): """Test with empty array.""" config = { "title": "Test", @@ -116,9 +111,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -129,7 +123,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] is None assert result.outputs["last_record"] is None - def test_run_with_filter_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with contains condition.""" config = { "title": "Test", @@ -148,9 +142,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -159,7 +152,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple"] - def test_run_with_filter_not_contains(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_not_contains(self, mock_graph_runtime_state, graph_init_params): """Test filter with not contains condition.""" config = { "title": "Test", @@ -178,9 +171,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -189,7 +181,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["banana", "cherry"] - def test_run_with_number_filter_greater_than(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_greater_than(self, mock_graph_runtime_state, graph_init_params): """Test filter with greater than condition on numbers.""" config = { "title": "Test", @@ -208,9 +200,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -219,7 +210,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [7, 9, 11] - def test_run_with_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in ascending order.""" config = { "title": "Test", @@ -237,9 +228,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -248,7 +238,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana", "cherry"] - def test_run_with_order_descending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_order_descending(self, mock_graph_runtime_state, graph_init_params): """Test ordering in descending order.""" config = { "title": "Test", @@ -266,9 +256,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -277,7 +266,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["cherry", "banana", "apple"] - def test_run_with_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_limit(self, mock_graph_runtime_state, graph_init_params): """Test with limit enabled.""" config = { "title": "Test", @@ -295,9 +284,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -306,7 +294,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "banana"] - def test_run_with_filter_order_and_limit(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_order_and_limit(self, mock_graph_runtime_state, graph_init_params): """Test with filter, order, and limit combined.""" config = { "title": "Test", @@ -331,9 +319,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -342,7 +329,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [9, 8, 7] - def test_run_with_variable_not_found(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_variable_not_found(self, mock_graph_runtime_state, graph_init_params): """Test when variable is not found.""" config = { "title": "Test", @@ -356,9 +343,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -367,7 +353,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Variable not found" in result.error - def test_run_with_first_and_last_record(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_first_and_last_record(self, mock_graph_runtime_state, graph_init_params): """Test first_record and last_record outputs.""" config = { "title": "Test", @@ -382,9 +368,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -394,7 +379,7 @@ class TestListOperatorNode: assert result.outputs["first_record"] == "first" assert result.outputs["last_record"] == "last" - def test_run_with_filter_startswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_startswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with startswith condition.""" config = { "title": "Test", @@ -413,9 +398,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -424,7 +408,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "application"] - def test_run_with_filter_endswith(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_filter_endswith(self, mock_graph_runtime_state, graph_init_params): """Test filter with endswith condition.""" config = { "title": "Test", @@ -443,9 +427,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -454,7 +437,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == ["apple", "pineapple", "table"] - def test_run_with_number_filter_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with equals condition.""" config = { "title": "Test", @@ -473,9 +456,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -484,7 +466,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [5, 5] - def test_run_with_number_filter_not_equals(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_filter_not_equals(self, mock_graph_runtime_state, graph_init_params): """Test number filter with not equals condition.""" config = { "title": "Test", @@ -503,9 +485,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) @@ -514,7 +495,7 @@ class TestListOperatorNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["result"].value == [1, 3, 7, 9] - def test_run_with_number_order_ascending(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_number_order_ascending(self, mock_graph_runtime_state, graph_init_params): """Test number ordering in ascending order.""" config = { "title": "Test", @@ -532,9 +513,8 @@ class TestListOperatorNode: node = ListOperatorNode( id="test", - config=config, + config={"id": "test", "data": config}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index 0677f1bb52..b0f0fd428b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -1,16 +1,16 @@ import uuid from typing import NamedTuple from unittest import mock +from unittest.mock import MagicMock import httpx import pytest -from sqlalchemy import Engine from core.helper import ssrf_proxy from core.tools import signature from core.tools.tool_file_manager import ToolFileManager -from core.workflow.file import FileTransferMethod, FileType, models -from core.workflow.nodes.llm.file_saver import ( +from dify_graph.file import FileTransferMethod, FileType, models +from dify_graph.nodes.llm.file_saver import ( FileSaverImpl, _extract_content_type_and_extension, _get_extension, @@ -44,7 +44,6 @@ class TestFileSaverImpl: ) mock_tool_file.id = _gen_id() mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - mocked_engine = mock.MagicMock(spec=Engine) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) @@ -53,11 +52,12 @@ class TestFileSaverImpl: # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) mocked_sign_file.return_value = mock_signed_url + http_client = MagicMock() storage_file_manager = FileSaverImpl( user_id=user_id, tenant_id=tenant_id, - engine_factory=mocked_engine, + http_client=http_client, ) file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) @@ -87,16 +87,18 @@ class TestFileSaverImpl: status_code=401, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response + file_saver = FileSaverImpl( user_id=_gen_id(), tenant_id=_gen_id(), + http_client=http_client, ) - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) with pytest.raises(httpx.HTTPStatusError) as exc: file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_get.assert_called_once_with(_TEST_URL) + http_client.get.assert_called_once_with(_TEST_URL) assert exc.value.response.status_code == 401 def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): @@ -112,8 +114,10 @@ class TestFileSaverImpl: headers={"Content-Type": mime_type}, request=mock_request, ) + http_client = MagicMock() + http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id) + file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) mock_tool_file = ToolFile( user_id=user_id, tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py new file mode 100644 index 0000000000..618a498659 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -0,0 +1,106 @@ +from unittest import mock + +import pytest + +from core.model_manager import ModelInstance +from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent +from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage +from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage +from dify_graph.nodes.llm.exc import NoPromptFoundError +from dify_graph.runtime import VariablePool + + +def _fetch_prompt_messages_with_mocked_content(content): + variable_pool = VariablePool.empty() + model_instance = mock.MagicMock(spec=ModelInstance) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + return_value=mock.MagicMock(features=[]), + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_list_messages", + return_value=[SystemPromptMessage(content=content)], + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + return_value=[], + ), + ): + return llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context=None, + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=["END"], + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + template_renderer=None, + ) + + +def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out(): + with pytest.raises(NoPromptFoundError): + _fetch_prompt_messages_with_mocked_content( + [ + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + +def test_fetch_prompt_messages_flattens_single_text_content_after_filtering_unsupported_multimodal_items(): + prompt_messages, stop = _fetch_prompt_messages_with_mocked_content( + [ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + assert stop == ["END"] + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_remain(): + prompt_messages, stop = _fetch_prompt_messages_with_mocked_content( + [ + TextPromptMessageContent(data="You are"), + TextPromptMessageContent(data=" a classifier."), + ImagePromptMessageContent( + format="url", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + + assert stop == ["END"] + assert prompt_messages == [ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are"), + TextPromptMessageContent(data=" a classifier."), + ] + ) + ] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 94b5b72ee1..fc96088af1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,13 +5,16 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.message_entities import ( +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from dify_graph.entities import GraphInitParams +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -19,13 +22,10 @@ from core.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.workflow.entities import GraphInitParams -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.llm import llm_utils -from core.workflow.nodes.llm.entities import ( +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, LLMNodeData, @@ -33,14 +33,14 @@ from core.workflow.nodes.llm.entities import ( VisionConfig, VisionConfigOptions, ) -from core.workflow.nodes.llm.file_saver import LLMFileSaver -from core.workflow.nodes.llm.node import LLMNode, _handle_memory_completion_mode -from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment -from models.enums import UserFrom +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.node import LLMNode +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType +from tests.workflow_test_utils import build_test_graph_init_params class MockTokenBufferMemory: @@ -76,11 +76,11 @@ def llm_node_data() -> LLMNodeData: @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="1", - app_id="1", + return build_test_graph_init_params( workflow_id="1", graph_config={}, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.SERVICE_API, @@ -107,10 +107,12 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) node_config = { "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -120,6 +122,8 @@ def llm_node( model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + template_renderer=mock_template_renderer, + http_client=http_client, ) return node @@ -588,6 +592,33 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] +def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): + llm_node._template_renderer.render_jinja2.return_value = "Hello, world" + messages = [ + LLMNodeChatModelMessage( + text="", + jinja2_text="Hello, {{ name }}", + role=PromptMessageRole.USER, + edition_type="jinja2", + ) + ] + + result = llm_node.handle_list_messages( + messages=messages, + context=None, + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=llm_node._template_renderer, + ) + + assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] + llm_node._template_renderer.render_jinja2.assert_called_once_with( + template="Hello, {{ name }}", + inputs={}, + ) + + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) memory.get_history_prompt_messages.return_value = [ @@ -611,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("core.workflow.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = _handle_memory_completion_mode( + with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = llm_utils.handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -628,10 +659,12 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) node_config = { "id": "1", "data": llm_node_data.model_dump(), } + http_client = mock.MagicMock() node = LLMNode( id="1", config=node_config, @@ -641,6 +674,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat model_factory=mock_model_factory, model_instance=mock.MagicMock(spec=ModelInstance), llm_file_saver=mock_file_saver, + template_renderer=mock_template_renderer, + http_client=http_client, ) return node, mock_file_saver diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py index ac0c1df9c5..e40d565ef5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py @@ -2,10 +2,10 @@ from collections.abc import Mapping, Sequence from pydantic import BaseModel, Field -from core.model_runtime.entities.message_entities import PromptMessage -from core.model_runtime.entities.model_entities import ModelFeature -from core.workflow.file import File -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage +from dify_graph.file import File +from dify_graph.model_runtime.entities.message_entities import PromptMessage +from dify_graph.model_runtime.entities.model_entities import ModelFeature +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage class LLMNodeTestScenario(BaseModel): diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py index 2742b7dab0..fd48edc58c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py @@ -1,5 +1,5 @@ -from core.workflow.nodes.parameter_extractor.entities import ParameterConfig -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.parameter_extractor.entities import ParameterConfig +from dify_graph.variables.types import SegmentType class TestParameterConfig: diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index ae229bbe2e..7eca531b62 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -7,17 +7,17 @@ from typing import Any import pytest -from core.model_runtime.entities import LLMMode -from core.workflow.nodes.llm import ModelConfig, VisionConfig -from core.workflow.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from core.workflow.nodes.parameter_extractor.exc import ( +from dify_graph.model_runtime.entities import LLMMode +from dify_graph.nodes.llm import ModelConfig, VisionConfig +from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from dify_graph.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from core.workflow.variables.types import SegmentType +from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment_with_type diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py index 5eb302798f..e57ebbd83e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py @@ -1,8 +1,8 @@ import pytest from pydantic import ValidationError -from core.workflow.enums import ErrorStrategy -from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData +from dify_graph.enums import ErrorStrategy +from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData class TestTemplateTransformNodeData: diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 0fb76fb7e7..332a8761f9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,14 +1,14 @@ -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.workflow.graph_engine.entities.graph import Graph -from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams -from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.enums import ErrorStrategy, NodeType, WorkflowNodeExecutionStatus -from core.workflow.nodes.template_transform.template_renderer import TemplateRenderError -from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode -from models.workflow import WorkflowType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus +from dify_graph.graph import Graph +from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError +from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode +from dify_graph.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params class TestTemplateTransformNode: @@ -24,21 +24,20 @@ class TestTemplateTransformNode: @pytest.fixture def mock_graph(self): - """Create a mock Graph.""" + """Create a mock Graph (kept for backward compat in other tests).""" return MagicMock(spec=Graph) @pytest.fixture def graph_init_params(self): """Create a mock GraphInitParams.""" - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", - workflow_type=WorkflowType.WORKFLOW, + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", - user_from="test", - invoke_from="test", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, call_depth=0, ) @@ -55,46 +54,49 @@ class TestTemplateTransformNode: "template": "Hello {{ name }}, you are {{ age }} years old!", } - def test_node_initialization(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_node_initialization(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test that TemplateTransformNode initializes correctly.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) - assert node.node_type == NodeType.TEMPLATE_TRANSFORM + assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM assert node._node_data.title == "Template Transform" assert len(node._node_data.variables) == 2 assert node._node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - def test_get_title(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_title(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_title method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" - def test_get_description(self, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_description(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _get_description method.""" + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" - def test_get_error_strategy(self, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_get_error_strategy(self, mock_graph_runtime_state, graph_init_params): """Test _get_error_strategy method.""" node_data = { "title": "Test", @@ -103,12 +105,13 @@ class TestTemplateTransformNode: "error_strategy": "fail-branch", } + mock_renderer = MagicMock() node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -127,14 +130,8 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_simple_template( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run with simple template transformation.""" + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool mock_name_value = MagicMock() mock_name_value.to_object.return_value = "Alice" @@ -147,15 +144,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - # Setup mock executor - mock_execute.return_value = "Hello Alice, you are 30 years old!" + # Setup mock renderer + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Hello Alice, you are 30 years old!" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -165,11 +163,7 @@ class TestTemplateTransformNode: assert result.inputs["name"] == "Alice" assert result.inputs["age"] == 30 - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_none_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_none_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with None variable values.""" node_data = { "title": "Test", @@ -178,14 +172,16 @@ class TestTemplateTransformNode: } mock_graph_runtime_state.variable_pool.get.return_value = None - mock_execute.return_value = "Value: " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Value: " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -193,23 +189,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs["value"] is None - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_code_execution_error( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): - """Test _run when code execution fails.""" + def test_run_with_render_error(self, basic_node_data, mock_graph_runtime_state, graph_init_params): + """Test _run when template rendering fails.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.side_effect = TemplateRenderError("Template syntax error") + + mock_renderer = MagicMock() + mock_renderer.render_template.side_effect = TemplateRenderError("Template syntax error") node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -217,23 +209,19 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Template syntax error" in result.error - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_output_length_exceeds_limit( - self, mock_execute, basic_node_data, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_output_length_exceeds_limit(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run when output exceeds maximum length.""" mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() - mock_execute.return_value = "This is a very long output that exceeds the limit" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a very long output that exceeds the limit" node = TemplateTransformNode( id="test_node", - config=basic_node_data, + config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, max_output_length=10, ) @@ -242,13 +230,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_complex_jinja2_template( - self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params - ): + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { "title": "Complex Template", @@ -272,14 +254,16 @@ class TestTemplateTransformNode: ("sys", "show_total"): mock_show_total, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "apple, banana, orange (Total: 3)" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "apple, banana, orange (Total: 3)" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -307,11 +291,7 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_empty_variables(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { "title": "Static Template", @@ -319,14 +299,15 @@ class TestTemplateTransformNode: "template": "This is a static message.", } - mock_execute.return_value = "This is a static message." + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "This is a static message." node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -335,11 +316,7 @@ class TestTemplateTransformNode: assert result.outputs["output"] == "This is a static message." assert result.inputs == {} - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_numeric_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_numeric_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with numeric variable values.""" node_data = { "title": "Numeric Template", @@ -360,14 +337,16 @@ class TestTemplateTransformNode: ("sys", "quantity"): mock_quantity, } mock_graph_runtime_state.variable_pool.get.side_effect = lambda selector: variable_map.get(tuple(selector)) - mock_execute.return_value = "Total: $31.5" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Total: $31.5" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -375,11 +354,7 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.outputs["output"] == "Total: $31.5" - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_dict_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_dict_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with dictionary variable values.""" node_data = { "title": "Dict Template", @@ -391,14 +366,16 @@ class TestTemplateTransformNode: mock_user.to_object.return_value = {"name": "John Doe", "email": "john@example.com"} mock_graph_runtime_state.variable_pool.get.return_value = mock_user - mock_execute.return_value = "Name: John Doe, Email: john@example.com" + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Name: John Doe, Email: john@example.com" node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() @@ -407,11 +384,7 @@ class TestTemplateTransformNode: assert "John Doe" in result.outputs["output"] assert "john@example.com" in result.outputs["output"] - @patch( - "core.workflow.nodes.template_transform.template_transform_node.CodeExecutorJinja2TemplateRenderer.render_template", - autospec=True, - ) - def test_run_with_list_values(self, mock_execute, mock_graph, mock_graph_runtime_state, graph_init_params): + def test_run_with_list_values(self, mock_graph_runtime_state, graph_init_params): """Test _run with list variable values.""" node_data = { "title": "List Template", @@ -423,14 +396,16 @@ class TestTemplateTransformNode: mock_tags.to_object.return_value = ["python", "ai", "workflow"] mock_graph_runtime_state.variable_pool.get.return_value = mock_tags - mock_execute.return_value = "Tags: #python #ai #workflow " + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "Tags: #python #ai #workflow " node = TemplateTransformNode( id="test_node", - config=node_data, + config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, - graph=mock_graph, graph_runtime_state=mock_graph_runtime_state, + template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 1854cca236..2b0205fb7b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -2,12 +2,15 @@ from collections.abc import Mapping import pytest -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.base.node import Node -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.base.node import Node +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params class _SampleNodeData(BaseNodeData): @@ -15,7 +18,7 @@ class _SampleNodeData(BaseNodeData): class _SampleNode(Node[_SampleNodeData]): - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER @classmethod def version(cls) -> str: @@ -26,15 +29,10 @@ class _SampleNode(Node[_SampleNodeData]): def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: - init_params = GraphInitParams( - tenant_id="tenant", - app_id="app", - workflow_id="workflow", + init_params = build_test_graph_init_params( graph_config=graph_config, - user_id="user", user_from="account", invoke_from="debugger", - call_depth=0, ) runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), @@ -43,19 +41,62 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, return init_params, runtime_state +def _build_node_config() -> NodeConfigDict: + return NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + }, + } + ) + + def test_node_hydrates_data_during_initialization(): graph_config: dict[str, object] = {} init_params, runtime_state = _build_context(graph_config) node = _SampleNode( id="node-1", - config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}}, + config=_build_node_config(), graph_init_params=init_params, graph_runtime_state=runtime_state, ) assert node.node_data.foo == "bar" assert node.title == "Sample" + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == "account" + assert dify_ctx.invoke_from == "debugger" + + +def test_node_accepts_invoke_from_enum(): + graph_config: dict[str, object] = {} + init_params = build_test_graph_init_params( + graph_config=graph_config, + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + start_at=0.0, + ) + + node = _SampleNode( + id="node-1", + config=_build_node_config(), + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + dify_ctx = node.require_dify_context() + assert dify_ctx.user_from == UserFrom.ACCOUNT + assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER + assert node.get_run_context_value("missing") is None + with pytest.raises(ValueError): + node.require_run_context_value("missing") def test_missing_generic_argument_raises_type_error(): @@ -64,7 +105,7 @@ def test_missing_generic_argument_raises_type_error(): with pytest.raises(TypeError): class _InvalidNode(Node): # type: ignore[type-abstract] - node_type = NodeType.ANSWER + node_type = BuiltinNodeTypes.ANSWER @classmethod def version(cls) -> str: @@ -72,3 +113,17 @@ def test_missing_generic_argument_raises_type_error(): def _run(self): raise NotImplementedError + + +def test_base_node_data_keeps_dict_style_access_compatibility(): + node_data = _SampleNodeData.model_validate( + { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + } + ) + + assert node_data["foo"] == "bar" + assert node_data.get("foo") == "bar" + assert node_data.get("missing", "fallback") == "fallback" diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 35c59b92c4..40754974c1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -5,31 +5,32 @@ import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities import GraphInitParams -from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod -from core.workflow.node_events import NodeRunResult -from core.workflow.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from core.workflow.nodes.document_extractor.node import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities import GraphInitParams +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from dify_graph.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, + _normalize_docx_zip, ) -from core.workflow.variables import ArrayFileSegment -from core.workflow.variables.segments import ArrayStringSegment -from core.workflow.variables.variables import StringVariable -from models.enums import UserFrom +from dify_graph.variables import ArrayFileSegment +from dify_graph.variables.segments import ArrayStringSegment +from dify_graph.variables.variables import StringVariable +from tests.workflow_test_utils import build_test_graph_init_params @pytest.fixture def graph_init_params() -> GraphInitParams: - return GraphInitParams( - tenant_id="test_tenant", - app_id="test_app", + return build_test_graph_init_params( workflow_id="test_workflow", graph_config={}, + tenant_id="test_tenant", + app_id="test_app", user_id="test_user", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -44,11 +45,13 @@ def document_extractor_node(graph_init_params): variable_selector=["node_id", "variable_name"], ) node_config = {"id": "test_node_id", "data": node_data.model_dump()} + http_client = Mock() node = DocumentExtractorNode( id="test_node_id", config=node_config, graph_init_params=graph_init_params, graph_runtime_state=Mock(), + http_client=http_client, ) return node @@ -84,6 +87,38 @@ def test_run_invalid_variable_type(document_extractor_node, mock_graph_runtime_s assert "is not an ArrayFileSegment" in result.error +def test_run_empty_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """Empty file list should return SUCCEEDED with empty documents and ArrayStringSegment([]).""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Provide an actual ArrayFileSegment with an empty list + mock_graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(value=[]) + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + +def test_run_none_only_file_list_returns_succeeded(document_extractor_node, mock_graph_runtime_state): + """A file list containing only None (e.g., [None]) should be filtered to [] and succeed.""" + document_extractor_node.graph_runtime_state = mock_graph_runtime_state + + # Use a Mock to bypass type validation for None entries in the list + afs = Mock(spec=ArrayFileSegment) + afs.value = [None] + mock_graph_runtime_state.variable_pool.get.return_value = afs + + result = document_extractor_node._run() + + assert isinstance(result, NodeRunResult) + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED, result.error + assert result.process_data.get("documents") == [] + assert result.outputs["text"] == ArrayStringSegment(value=[]) + + @pytest.mark.parametrize( ("mime_type", "file_content", "expected_text", "transfer_method", "extension"), [ @@ -142,19 +177,20 @@ def test_run_extract_text( mock_graph_runtime_state.variable_pool.get.return_value = mock_array_file_segment mock_download = Mock(return_value=file_content) - mock_ssrf_proxy_get = Mock() - mock_ssrf_proxy_get.return_value.content = file_content - mock_ssrf_proxy_get.return_value.raise_for_status = Mock() - monkeypatch.setattr("core.workflow.file.file_manager.download", mock_download) - monkeypatch.setattr("core.helper.ssrf_proxy.get", mock_ssrf_proxy_get) + mock_response = Mock() + mock_response.content = file_content + mock_response.raise_for_status = Mock() + document_extractor_node._http_client.get = Mock(return_value=mock_response) + + monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("core.workflow.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() @@ -164,7 +200,7 @@ def test_run_extract_text( assert result.outputs["text"] == ArrayStringSegment(value=expected_text) if transfer_method == FileTransferMethod.REMOTE_URL: - mock_ssrf_proxy_get.assert_called_once_with("https://example.com/file.txt") + document_extractor_node._http_client.get.assert_called_once_with("https://example.com/file.txt") elif transfer_method == FileTransferMethod.LOCAL_FILE: mock_download.assert_called_once_with(mock_file) @@ -214,7 +250,7 @@ def test_extract_text_from_docx(mock_document): def test_node_type(document_extractor_node): - assert document_extractor_node.node_type == NodeType.DOCUMENT_EXTRACTOR + assert document_extractor_node.node_type == BuiltinNodeTypes.DOCUMENT_EXTRACTOR @patch("pandas.ExcelFile") @@ -382,3 +418,58 @@ def test_extract_text_from_excel_numeric_type_column(mock_excel_file): expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" assert expected_manual == result + + +def _make_docx_zip(use_backslash: bool) -> bytes: + """Helper to build a minimal in-memory DOCX zip. + + When use_backslash=True the ZIP entry names use backslash separators + (as produced by Evernote on Windows), otherwise forward slashes are used. + """ + import zipfile + + sep = "\\" if use_backslash else "/" + buf = io.BytesIO() + with zipfile.ZipFile(buf, "w", compression=zipfile.ZIP_DEFLATED) as zf: + zf.writestr("[Content_Types].xml", b"") + zf.writestr(f"_rels{sep}.rels", b"") + zf.writestr(f"word{sep}document.xml", b"") + zf.writestr(f"word{sep}_rels{sep}document.xml.rels", b"") + return buf.getvalue() + + +def test_normalize_docx_zip_replaces_backslashes(): + """ZIP entries with backslash separators must be rewritten to forward slashes.""" + import zipfile + + malformed = _make_docx_zip(use_backslash=True) + fixed = _normalize_docx_zip(malformed) + + with zipfile.ZipFile(io.BytesIO(fixed)) as zf: + names = zf.namelist() + + assert "word/document.xml" in names + assert "word/_rels/document.xml.rels" in names + # No entry should contain a backslash after normalization + assert all("\\" not in name for name in names) + + +def test_normalize_docx_zip_leaves_forward_slash_unchanged(): + """ZIP entries that already use forward slashes must not be modified.""" + import zipfile + + normal = _make_docx_zip(use_backslash=False) + fixed = _normalize_docx_zip(normal) + + with zipfile.ZipFile(io.BytesIO(fixed)) as zf: + names = zf.namelist() + + assert "word/document.xml" in names + assert "word/_rels/document.xml.rels" in names + + +def test_normalize_docx_zip_returns_original_on_bad_zip(): + """Non-zip bytes must be returned as-is without raising.""" + garbage = b"not a zip file at all" + result = _normalize_docx_zip(garbage) + assert result == garbage diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index bc87a64161..c746a945fe 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,30 +4,30 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.graph import Graph -from core.workflow.nodes.if_else.entities import IfElseNodeData -from core.workflow.nodes.if_else.if_else_node import IfElseNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from core.workflow.variables import ArrayFileSegment +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.graph import Graph +from dify_graph.nodes.if_else.entities import IfElseNodeData +from dify_graph.nodes.if_else.if_else_node import IfElseNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db -from models.enums import UserFrom +from tests.workflow_test_utils import build_test_graph_init_params def test_execute_if_else_result_true(): graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -60,7 +60,7 @@ def test_execute_if_else_result_true(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "if-else", @@ -129,11 +129,11 @@ def test_execute_if_else_result_false(): # Create a simple graph for IfElse node testing graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -154,7 +154,7 @@ def test_execute_if_else_result_false(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "if-else", @@ -230,14 +230,18 @@ def test_array_file_contains_file_name(): # Create properly configured mock for graph_init_params graph_init_params = Mock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = IfElseNode( id=str(uuid.uuid4()), @@ -299,11 +303,11 @@ def test_execute_if_else_boolean_conditions(condition: Condition): """Test IfElseNode with boolean conditions using various operators""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -324,7 +328,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean Test", @@ -354,11 +358,11 @@ def test_execute_if_else_boolean_false_conditions(): """Test IfElseNode with boolean conditions that should evaluate to false""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -378,7 +382,7 @@ def test_execute_if_else_boolean_false_conditions(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean False Test", @@ -423,11 +427,11 @@ def test_execute_if_else_boolean_cases_structure(): """Test IfElseNode with boolean conditions using the new cases structure""" graph_config = {"edges": [], "nodes": [{"data": {"type": "start", "title": "Start"}, "id": "start"}]} - init_params = GraphInitParams( - tenant_id="1", - app_id="1", + init_params = build_test_graph_init_params( workflow_id="1", graph_config=graph_config, + tenant_id="1", + app_id="1", user_id="1", user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, @@ -446,7 +450,7 @@ def test_execute_if_else_boolean_cases_structure(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_data = { "title": "Boolean Cases Test", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 73c17ee45a..6ca72b64b2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.enums import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.nodes.list_operator.entities import ( +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -14,10 +15,9 @@ from core.workflow.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from core.workflow.nodes.list_operator.exc import InvalidKeyError -from core.workflow.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from core.workflow.variables import ArrayFileSegment -from models.enums import UserFrom +from dify_graph.nodes.list_operator.exc import InvalidKeyError +from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from dify_graph.variables import ArrayFileSegment @pytest.fixture @@ -42,14 +42,18 @@ def list_operator_node(): } # Create properly configured mock for graph_init_params graph_init_params = MagicMock() - graph_init_params.tenant_id = "test_tenant" - graph_init_params.app_id = "test_app" graph_init_params.workflow_id = "test_workflow" graph_init_params.graph_config = {} - graph_init_params.user_id = "test_user" - graph_init_params.user_from = UserFrom.ACCOUNT - graph_init_params.invoke_from = InvokeFrom.SERVICE_API graph_init_params.call_depth = 0 + graph_init_params.run_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "test_tenant", + "app_id": "test_app", + "user_id": "test_user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + } node = ListOperatorNode( id="test_node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py new file mode 100644 index 0000000000..6372583839 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -0,0 +1,52 @@ +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.nodes.loop.entities import LoopNodeData +from dify_graph.nodes.loop.loop_node import LoopNode + + +def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: + seen_configs: list[object] = [] + original_validate_python = NodeConfigDictAdapter.validate_python + + def record_validate_python(value: object): + seen_configs.append(value) + return original_validate_python(value) + + monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) + + child_node_config = { + "id": "answer-node", + "data": { + "type": "answer", + "title": "Answer", + "answer": "", + "loop_id": "loop-node", + }, + } + + LoopNode._extract_variable_selector_to_variable_mapping( + graph_config={ + "nodes": [ + { + "id": "loop-node", + "data": { + "type": "loop", + "title": "Loop", + "loop_count": 1, + "break_conditions": [], + "logical_operator": "and", + }, + }, + child_node_config, + ], + "edges": [], + }, + node_id="loop-node", + node_data=LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + ), + ) + + assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index 47ef289ef3..c5a02e87e4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -1,5 +1,14 @@ -from core.model_runtime.entities import ImagePromptMessageContent -from core.workflow.nodes.question_classifier import QuestionClassifierNodeData +from types import SimpleNamespace +from unittest.mock import MagicMock + +from dify_graph.model_runtime.entities import ImagePromptMessageContent +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.nodes.question_classifier import ( + QuestionClassifierNode, + QuestionClassifierNodeData, +) +from tests.workflow_test_utils import build_test_graph_init_params def test_init_question_classifier_node_data(): @@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config(): assert node_data.vision.enabled == False assert node_data.vision.configs.variable_selector == ["sys", "files"] assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH + + +def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): + node_data = QuestionClassifierNodeData.model_validate( + { + "title": "test classifier node", + "query_variable_selector": ["id", "name"], + "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, + "classes": [{"id": "1", "name": "class 1"}], + "instruction": "This is a test instruction", + } + ) + template_renderer = MagicMock(spec=TemplateRenderer) + node = QuestionClassifierNode( + id="node-id", + config={"id": "node-id", "data": node_data.model_dump(mode="json")}, + graph_init_params=build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + ), + graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), + credentials_provider=MagicMock(spec=CredentialsProvider), + model_factory=MagicMock(spec=ModelFactory), + model_instance=MagicMock(), + http_client=MagicMock(spec=HttpClientProtocol), + llm_file_saver=MagicMock(), + template_renderer=template_renderer, + ) + fetch_prompt_messages = MagicMock(return_value=([], None)) + monkeypatch.setattr( + "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", + fetch_prompt_messages, + ) + monkeypatch.setattr( + "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", + MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), + ) + + node._calculate_rest_token( + node_data=node_data, + query="hello", + model_instance=MagicMock(stop=(), parameters={}), + context="", + ) + + assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index 8c7dc24868..b8f0e25e91 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,12 +4,12 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError -from core.workflow.entities import GraphInitParams -from core.workflow.nodes.start.entities import StartNodeData -from core.workflow.nodes.start.start_node import StartNode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.nodes.start.entities import StartNodeData +from dify_graph.nodes.start.start_node import StartNode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType +from tests.workflow_test_utils import build_test_graph_init_params def make_start_node(user_inputs, variables): @@ -32,11 +32,11 @@ def make_start_node(user_inputs, variables): return StartNode( id="start", config=config, - graph_init_params=GraphInitParams( - tenant_id="tenant", - app_id="app", + graph_init_params=build_test_graph_init_params( workflow_id="wf", graph_config={}, + tenant_id="tenant", + app_id="app", user_id="u", user_from="account", invoke_from="debugger", diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 678691439f..3cbd96dfef 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -8,18 +8,18 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer -from core.workflow.entities import GraphInitParams -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.node_events import StreamChunkEvent, StreamCompletedEvent -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.segments import ArrayFileSegment +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.segments import ArrayFileSegment +from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from core.workflow.nodes.tool.tool_node import ToolNode + from dify_graph.nodes.tool.tool_node import ToolNode @pytest.fixture @@ -31,7 +31,8 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from core.workflow.nodes.tool.tool_node import ToolNode + from dify_graph.nodes.protocols import ToolFileManagerProtocol + from dify_graph.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -54,11 +55,11 @@ def tool_node(monkeypatch) -> ToolNode: "edges": [], } - init_params = GraphInitParams( - tenant_id="tenant-id", - app_id="app-id", + init_params = build_test_graph_init_params( workflow_id="workflow-id", graph_config=graph_config, + tenant_id="tenant-id", + app_id="app-id", user_id="user-id", user_from="account", invoke_from="debugger", @@ -69,11 +70,16 @@ def tool_node(monkeypatch) -> ToolNode: graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] + + # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor + tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + node = ToolNode( id="node-instance", config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, + tool_file_manager_factory=tool_file_manager_factory, ) return node diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py new file mode 100644 index 0000000000..9aeab0409e --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -0,0 +1,63 @@ +from collections.abc import Mapping + +from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE +from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_graph_init_params + + +def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: + init_params = build_test_graph_init_params( + graph_config=graph_config, + user_from="account", + invoke_from="debugger", + ) + runtime_state = GraphRuntimeState( + variable_pool=VariablePool( + system_variables=SystemVariable(user_id="user", files=[]), + user_inputs={"payload": "value"}, + ), + start_at=0.0, + ) + return init_params, runtime_state + + +def _build_node_config() -> NodeConfigDict: + return NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": TRIGGER_PLUGIN_NODE_TYPE, + "title": "Trigger Event", + "plugin_id": "plugin-id", + "provider_id": "provider-id", + "event_name": "event-name", + "subscription_id": "subscription-id", + "plugin_unique_identifier": "plugin-unique-identifier", + "event_parameters": {}, + }, + } + ) + + +def test_trigger_event_node_run_populates_trigger_info_metadata() -> None: + init_params, runtime_state = _build_context(graph_config={}) + node = TriggerEventNode( + id="node-1", + config=_build_node_config(), + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { + "provider_id": "provider-id", + "event_name": "event-name", + "plugin_unique_identifier": "plugin-unique-identifier", + } diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 8a52f963ef..e69c05dc0b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,18 +2,18 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.graph_events.node import NodeRunSucceededEvent -from core.workflow.nodes.variable_assigner.common import helpers as common_helpers -from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode -from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayStringVariable, StringVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.graph import Graph +from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.nodes.variable_assigner.common import helpers as common_helpers +from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode +from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" @@ -43,13 +43,17 @@ def test_overwrite_string_variable(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -84,7 +88,7 @@ def test_overwrite_string_variable(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -141,13 +145,17 @@ def test_append_variable_to_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -180,7 +188,7 @@ def test_append_variable_to_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -236,13 +244,17 @@ def test_clear_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -265,7 +277,7 @@ def test_clear_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py index 9a874337ed..a7673c5a14 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -1,6 +1,6 @@ -from core.workflow.nodes.variable_assigner.v2.enums import Operation -from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid -from core.workflow.variables import SegmentType +from dify_graph.nodes.variable_assigner.v2.enums import Operation +from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid +from dify_graph.variables import SegmentType def test_is_input_value_valid_overwrite_array_string(): diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 5ed68fe8d0..6874f3fef1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,16 +2,16 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom -from core.app.workflow.node_factory import DifyNodeFactory -from core.workflow.entities import GraphInitParams -from core.workflow.graph import Graph -from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode -from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation -from core.workflow.runtime import GraphRuntimeState, VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import ArrayStringVariable -from models.enums import UserFrom +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from dify_graph.entities import GraphInitParams +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.graph import Graph +from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode +from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" @@ -85,13 +85,17 @@ def test_remove_first_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -114,7 +118,7 @@ def test_remove_first_from_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -169,13 +173,17 @@ def test_remove_last_from_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -198,7 +206,7 @@ def test_remove_last_from_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -250,13 +258,17 @@ def test_remove_first_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -279,7 +291,7 @@ def test_remove_first_from_empty_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -331,13 +343,17 @@ def test_remove_last_from_empty_array(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) @@ -360,7 +376,7 @@ def test_remove_last_from_empty_array(): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") node_config = { "id": "node_id", @@ -404,13 +420,17 @@ def test_node_factory_creates_variable_assigner_node(): } init_params = GraphInitParams( - tenant_id="1", - app_id="1", workflow_id="1", graph_config=graph_config, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.DEBUGGER, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, call_depth=0, ) variable_pool = VariablePool( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 4fa9a01b61..6be5bb23e8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -210,9 +210,6 @@ def test_webhook_data_model_dump_with_alias(): def test_webhook_data_validation_errors(): """Test WebhookData validation errors.""" - # Title is required (inherited from BaseNodeData) - with pytest.raises(ValidationError): - WebhookData() # Invalid method with pytest.raises(ValidationError): @@ -254,6 +251,36 @@ def test_webhook_data_sequence_fields(): assert len(data.headers) == 1 # Should still be 1 +def test_webhook_data_rejects_non_string_header_types(): + """Headers should stay string-only because runtime does not coerce header values.""" + for param_type in ["number", "boolean", "object", "array[string]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + headers=[WebhookParameter(name="X-Test", type=param_type)], + ) + + +def test_webhook_data_limits_query_param_types_to_scalar_values(): + """Query params only support scalar conversions in the current runtime.""" + data = WebhookData( + title="Test", + params=[ + WebhookParameter(name="count", type="number"), + WebhookParameter(name="enabled", type="boolean"), + ], + ) + assert data.params[0].type == "number" + assert data.params[1].type == "boolean" + + for param_type in ["object", "array[string]", "array[number]", "array[boolean]", "array[object]", "file"]: + with pytest.raises(ValidationError): + WebhookData( + title="Test", + params=[WebhookParameter(name="test", type=param_type)], + ) + + def test_webhook_data_sync_mode(): """Test WebhookData SyncMode nested enum.""" # Test that SyncMode enum exists and has expected value @@ -297,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from core.workflow.nodes.base import BaseNodeData + from dify_graph.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index 374d5183c8..ddf1af5a59 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,12 +1,12 @@ import pytest -from core.workflow.nodes.base.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, WebhookNodeError, WebhookNotFoundError, WebhookTimeoutError, ) +from dify_graph.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index d8f6b41f89..78dd7ce0f3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,9 +8,7 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -18,11 +16,11 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable -from models.enums import UserFrom -from models.workflow import WorkflowType +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable def create_webhook_node( @@ -37,14 +35,17 @@ def create_webhook_node( } graph_init_params = GraphInitParams( - tenant_id=tenant_id, - app_id="test-app", - workflow_type=WorkflowType.WORKFLOW, workflow_id="test-workflow", graph_config={}, - user_id="test-user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": tenant_id, + "app_id": "test-app", + "user_id": "test-user", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 24d3740b99..139f65d6c3 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,10 +2,8 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.graph_init_params import GraphInitParams -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.file import File, FileTransferMethod, FileType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -14,12 +12,13 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from core.workflow.runtime.graph_runtime_state import GraphRuntimeState -from core.workflow.runtime.variable_pool import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import FileVariable, StringVariable -from models.enums import UserFrom -from models.workflow import WorkflowType +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.runtime.graph_runtime_state import GraphRuntimeState +from dify_graph.runtime.variable_pool import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import FileVariable, StringVariable def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -30,14 +29,17 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) } graph_init_params = GraphInitParams( - tenant_id="1", - app_id="1", - workflow_type=WorkflowType.WORKFLOW, workflow_id="1", graph_config={}, - user_id="1", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.SERVICE_API, + } + }, call_depth=0, ) runtime_state = GraphRuntimeState( @@ -81,7 +83,7 @@ def test_webhook_node_basic_initialization(): node = create_webhook_node(data, variable_pool) - assert node.node_type.value == "trigger-webhook" + assert node.node_type == TRIGGER_WEBHOOK_NODE_TYPE assert node.version() == "1" assert node._get_title() == "Test Webhook" assert node._node_data.method == Method.POST diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py index 078ec5f6ab..e8ce6f60f7 100644 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ b/api/tests/unit_tests/core/workflow/test_enums.py @@ -1,6 +1,6 @@ """Tests for workflow pause related enums and constants.""" -from core.workflow.enums import ( +from dify_graph.enums import ( WorkflowExecutionStatus, ) diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py new file mode 100644 index 0000000000..367e3958ad --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -0,0 +1,634 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom +from core.workflow import node_factory +from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from dify_graph.entities.base_node_data import BaseNodeData +from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.variables.segments import StringSegment + + +def _assert_typed_node_config(config, *, node_id: str, node_type: NodeType, version: str = "1") -> None: + assert config["id"] == node_id + assert isinstance(config["data"], BaseNodeData) + assert config["data"].type == node_type + assert config["data"].version == version + + +class TestFetchMemory: + @pytest.mark.parametrize( + ("conversation_id", "memory_config"), + [ + (None, object()), + ("conversation-id", None), + ], + ) + def test_returns_none_when_memory_or_conversation_is_missing(self, conversation_id, memory_config): + result = node_factory.fetch_memory( + conversation_id=conversation_id, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_returns_none_when_conversation_does_not_exist(self, monkeypatch): + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return None + + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is None + + def test_builds_token_buffer_memory_for_existing_conversation(self, monkeypatch): + conversation = sentinel.conversation + memory = sentinel.memory + + class FakeSelect: + def where(self, *_args): + return self + + class FakeSession: + def __init__(self, *_args, **_kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, *_args): + return False + + def scalar(self, _stmt): + return conversation + + token_buffer_memory = MagicMock(return_value=memory) + monkeypatch.setattr(node_factory, "db", SimpleNamespace(engine=sentinel.engine)) + monkeypatch.setattr(node_factory, "select", MagicMock(return_value=FakeSelect())) + monkeypatch.setattr(node_factory, "Session", FakeSession) + monkeypatch.setattr(node_factory, "TokenBufferMemory", token_buffer_memory) + + result = node_factory.fetch_memory( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=object(), + model_instance=sentinel.model_instance, + ) + + assert result is memory + token_buffer_memory.assert_called_once_with( + conversation=conversation, + model_instance=sentinel.model_instance, + ) + + +class TestDefaultWorkflowCodeExecutor: + def test_execute_delegates_to_code_executor(self, monkeypatch): + executor = node_factory.DefaultWorkflowCodeExecutor() + execute_workflow_code_template = MagicMock(return_value={"answer": "ok"}) + monkeypatch.setattr( + node_factory.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = executor.execute( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + assert result == {"answer": "ok"} + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.PYTHON3, + code="print('ok')", + inputs={"name": "workflow"}, + ) + + def test_is_execution_error_checks_code_execution_error_type(self): + executor = node_factory.DefaultWorkflowCodeExecutor() + + assert executor.is_execution_error(node_factory.CodeExecutionError("boom")) is True + assert executor.is_execution_error(RuntimeError("boom")) is False + + +class TestDefaultLLMTemplateRenderer: + def test_render_jinja2_delegates_to_code_executor(self, monkeypatch): + renderer = node_factory.DefaultLLMTemplateRenderer() + execute_workflow_code_template = MagicMock(return_value={"result": "hello world"}) + monkeypatch.setattr( + node_factory.CodeExecutor, + "execute_workflow_code_template", + execute_workflow_code_template, + ) + + result = renderer.render_jinja2( + template="Hello {{ name }}", + inputs={"name": "world"}, + ) + + assert result == "hello world" + execute_workflow_code_template.assert_called_once_with( + language=CodeLanguage.JINJA2, + code="Hello {{ name }}", + inputs={"name": "world"}, + ) + + +class TestDifyNodeFactoryInit: + def test_init_builds_default_dependencies(self): + graph_init_params = SimpleNamespace(run_context={"context": "value"}) + graph_runtime_state = sentinel.graph_runtime_state + dify_context = SimpleNamespace(tenant_id="tenant-id") + template_renderer = sentinel.template_renderer + unstructured_api_config = sentinel.unstructured_api_config + http_request_config = sentinel.http_request_config + credentials_provider = sentinel.credentials_provider + model_factory = sentinel.model_factory + llm_template_renderer = sentinel.llm_template_renderer + + with ( + patch.object( + node_factory.DifyNodeFactory, + "_resolve_dify_context", + return_value=dify_context, + ) as resolve_dify_context, + patch.object( + node_factory, + "CodeExecutorJinja2TemplateRenderer", + return_value=template_renderer, + ) as renderer_factory, + patch.object( + node_factory, + "UnstructuredApiConfig", + return_value=unstructured_api_config, + ), + patch.object( + node_factory, + "build_http_request_config", + return_value=http_request_config, + ), + patch.object( + node_factory, + "DefaultLLMTemplateRenderer", + return_value=llm_template_renderer, + ) as llm_renderer_factory, + patch.object( + node_factory, + "build_dify_model_access", + return_value=(credentials_provider, model_factory), + ) as build_dify_model_access, + ): + factory = node_factory.DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + + resolve_dify_context.assert_called_once_with(graph_init_params.run_context) + build_dify_model_access.assert_called_once_with("tenant-id") + renderer_factory.assert_called_once() + llm_renderer_factory.assert_called_once() + assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor + assert factory.graph_init_params is graph_init_params + assert factory.graph_runtime_state is graph_runtime_state + assert factory._dify_context is dify_context + assert factory._template_renderer is template_renderer + + assert factory._llm_template_renderer is llm_template_renderer + assert factory._document_extractor_unstructured_api_config is unstructured_api_config + assert factory._http_request_config is http_request_config + assert factory._llm_credentials_provider is credentials_provider + assert factory._llm_model_factory is model_factory + + +class TestDifyNodeFactoryResolveContext: + def test_requires_reserved_context_key(self): + with pytest.raises(ValueError, match=DIFY_RUN_CONTEXT_KEY): + node_factory.DifyNodeFactory._resolve_dify_context({}) + + def test_returns_existing_dify_context(self): + dify_context = DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + result = node_factory.DifyNodeFactory._resolve_dify_context({DIFY_RUN_CONTEXT_KEY: dify_context}) + + assert result is dify_context + + def test_validates_mapping_context(self): + raw_context = { + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant-id", + "app_id": "app-id", + "user_id": "user-id", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + } + + result = node_factory.DifyNodeFactory._resolve_dify_context(raw_context) + + assert isinstance(result, DifyRunContext) + assert result.tenant_id == "tenant-id" + + +class TestDifyNodeFactoryCreateNode: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory.graph_init_params = sentinel.graph_init_params + factory.graph_runtime_state = sentinel.graph_runtime_state + factory._dify_context = SimpleNamespace(tenant_id="tenant-id", app_id="app-id") + factory._code_executor = sentinel.code_executor + factory._code_limits = sentinel.code_limits + factory._template_renderer = sentinel.template_renderer + factory._llm_template_renderer = sentinel.llm_template_renderer + factory._template_transform_max_output_length = 2048 + factory._http_request_http_client = sentinel.http_client + factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._http_request_file_manager = sentinel.file_manager + factory._document_extractor_unstructured_api_config = sentinel.unstructured_api_config + factory._http_request_config = sentinel.http_request_config + factory._llm_credentials_provider = sentinel.credentials_provider + factory._llm_model_factory = sentinel.model_factory + return factory + + def test_rejects_unknown_node_type(self, factory): + with pytest.raises(ValueError, match="No class mapping found for node type: missing"): + factory.create_node({"id": "node-id", "data": {"type": "missing"}}) + + def test_rejects_missing_class_mapping(self, monkeypatch, factory): + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(side_effect=ValueError("No class mapping found for node type: start")), + ) + + with pytest.raises(ValueError, match="No class mapping found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START}}) + + def test_rejects_missing_latest_class(self, monkeypatch, factory): + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(side_effect=ValueError("No latest version class found for node type: start")), + ) + + with pytest.raises(ValueError, match="No latest version class found for node type: start"): + factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START}}) + + def test_uses_version_specific_class_when_available(self, monkeypatch, factory): + matched_node = sentinel.matched_node + latest_node_class = MagicMock(return_value=sentinel.latest_node) + matched_node_class = MagicMock(return_value=matched_node) + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(return_value=matched_node_class), + ) + + result = factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START, "version": "9"}}) + + assert result is matched_node + matched_node_class.assert_called_once() + kwargs = matched_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + latest_node_class.assert_not_called() + + def test_falls_back_to_latest_class_when_version_specific_mapping_is_missing(self, monkeypatch, factory): + latest_node = sentinel.latest_node + latest_node_class = MagicMock(return_value=latest_node) + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(return_value=latest_node_class), + ) + + result = factory.create_node({"id": "node-id", "data": {"type": BuiltinNodeTypes.START, "version": "9"}}) + + assert result is latest_node + latest_node_class.assert_called_once() + kwargs = latest_node_class.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=BuiltinNodeTypes.START, version="9") + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + @pytest.mark.parametrize( + ("node_type", "constructor_name"), + [ + (BuiltinNodeTypes.CODE, "CodeNode"), + (BuiltinNodeTypes.TEMPLATE_TRANSFORM, "TemplateTransformNode"), + (BuiltinNodeTypes.HTTP_REQUEST, "HttpRequestNode"), + (BuiltinNodeTypes.HUMAN_INPUT, "HumanInputNode"), + (KNOWLEDGE_INDEX_NODE_TYPE, "KnowledgeIndexNode"), + (BuiltinNodeTypes.DATASOURCE, "DatasourceNode"), + (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), + (BuiltinNodeTypes.DOCUMENT_EXTRACTOR, "DocumentExtractorNode"), + ], + ) + def test_creates_specialized_nodes(self, monkeypatch, factory, node_type, constructor_name): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(return_value=constructor), + ) + + if constructor_name == "HumanInputNode": + form_repository = sentinel.form_repository + form_repository_impl = MagicMock(return_value=form_repository) + monkeypatch.setattr( + node_factory, + "HumanInputFormRepositoryImpl", + form_repository_impl, + ) + + node_config = {"id": "node-id", "data": {"type": node_type}} + result = factory.create_node(node_config) + + assert result is created_node + kwargs = constructor.call_args.kwargs + assert kwargs["id"] == "node-id" + _assert_typed_node_config(kwargs["config"], node_id="node-id", node_type=node_type) + assert kwargs["graph_init_params"] is sentinel.graph_init_params + assert kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + + if constructor_name == "CodeNode": + assert kwargs["code_executor"] is sentinel.code_executor + assert kwargs["code_limits"] is sentinel.code_limits + elif constructor_name == "TemplateTransformNode": + assert kwargs["template_renderer"] is sentinel.template_renderer + assert kwargs["max_output_length"] == 2048 + elif constructor_name == "HttpRequestNode": + assert kwargs["http_request_config"] is sentinel.http_request_config + assert kwargs["http_client"] is sentinel.http_client + assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory + assert kwargs["file_manager"] is sentinel.file_manager + elif constructor_name == "HumanInputNode": + assert kwargs["form_repository"] is form_repository + form_repository_impl.assert_called_once_with(tenant_id="tenant-id") + elif constructor_name == "DocumentExtractorNode": + assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config + assert kwargs["http_client"] is sentinel.http_client + + @pytest.mark.parametrize( + ("node_type", "constructor_name", "expected_extra_kwargs"), + [ + ( + BuiltinNodeTypes.LLM, + "LLMNode", + { + "http_client": sentinel.http_client, + "template_renderer": sentinel.llm_template_renderer, + }, + ), + ( + BuiltinNodeTypes.QUESTION_CLASSIFIER, + "QuestionClassifierNode", + { + "http_client": sentinel.http_client, + "template_renderer": sentinel.llm_template_renderer, + }, + ), + (BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}), + ], + ) + def test_creates_model_backed_nodes( + self, + monkeypatch, + factory, + node_type, + constructor_name, + expected_extra_kwargs, + ): + created_node = object() + constructor = MagicMock(name=constructor_name, return_value=created_node) + monkeypatch.setattr( + factory, + "_resolve_node_class", + MagicMock(return_value=constructor), + ) + llm_init_kwargs = { + "credentials_provider": sentinel.credentials_provider, + "model_factory": sentinel.model_factory, + "model_instance": sentinel.model_instance, + "memory": sentinel.memory, + **expected_extra_kwargs, + } + build_llm_init_kwargs = MagicMock(return_value=llm_init_kwargs) + factory._build_llm_compatible_node_init_kwargs = build_llm_init_kwargs + + node_config = {"id": "node-id", "data": {"type": node_type}} + result = factory.create_node(node_config) + + assert result is created_node + build_llm_init_kwargs.assert_called_once() + helper_kwargs = build_llm_init_kwargs.call_args.kwargs + assert helper_kwargs["node_class"] is constructor + assert isinstance(helper_kwargs["node_data"], BaseNodeData) + assert helper_kwargs["node_data"].type == node_type + assert helper_kwargs["include_http_client"] is (node_type != BuiltinNodeTypes.PARAMETER_EXTRACTOR) + + constructor_kwargs = constructor.call_args.kwargs + assert constructor_kwargs["id"] == "node-id" + _assert_typed_node_config(constructor_kwargs["config"], node_id="node-id", node_type=node_type) + assert constructor_kwargs["graph_init_params"] is sentinel.graph_init_params + assert constructor_kwargs["graph_runtime_state"] is sentinel.graph_runtime_state + assert constructor_kwargs["credentials_provider"] is sentinel.credentials_provider + assert constructor_kwargs["model_factory"] is sentinel.model_factory + assert constructor_kwargs["model_instance"] is sentinel.model_instance + assert constructor_kwargs["memory"] is sentinel.memory + for key, value in expected_extra_kwargs.items(): + assert constructor_kwargs[key] is value + + +class TestDifyNodeFactoryModelInstance: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._llm_credentials_provider = MagicMock() + factory._llm_model_factory = MagicMock() + return factory + + @pytest.fixture + def llm_model_setup(self, factory): + def _configure( + *, + completion_params=None, + has_provider_model=True, + model_schema=sentinel.model_schema, + ): + credentials = {"api_key": "secret"} + node_data_model = SimpleNamespace( + provider="provider", + name="model", + mode="chat", + completion_params=completion_params or {}, + ) + node_data = SimpleNamespace(model=node_data_model) + provider_model = MagicMock() if has_provider_model else None + provider_model_bundle = SimpleNamespace( + configuration=SimpleNamespace(get_provider_model=MagicMock(return_value=provider_model)) + ) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + provider_model_bundle=provider_model_bundle, + model_type_instance=model_type_instance, + provider=None, + model_name=None, + credentials=None, + parameters=None, + stop=None, + ) + factory._llm_credentials_provider.fetch.return_value = credentials + factory._llm_model_factory.init_model_instance.return_value = model_instance + return SimpleNamespace( + node_data=node_data, + credentials=credentials, + provider_model=provider_model, + model_type_instance=model_type_instance, + model_instance=model_instance, + ) + + return _configure + + def test_requires_llm_mode(self, factory): + node_data = SimpleNamespace( + model=SimpleNamespace( + provider="provider", + name="model", + mode="", + completion_params={}, + ) + ) + + with pytest.raises(node_factory.LLMModeRequiredError, match="LLM mode is required"): + factory._build_model_instance_for_llm_node(node_data) + + def test_raises_when_provider_model_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(has_provider_model=False) + + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) + + def test_raises_when_model_schema_is_missing(self, factory, llm_model_setup): + setup = llm_model_setup(model_schema=None) + + with pytest.raises(node_factory.ModelNotExistError, match="Model model not exist"): + factory._build_model_instance_for_llm_node(setup.node_data) + + setup.provider_model.raise_for_status.assert_called_once() + + def test_builds_model_instance_and_normalizes_stop_tokens(self, factory, llm_model_setup): + setup = llm_model_setup( + completion_params={"temperature": 0.3, "stop": "not-a-list"}, + model_schema={"schema": "value"}, + ) + + result = factory._build_model_instance_for_llm_node(setup.node_data) + + assert result is setup.model_instance + assert result.provider == "provider" + assert result.model_name == "model" + assert result.credentials == setup.credentials + assert result.parameters == {"temperature": 0.3} + assert result.stop == () + assert result.model_type_instance is setup.model_type_instance + setup.provider_model.raise_for_status.assert_called_once() + + +class TestDifyNodeFactoryMemory: + @pytest.fixture + def factory(self): + factory = object.__new__(node_factory.DifyNodeFactory) + factory._dify_context = SimpleNamespace(app_id="app-id") + factory.graph_runtime_state = SimpleNamespace(variable_pool=MagicMock()) + return factory + + def test_returns_none_when_memory_is_not_configured(self, factory): + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=None), + model_instance=sentinel.model_instance, + ) + + assert result is None + factory.graph_runtime_state.variable_pool.get.assert_not_called() + + def test_uses_string_segment_conversation_id(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = StringSegment(value="conversation-id") + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + factory.graph_runtime_state.variable_pool.get.assert_called_once_with( + ["sys", SystemVariableKey.CONVERSATION_ID] + ) + fetch_memory.assert_called_once_with( + conversation_id="conversation-id", + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) + + def test_ignores_non_string_segment_conversation_ids(self, monkeypatch, factory): + memory_config = sentinel.memory_config + factory.graph_runtime_state.variable_pool.get.return_value = sentinel.segment + fetch_memory = MagicMock(return_value=sentinel.memory) + monkeypatch.setattr(node_factory, "fetch_memory", fetch_memory) + + result = factory._build_memory_for_llm_node( + node_data=SimpleNamespace(memory=memory_config), + model_instance=sentinel.model_instance, + ) + + assert result is sentinel.memory + fetch_memory.assert_called_once_with( + conversation_id=None, + app_id="app-id", + node_data_memory=memory_config, + model_instance=sentinel.model_instance, + ) diff --git a/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py new file mode 100644 index 0000000000..8de45257ec --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_node_mapping_bootstrap.py @@ -0,0 +1,43 @@ +import os +import subprocess +import sys +import textwrap +from pathlib import Path + + +def test_moved_core_nodes_resolve_after_importing_production_entrypoints(): + api_root = Path(__file__).resolve().parents[4] + script = textwrap.dedent( + """ + from core.app.apps import workflow_app_runner + from core.workflow import workflow_entry + from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE + from core.workflow.node_factory import DifyNodeFactory, NODE_TYPE_CLASSES_MAPPING + from dify_graph.enums import BuiltinNodeTypes + from services import workflow_service + from services.rag_pipeline import rag_pipeline + + _ = workflow_entry, workflow_app_runner, workflow_service, rag_pipeline + + expected = ( + BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, + KNOWLEDGE_INDEX_NODE_TYPE, + BuiltinNodeTypes.DATASOURCE, + ) + + for node_type in expected: + assert node_type in NODE_TYPE_CLASSES_MAPPING, node_type + resolved = DifyNodeFactory._resolve_node_class(node_type=node_type, node_version="1") + assert resolved.__module__.startswith("core.workflow.nodes."), resolved.__module__ + """ + ) + completed = subprocess.run( + [sys.executable, "-c", script], + cwd=api_root, + env=os.environ.copy(), + capture_output=True, + text=True, + check=False, + ) + + assert completed.returncode == 0, completed.stderr or completed.stdout diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py index 93e7c9f68d..8023a0b594 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -4,9 +4,9 @@ from typing import Any import pytest from pydantic import ValidationError -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.system_variable import SystemVariable +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.system_variable import SystemVariable # Test data constants for SystemVariable serialization tests VALID_BASE_DATA: dict[str, Any] = { diff --git a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py index 743fecaed0..b7a8f2551d 100644 --- a/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py +++ b/api/tests/unit_tests/core/workflow/test_system_variable_read_only_view.py @@ -2,8 +2,8 @@ from typing import cast import pytest -from core.workflow.file.models import File, FileTransferMethod, FileType -from core.workflow.system_variable import SystemVariable, SystemVariableReadOnlyView +from dify_graph.file.models import File, FileTransferMethod, FileType +from dify_graph.system_variable import SystemVariable, SystemVariableReadOnlyView class TestSystemVariableReadOnlyView: diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index 7f2b080498..0fa0d26114 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -3,12 +3,12 @@ from collections import defaultdict import pytest -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables import FileSegment, StringSegment -from core.workflow.variables.segments import ( +from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables import FileSegment, StringSegment +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -19,7 +19,7 @@ from core.workflow.variables.segments import ( NoneSegment, ObjectSegment, ) -from core.workflow.variables.variables import ( +from dify_graph.variables.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry.py b/api/tests/unit_tests/core/workflow/test_workflow_entry.py index 4a71692f1e..93ba7f3333 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry.py @@ -4,18 +4,19 @@ import pytest from configs import dify_config from core.helper.code_executor.code_executor import CodeLanguage -from core.workflow.constants import ( +from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.constants import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from core.workflow.file.enums import FileType -from core.workflow.file.models import File, FileTransferMethod -from core.workflow.nodes.code.code_node import CodeNode -from core.workflow.nodes.code.limits import CodeNodeLimits -from core.workflow.runtime import VariablePool -from core.workflow.system_variable import SystemVariable -from core.workflow.variables.variables import StringVariable -from core.workflow.workflow_entry import WorkflowEntry +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.file.enums import FileType +from dify_graph.file.models import File, FileTransferMethod +from dify_graph.nodes.code.code_node import CodeNode +from dify_graph.nodes.code.limits import CodeNodeLimits +from dify_graph.runtime import VariablePool +from dify_graph.system_variable import SystemVariable +from dify_graph.variables.variables import StringVariable @pytest.fixture(autouse=True) @@ -124,7 +125,7 @@ class TestWorkflowEntry: def get_node_config_by_id(self, target_id: str): assert target_id == node_id - return node_config + return NodeConfigDictAdapter.validate_python(node_config) workflow = StubWorkflow() variable_pool = VariablePool(system_variables=SystemVariable.default(), user_inputs={}) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py new file mode 100644 index 0000000000..dc4c7a00c5 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_helpers.py @@ -0,0 +1,657 @@ +from collections import UserString +from types import SimpleNamespace +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from core.app.apps.exc import GenerateTaskStoppedError +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow import workflow_entry +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import NodeType +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.graph_events import GraphRunFailedEvent +from dify_graph.nodes import BuiltinNodeTypes +from dify_graph.runtime import ChildGraphNotFoundError + + +def _build_typed_node_config(node_type: NodeType): + return NodeConfigDictAdapter.validate_python({"id": "node-id", "data": {"type": node_type}}) + + +class TestWorkflowChildEngineBuilder: + @pytest.mark.parametrize( + ("graph_config", "node_id", "expected"), + [ + ({"nodes": [{"id": "root"}]}, "root", True), + ({"nodes": [{"id": "root"}]}, "other", False), + ({"nodes": "invalid"}, "root", None), + ({"nodes": ["invalid"]}, "root", None), + ], + ) + def test_has_node_id(self, graph_config, node_id, expected): + result = workflow_entry._WorkflowChildEngineBuilder._has_node_id(graph_config, node_id) + + assert result is expected + + def test_build_child_engine_raises_when_root_node_is_missing(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + + with patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory): + with pytest.raises(ChildGraphNotFoundError, match="child graph root node 'missing' not found"): + builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": []}, + root_node_id="missing", + ) + + def test_build_child_engine_constructs_graph_engine_and_layers(self): + builder = workflow_entry._WorkflowChildEngineBuilder() + child_graph = sentinel.child_graph + child_engine = MagicMock() + quota_layer = sentinel.quota_layer + additional_layers = [sentinel.layer_one, sentinel.layer_two] + + with ( + patch.object(workflow_entry, "DifyNodeFactory", return_value=sentinel.factory) as dify_node_factory, + patch.object(workflow_entry.Graph, "init", return_value=child_graph) as graph_init, + patch.object(workflow_entry, "GraphEngine", return_value=child_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "LLMQuotaLayer", return_value=quota_layer), + ): + result = builder.build_child_engine( + workflow_id="workflow-id", + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + graph_config={"nodes": [{"id": "root"}]}, + root_node_id="root", + layers=additional_layers, + ) + + assert result is child_engine + dify_node_factory.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + graph_init.assert_called_once_with( + graph_config={"nodes": [{"id": "root"}]}, + node_factory=sentinel.factory, + root_node_id="root", + ) + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id", + graph=child_graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=builder, + ) + assert child_engine.layer.call_args_list == [ + ((quota_layer,), {}), + ((sentinel.layer_one,), {}), + ((sentinel.layer_two,), {}), + ] + + +class TestWorkflowEntryInit: + def test_rejects_call_depth_above_limit(self): + call_depth = workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH + 1 + + with pytest.raises(ValueError, match="Max workflow call depth"): + workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=call_depth, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + ) + + def test_applies_debug_and_observability_layers(self): + graph_engine = MagicMock() + debug_layer = sentinel.debug_layer + execution_limits_layer = sentinel.execution_limits_layer + llm_quota_layer = sentinel.llm_quota_layer + observability_layer = sentinel.observability_layer + + with ( + patch.object(workflow_entry.dify_config, "DEBUG", True), + patch.object(workflow_entry.dify_config, "ENABLE_OTEL", False), + patch.object(workflow_entry, "is_instrument_flag_enabled", return_value=True), + patch.object(workflow_entry, "GraphEngine", return_value=graph_engine) as graph_engine_cls, + patch.object(workflow_entry, "GraphEngineConfig", return_value=sentinel.graph_engine_config), + patch.object(workflow_entry, "InMemoryChannel", return_value=sentinel.command_channel), + patch.object(workflow_entry, "DebugLoggingLayer", return_value=debug_layer) as debug_logging_layer, + patch.object( + workflow_entry, + "ExecutionLimitsLayer", + return_value=execution_limits_layer, + ) as execution_limits_layer_cls, + patch.object(workflow_entry, "LLMQuotaLayer", return_value=llm_quota_layer), + patch.object(workflow_entry, "ObservabilityLayer", return_value=observability_layer), + ): + entry = workflow_entry.WorkflowEntry( + tenant_id="tenant-id", + app_id="app-id", + workflow_id="workflow-id-123456", + graph_config={"nodes": [], "edges": []}, + graph=sentinel.graph, + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + variable_pool=sentinel.variable_pool, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=None, + ) + + assert entry.command_channel is sentinel.command_channel + graph_engine_cls.assert_called_once_with( + workflow_id="workflow-id-123456", + graph=sentinel.graph, + graph_runtime_state=sentinel.graph_runtime_state, + command_channel=sentinel.command_channel, + config=sentinel.graph_engine_config, + child_engine_builder=entry._child_engine_builder, + ) + debug_logging_layer.assert_called_once_with( + level="DEBUG", + include_inputs=True, + include_outputs=True, + include_process_data=False, + logger_name="GraphEngine.Debug.workflow", + ) + execution_limits_layer_cls.assert_called_once_with( + max_steps=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_STEPS, + max_time=workflow_entry.dify_config.WORKFLOW_MAX_EXECUTION_TIME, + ) + assert graph_engine.layer.call_args_list == [ + ((debug_layer,), {}), + ((execution_limits_layer,), {}), + ((llm_quota_layer,), {}), + ((observability_layer,), {}), + ] + + +class TestWorkflowEntryRun: + def test_run_swallows_generate_task_stopped_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = GenerateTaskStoppedError() + + assert list(entry.run()) == [] + + def test_run_emits_failed_event_for_unexpected_errors(self): + entry = object.__new__(workflow_entry.WorkflowEntry) + entry.graph_engine = MagicMock() + entry.graph_engine.run.side_effect = RuntimeError("boom") + + events = list(entry.run()) + + assert len(events) == 1 + assert isinstance(events[0], GraphRunFailedEvent) + assert events[0].error == "boom" + + +class TestWorkflowEntrySingleStepRun: + def test_uses_empty_mapping_when_selector_extraction_is_not_implemented(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.START), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once_with( + variable_loader=workflow_entry.DUMMY_VARIABLE_LOADER, + variable_pool=sentinel.variable_pool, + variable_mapping={}, + user_inputs={"question": "hello"}, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_skips_user_input_mapping_for_datasource_nodes(self): + class FakeDatasourceNode: + id = "node-id" + node_type = "datasource" + + @staticmethod + def version(): + return "1" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {"question": ["node", "question"]} + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool") as load_into_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeDatasourceNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.DATASOURCE), + ) + + node, generator = workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + load_into_variable_pool.assert_called_once() + mapping_user_inputs_to_variable_pool.assert_not_called() + + def test_wraps_traced_node_run_failures(self): + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "fake" + + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + @staticmethod + def version(): + return "1" + + with ( + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory") as dify_node_factory, + patch.object(workflow_entry, "load_into_variable_pool"), + patch.object(workflow_entry.WorkflowEntry, "mapping_user_inputs_to_variable_pool"), + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + side_effect=RuntimeError("boom"), + ), + ): + dify_node_factory.return_value.create_node.return_value = FakeNode() + workflow = SimpleNamespace( + tenant_id="tenant-id", + app_id="app-id", + id="workflow-id", + graph_dict={"nodes": [], "edges": []}, + get_node_config_by_id=lambda _node_id: _build_typed_node_config(BuiltinNodeTypes.START), + ) + + with pytest.raises(WorkflowNodeRunFailedError): + workflow_entry.WorkflowEntry.single_step_run( + workflow=workflow, + node_id="node-id", + user_id="user-id", + user_inputs={}, + variable_pool=sentinel.variable_pool, + ) + + +class TestWorkflowEntryHelpers: + def test_create_single_node_graph_builds_start_edge(self): + graph = workflow_entry.WorkflowEntry._create_single_node_graph( + node_id="target-node", + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}, + node_width=320, + node_height=180, + ) + + assert graph["nodes"][0]["id"] == "start" + assert graph["nodes"][1]["id"] == "target-node" + assert graph["nodes"][1]["width"] == 320 + assert graph["nodes"][1]["height"] == 180 + assert graph["edges"] == [ + { + "source": "start", + "target": "target-node", + "sourceHandle": "source", + "targetHandle": "target", + } + ] + + def test_run_free_node_rejects_unsupported_types(self): + with pytest.raises(ValueError, match="Node type start not supported"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": BuiltinNodeTypes.START}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_rejects_missing_node_class(self, monkeypatch): + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=None), + ) + + with pytest.raises(ValueError, match="Node class not found for node type parameter-extractor"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={}, + ) + + def test_run_free_node_uses_empty_mapping_when_selector_extraction_is_not_implemented(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + raise NotImplementedError + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool) as variable_pool_cls, + patch.object( + workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params + ) as graph_init_params, + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object( + workflow_entry, "build_dify_run_context", return_value={"_dify": "context"} + ) as build_dify_run_context, + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory) as dify_node_factory_cls, + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + ) as mapping_user_inputs_to_variable_pool, + patch.object( + workflow_entry.WorkflowEntry, + "_traced_node_run", + return_value=iter(["event"]), + ), + ): + node, generator = workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + assert node.id == "node-id" + assert list(generator) == ["event"] + variable_pool_cls.assert_called_once_with( + system_variables=sentinel.system_variables, + user_inputs={}, + environment_variables=[], + ) + build_dify_run_context.assert_called_once_with( + tenant_id="tenant-id", + app_id="", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + graph_init_params.assert_called_once_with( + workflow_id="", + graph_config=workflow_entry.WorkflowEntry._create_single_node_graph( + "node-id", {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"} + ), + run_context={"_dify": "context"}, + call_depth=0, + ) + dify_node_factory_cls.assert_called_once_with( + graph_init_params=sentinel.graph_init_params, + graph_runtime_state=sentinel.graph_runtime_state, + ) + mapping_user_inputs_to_variable_pool.assert_called_once_with( + variable_mapping={}, + user_inputs={"question": "hello"}, + variable_pool=sentinel.variable_pool, + tenant_id="tenant-id", + ) + + def test_run_free_node_wraps_execution_failures(self, monkeypatch): + class FakeNodeClass: + @staticmethod + def extract_variable_selector_to_variable_mapping(**_kwargs): + return {} + + class FakeNode: + id = "node-id" + title = "Node Title" + node_type = "parameter-extractor" + + @staticmethod + def version(): + return "1" + + dify_node_factory = MagicMock() + dify_node_factory.create_node.return_value = FakeNode() + monkeypatch.setattr( + workflow_entry, + "resolve_workflow_node_class", + MagicMock(return_value=FakeNodeClass), + ) + + with ( + patch.object(workflow_entry.SystemVariable, "default", return_value=sentinel.system_variables), + patch.object(workflow_entry, "VariablePool", return_value=sentinel.variable_pool), + patch.object(workflow_entry, "GraphInitParams", return_value=sentinel.graph_init_params), + patch.object(workflow_entry, "GraphRuntimeState", return_value=sentinel.graph_runtime_state), + patch.object(workflow_entry, "build_dify_run_context", return_value={"_dify": "context"}), + patch.object(workflow_entry.time, "perf_counter", return_value=123.0), + patch.object(workflow_entry, "DifyNodeFactory", return_value=dify_node_factory), + patch.object( + workflow_entry.WorkflowEntry, + "mapping_user_inputs_to_variable_pool", + side_effect=RuntimeError("boom"), + ), + ): + with pytest.raises(WorkflowNodeRunFailedError, match="Node Title run failed: boom"): + workflow_entry.WorkflowEntry.run_free_node( + node_data={"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR, "title": "Node"}, + node_id="node-id", + tenant_id="tenant-id", + user_id="user-id", + user_inputs={"question": "hello"}, + ) + + def test_handle_special_values_serializes_nested_files(self): + file = File( + tenant_id="tenant-id", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.png", + filename="image.png", + extension=".png", + ) + + result = workflow_entry.WorkflowEntry.handle_special_values({"file": file, "nested": {"files": [file]}}) + + assert result == { + "file": file.to_dict(), + "nested": {"files": [file.to_dict()]}, + } + + def test_handle_special_values_returns_none_for_none(self): + assert workflow_entry.WorkflowEntry._handle_special_values(None) is None + + def test_handle_special_values_returns_scalar_as_is(self): + assert workflow_entry.WorkflowEntry._handle_special_values("plain-text") == "plain-text" + + +class TestMappingUserInputsBranches: + def test_rejects_invalid_node_variable_key(self): + class EmptySplitKey(UserString): + def split(self, _sep=None): + return [] + + with pytest.raises(ValueError, match="Invalid node variable broken"): + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={EmptySplitKey("broken"): ["node", "input"]}, + user_inputs={}, + variable_pool=MagicMock(), + tenant_id="tenant-id", + ) + + def test_skips_none_user_input_when_variable_already_exists(self): + variable_pool = MagicMock() + variable_pool.get.return_value = None + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.input": ["target", "input"]}, + user_inputs={"node.input": None}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_not_called() + + def test_merges_structured_output_values(self): + variable_pool = MagicMock() + variable_pool.get.side_effect = [ + None, + SimpleNamespace(value={"existing": "value"}), + ] + + workflow_entry.WorkflowEntry.mapping_user_inputs_to_variable_pool( + variable_mapping={"node.answer": ["target", "structured_output", "answer"]}, + user_inputs={"node.answer": "new-value"}, + variable_pool=variable_pool, + tenant_id="tenant-id", + ) + + variable_pool.add.assert_called_once_with( + ["target", "structured_output"], + {"existing": "value", "answer": "new-value"}, + ) + + +class TestWorkflowEntryTracing: + def test_traced_node_run_reports_success(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + yield "event" + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + events = list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert events == ["event"] + layer.on_graph_start.assert_called_once_with() + layer.on_node_run_start.assert_called_once() + layer.on_node_run_end.assert_called_once_with( + layer.on_node_run_start.call_args.args[0], + None, + ) + + def test_traced_node_run_reports_errors(self): + layer = MagicMock() + + class FakeNode: + def ensure_execution_id(self): + return None + + def run(self): + raise RuntimeError("boom") + yield + + with patch.object(workflow_entry, "ObservabilityLayer", return_value=layer): + with pytest.raises(RuntimeError, match="boom"): + list(workflow_entry.WorkflowEntry._traced_node_run(FakeNode())) + + assert isinstance(layer.on_node_run_end.call_args.args[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py index 12b9bf5f14..9969c953e8 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_entry_redis_channel.py @@ -2,11 +2,10 @@ from unittest.mock import MagicMock, patch -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.workflow_entry import WorkflowEntry -from models.enums import UserFrom +from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel +from dify_graph.runtime import GraphRuntimeState, VariablePool class TestWorkflowEntryRedisChannel: diff --git a/api/tests/unit_tests/core/workflow/utils/test_condition.py b/api/tests/unit_tests/core/workflow/utils/test_condition.py index efedf88726..324ad5f674 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_condition.py +++ b/api/tests/unit_tests/core/workflow/utils/test_condition.py @@ -1,6 +1,6 @@ -from core.workflow.runtime import VariablePool -from core.workflow.utils.condition.entities import Condition -from core.workflow.utils.condition.processor import ConditionProcessor +from dify_graph.runtime import VariablePool +from dify_graph.utils.condition.entities import Condition +from dify_graph.utils.condition.processor import ConditionProcessor def test_number_formatting(): diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py index 83867e22e4..40df9de7fa 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_template_parser.py @@ -1,7 +1,7 @@ import dataclasses -from core.workflow.nodes.base import variable_template_parser -from core.workflow.nodes.base.entities import VariableSelector +from dify_graph.nodes.base import variable_template_parser +from dify_graph.nodes.base.entities import VariableSelector def test_extract_selectors_from_template(): diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py similarity index 91% rename from api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py index 5fbdabceed..d42b7ca0d9 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_increase_tool_call.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock, patch import pytest -from core.model_runtime.entities.message_entities import AssistantPromptMessage -from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call +from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage +from dify_graph.model_runtime.model_providers.__base.large_language_model import _increase_tool_call ToolCall = AssistantPromptMessage.ToolCall @@ -97,7 +97,9 @@ def test__increase_tool_call(): # case 4: mock_id_generator = MagicMock() mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4] - with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator): + with patch( + "dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator + ): _run_case(INPUTS_CASE_4, EXPECTED_CASE_4) @@ -107,6 +109,6 @@ def test__increase_tool_call__no_id_no_name_first_delta_should_raise(): ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='"value"}')), ] actual: list[ToolCall] = [] - with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): + with patch("dify_graph.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", MagicMock()): with pytest.raises(ValueError): _increase_tool_call(inputs, actual) diff --git a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py similarity index 93% rename from api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py rename to api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py index 09d527cb12..8dcfd10ec6 100644 --- a/api/tests/unit_tests/core/model_runtime/__base/test_large_language_model_non_stream_parsing.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/__base/test_large_language_model_non_stream_parsing.py @@ -1,10 +1,10 @@ -from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result +from dify_graph.model_runtime.model_providers.__base.large_language_model import _normalize_non_stream_plugin_result def _make_chunk( diff --git a/api/tests/unit_tests/dify_graph/model_runtime/__init__.py b/api/tests/unit_tests/dify_graph/model_runtime/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py new file mode 100644 index 0000000000..2410d16d63 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_base_callback.py @@ -0,0 +1,964 @@ +"""Comprehensive unit tests for core/model_runtime/callbacks/base_callback.py""" + +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.base_callback import ( + _TEXT_COLOR_MAPPING, + Callback, +) +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool + +# --------------------------------------------------------------------------- +# Concrete implementation of the abstract Callback for testing +# --------------------------------------------------------------------------- + + +class ConcreteCallback(Callback): + """A minimal concrete subclass that satisfies all abstract methods.""" + + def __init__(self, raise_error: bool = False): + self.raise_error = raise_error + # Track invocations + self.before_invoke_calls: list[dict] = [] + self.new_chunk_calls: list[dict] = [] + self.after_invoke_calls: list[dict] = [] + self.invoke_error_calls: list[dict] = [] + + def on_before_invoke( + self, + llm_instance, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.before_invoke_calls.append( + { + "llm_instance": llm_instance, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + # To cover the 'raise NotImplementedError()' in the base class + try: + super().on_before_invoke( + llm_instance, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_new_chunk( + self, + llm_instance, + chunk, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.new_chunk_calls.append( + { + "llm_instance": llm_instance, + "chunk": chunk, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_new_chunk( + llm_instance, chunk, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_after_invoke( + self, + llm_instance, + result, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.after_invoke_calls.append( + { + "llm_instance": llm_instance, + "result": result, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_after_invoke( + llm_instance, result, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + def on_invoke_error( + self, + llm_instance, + ex, + model, + credentials, + prompt_messages, + model_parameters, + tools=None, + stop=None, + stream=True, + user=None, + ): + self.invoke_error_calls.append( + { + "llm_instance": llm_instance, + "ex": ex, + "model": model, + "credentials": credentials, + "prompt_messages": prompt_messages, + "model_parameters": model_parameters, + "tools": tools, + "stop": stop, + "stream": stream, + "user": user, + } + ) + try: + super().on_invoke_error( + llm_instance, ex, model, credentials, prompt_messages, model_parameters, tools, stop, stream, user + ) + except NotImplementedError: + pass + + +# --------------------------------------------------------------------------- +# A subclass that deliberately leaves abstract methods un-implemented, +# used to verify that instantiation raises TypeError. +# --------------------------------------------------------------------------- + + +# =========================================================================== +# Tests for _TEXT_COLOR_MAPPING module-level constant +# =========================================================================== + + +class TestTextColorMapping: + """Tests for the module-level _TEXT_COLOR_MAPPING dictionary.""" + + def test_contains_all_expected_colors(self): + expected_keys = {"blue", "yellow", "pink", "green", "red"} + assert set(_TEXT_COLOR_MAPPING.keys()) == expected_keys + + def test_blue_escape_code(self): + assert _TEXT_COLOR_MAPPING["blue"] == "36;1" + + def test_yellow_escape_code(self): + assert _TEXT_COLOR_MAPPING["yellow"] == "33;1" + + def test_pink_escape_code(self): + assert _TEXT_COLOR_MAPPING["pink"] == "38;5;200" + + def test_green_escape_code(self): + assert _TEXT_COLOR_MAPPING["green"] == "32;1" + + def test_red_escape_code(self): + assert _TEXT_COLOR_MAPPING["red"] == "31;1" + + def test_mapping_is_dict(self): + assert isinstance(_TEXT_COLOR_MAPPING, dict) + + def test_all_values_are_strings(self): + for key, value in _TEXT_COLOR_MAPPING.items(): + assert isinstance(value, str), f"Value for {key!r} should be str" + + +# =========================================================================== +# Tests for the Callback ABC itself +# =========================================================================== + + +class TestCallbackAbstract: + """Tests verifying Callback is a proper ABC.""" + + def test_cannot_instantiate_abstract_class_directly(self): + """Callback cannot be instantiated since it has abstract methods.""" + with pytest.raises(TypeError): + Callback() # type: ignore[abstract] + + def test_concrete_subclass_can_be_instantiated(self): + cb = ConcreteCallback() + assert isinstance(cb, Callback) + + def test_default_raise_error_is_false(self): + cb = ConcreteCallback() + assert cb.raise_error is False + + def test_raise_error_can_be_set_to_true(self): + cb = ConcreteCallback(raise_error=True) + assert cb.raise_error is True + + def test_subclass_missing_on_before_invoke_raises_type_error(self): + """A subclass missing any single abstract method cannot be instantiated.""" + + class IncompleteCallback(Callback): + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_new_chunk_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_after_invoke_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_invoke_error(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + def test_subclass_missing_on_invoke_error_raises_type_error(self): + class IncompleteCallback(Callback): + def on_before_invoke(self, *a, **kw): ... + def on_new_chunk(self, *a, **kw): ... + def on_after_invoke(self, *a, **kw): ... + + with pytest.raises(TypeError): + IncompleteCallback() # type: ignore[abstract] + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for the on_before_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.model = "gpt-4" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 0.7} + + def test_on_before_invoke_called_with_required_args(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 1 + call = self.cb.before_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["model"] == self.model + assert call["credentials"] == self.credentials + assert call["prompt_messages"] is self.prompt_messages + assert call["model_parameters"] is self.model_parameters + + def test_on_before_invoke_defaults_tools_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["tools"] is None + + def test_on_before_invoke_defaults_stop_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stop"] is None + + def test_on_before_invoke_defaults_stream_true(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["stream"] is True + + def test_on_before_invoke_defaults_user_none(self): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.before_invoke_calls[0]["user"] is None + + def test_on_before_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["stop1", "stop2"] + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="user-123", + ) + call = self.cb.before_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "user-123" + + def test_on_before_invoke_called_multiple_times(self): + for i in range(3): + self.cb.on_before_invoke( + llm_instance=self.llm_instance, + model=f"model-{i}", + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.before_invoke_calls) == 3 + assert self.cb.before_invoke_calls[2]["model"] == "model-2" + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for the on_new_chunk callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.chunk = MagicMock(spec=LLMResultChunk) + self.model = "gpt-3.5-turbo" + self.credentials = {"api_key": "sk-test"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"max_tokens": 256} + + def test_on_new_chunk_called_with_required_args(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 1 + call = self.cb.new_chunk_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["chunk"] is self.chunk + assert call["model"] == self.model + assert call["credentials"] == self.credentials + + def test_on_new_chunk_defaults_tools_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["tools"] is None + + def test_on_new_chunk_defaults_stop_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stop"] is None + + def test_on_new_chunk_defaults_stream_true(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["stream"] is True + + def test_on_new_chunk_defaults_user_none(self): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.new_chunk_calls[0]["user"] is None + + def test_on_new_chunk_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["END"] + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="chunk-user", + ) + call = self.cb.new_chunk_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "chunk-user" + + def test_on_new_chunk_called_multiple_times(self): + for i in range(5): + self.cb.on_new_chunk( + llm_instance=self.llm_instance, + chunk=self.chunk, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.new_chunk_calls) == 5 + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for the on_after_invoke callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.result = MagicMock(spec=LLMResult) + self.model = "claude-3" + self.credentials = {"api_key": "anthropic-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"temperature": 1.0} + + def test_on_after_invoke_called_with_required_args(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.after_invoke_calls) == 1 + call = self.cb.after_invoke_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["result"] is self.result + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_after_invoke_defaults_tools_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["tools"] is None + + def test_on_after_invoke_defaults_stop_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stop"] is None + + def test_on_after_invoke_defaults_stream_true(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["stream"] is True + + def test_on_after_invoke_defaults_user_none(self): + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.after_invoke_calls[0]["user"] is None + + def test_on_after_invoke_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["STOP"] + self.cb.on_after_invoke( + llm_instance=self.llm_instance, + result=self.result, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="after-user", + ) + call = self.cb.after_invoke_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "after-user" + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for the on_invoke_error callback method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + self.llm_instance = MagicMock() + self.ex = ValueError("something went wrong") + self.model = "gemini-pro" + self.credentials = {"api_key": "google-key"} + self.prompt_messages = [MagicMock(spec=PromptMessage)] + self.model_parameters = {"top_p": 0.9} + + def test_on_invoke_error_called_with_required_args(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 1 + call = self.cb.invoke_error_calls[0] + assert call["llm_instance"] is self.llm_instance + assert call["ex"] is self.ex + assert call["model"] == self.model + assert call["credentials"] is self.credentials + + def test_on_invoke_error_defaults_tools_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["tools"] is None + + def test_on_invoke_error_defaults_stop_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stop"] is None + + def test_on_invoke_error_defaults_stream_true(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["stream"] is True + + def test_on_invoke_error_defaults_user_none(self): + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert self.cb.invoke_error_calls[0]["user"] is None + + def test_on_invoke_error_with_all_optional_args(self): + tools = [MagicMock(spec=PromptMessageTool)] + stop = ["HALT"] + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=self.ex, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + tools=tools, + stop=stop, + stream=False, + user="error-user", + ) + call = self.cb.invoke_error_calls[0] + assert call["tools"] is tools + assert call["stop"] == stop + assert call["stream"] is False + assert call["user"] == "error-user" + + def test_on_invoke_error_accepts_various_exception_types(self): + for exc in [RuntimeError("r"), KeyError("k"), Exception("e")]: + self.cb.on_invoke_error( + llm_instance=self.llm_instance, + ex=exc, + model=self.model, + credentials=self.credentials, + prompt_messages=self.prompt_messages, + model_parameters=self.model_parameters, + ) + assert len(self.cb.invoke_error_calls) == 3 + + +# =========================================================================== +# Tests for print_text (concrete method on Callback) +# =========================================================================== + + +class TestPrintText: + """Tests for the concrete print_text method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + def test_print_text_without_color_prints_plain_text(self, capsys): + self.cb.print_text("hello world") + captured = capsys.readouterr() + assert captured.out == "hello world" + + def test_print_text_with_color_prints_colored_text(self, capsys): + self.cb.print_text("colored text", color="blue") + captured = capsys.readouterr() + # Should contain ANSI escape sequences + assert "colored text" in captured.out + assert "\001b[" in captured.out or "\033[" in captured.out or "\x1b[" in captured.out + + def test_print_text_without_color_no_ansi(self, capsys): + self.cb.print_text("plain text", color=None) + captured = capsys.readouterr() + assert captured.out == "plain text" + # No ANSI escape sequences + assert "\x1b" not in captured.out + + def test_print_text_default_end_is_empty_string(self, capsys): + self.cb.print_text("no newline") + captured = capsys.readouterr() + assert not captured.out.endswith("\n") + + def test_print_text_with_custom_end(self, capsys): + self.cb.print_text("with newline", end="\n") + captured = capsys.readouterr() + assert captured.out.endswith("\n") + + def test_print_text_with_empty_string(self, capsys): + self.cb.print_text("", color=None) + captured = capsys.readouterr() + assert captured.out == "" + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_print_text_all_colors_work(self, color, capsys): + """Verify no KeyError is thrown for any valid color.""" + self.cb.print_text("test", color=color) + captured = capsys.readouterr() + assert "test" in captured.out + + def test_print_text_calls_get_colored_text_when_color_given(self): + with patch.object(self.cb, "_get_colored_text", return_value="[COLORED]") as mock_gct: + with patch("builtins.print") as mock_print: + self.cb.print_text("hello", color="green") + mock_gct.assert_called_once_with("hello", "green") + mock_print.assert_called_once_with("[COLORED]", end="") + + def test_print_text_does_not_call_get_colored_text_when_no_color(self): + with patch.object(self.cb, "_get_colored_text") as mock_gct: + with patch("builtins.print"): + self.cb.print_text("hello", color=None) + mock_gct.assert_not_called() + + def test_print_text_passes_end_to_print(self): + with patch("builtins.print") as mock_print: + self.cb.print_text("text", end="---") + mock_print.assert_called_once_with("text", end="---") + + +# =========================================================================== +# Tests for _get_colored_text (private helper method) +# =========================================================================== + + +class TestGetColoredText: + """Tests for the _get_colored_text private method.""" + + def setup_method(self): + self.cb = ConcreteCallback() + + @pytest.mark.parametrize(("color", "expected_code"), list(_TEXT_COLOR_MAPPING.items())) + def test_get_colored_text_uses_correct_escape_code(self, color, expected_code): + result = self.cb._get_colored_text("text", color) + assert expected_code in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_contains_input_text(self, color): + result = self.cb._get_colored_text("hello", color) + assert "hello" in result + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_starts_with_escape(self, color): + result = self.cb._get_colored_text("text", color) + # Should start with an ANSI escape (\x1b or \u001b) + assert result.startswith("\x1b[") or result.startswith("\u001b[") + + @pytest.mark.parametrize("color", ["blue", "yellow", "pink", "green", "red"]) + def test_get_colored_text_ends_with_reset(self, color): + result = self.cb._get_colored_text("text", color) + # Should end with the ANSI reset code + assert result.endswith("\x1b[0m") or result.endswith("\u001b[0m") + + def test_get_colored_text_returns_string(self): + result = self.cb._get_colored_text("text", "blue") + assert isinstance(result, str) + + def test_get_colored_text_blue_exact_format(self): + result = self.cb._get_colored_text("hello", "blue") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['blue']}m\033[1;3mhello\u001b[0m" + assert result == expected + + def test_get_colored_text_red_exact_format(self): + result = self.cb._get_colored_text("error", "red") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['red']}m\033[1;3merror\u001b[0m" + assert result == expected + + def test_get_colored_text_green_exact_format(self): + result = self.cb._get_colored_text("ok", "green") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['green']}m\033[1;3mok\u001b[0m" + assert result == expected + + def test_get_colored_text_yellow_exact_format(self): + result = self.cb._get_colored_text("warn", "yellow") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['yellow']}m\033[1;3mwarn\u001b[0m" + assert result == expected + + def test_get_colored_text_pink_exact_format(self): + result = self.cb._get_colored_text("info", "pink") + expected = f"\u001b[{_TEXT_COLOR_MAPPING['pink']}m\033[1;3minfo\u001b[0m" + assert result == expected + + def test_get_colored_text_empty_string(self): + result = self.cb._get_colored_text("", "blue") + assert isinstance(result, str) + # Empty text should still have escape codes + assert _TEXT_COLOR_MAPPING["blue"] in result + + def test_get_colored_text_invalid_color_raises_key_error(self): + with pytest.raises(KeyError): + self.cb._get_colored_text("text", "purple") + + def test_get_colored_text_with_special_characters(self): + special = "hello\nworld\ttab" + result = self.cb._get_colored_text(special, "blue") + assert special in result + + def test_get_colored_text_with_long_text(self): + long_text = "a" * 10000 + result = self.cb._get_colored_text(long_text, "green") + assert long_text in result + + +# =========================================================================== +# Integration-style tests: full workflow through a ConcreteCallback +# =========================================================================== + + +class TestConcreteCallbackIntegration: + """End-to-end workflow tests using ConcreteCallback.""" + + def test_full_invocation_lifecycle(self): + """Simulate a complete LLM invocation lifecycle through all callbacks.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4o" + credentials = {"api_key": "sk-xyz"} + prompt_messages = [MagicMock(spec=PromptMessage)] + model_parameters = {"temperature": 0.5} + tools = [MagicMock(spec=PromptMessageTool)] + stop = [""] + user = "user-abc" + + # 1. Before invoke + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 2. Multiple chunks during streaming + for i in range(3): + chunk = MagicMock(spec=LLMResultChunk) + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + # 3. After invoke + result = MagicMock(spec=LLMResult) + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + tools=tools, + stop=stop, + stream=True, + user=user, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.new_chunk_calls) == 3 + assert len(cb.after_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 0 + + def test_error_lifecycle(self): + """Simulate an invoke that results in an error.""" + cb = ConcreteCallback() + llm_instance = MagicMock() + model = "gpt-4" + credentials = {} + prompt_messages = [] + model_parameters = {} + + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + ex = RuntimeError("API timeout") + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + model_parameters=model_parameters, + ) + + assert len(cb.before_invoke_calls) == 1 + assert len(cb.invoke_error_calls) == 1 + assert cb.invoke_error_calls[0]["ex"] is ex + assert len(cb.after_invoke_calls) == 0 + + def test_print_text_with_color_in_integration(self, capsys): + """verify print_text works correctly in a concrete instance.""" + cb = ConcreteCallback() + cb.print_text("SUCCESS", color="green", end="\n") + captured = capsys.readouterr() + assert "SUCCESS" in captured.out + assert "\n" in captured.out + + def test_print_text_no_color_in_integration(self, capsys): + cb = ConcreteCallback() + cb.print_text("plain output") + captured = capsys.readouterr() + assert captured.out == "plain output" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py new file mode 100644 index 0000000000..0c6c1fd191 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/callbacks/test_logging_callback.py @@ -0,0 +1,700 @@ +""" +Comprehensive unit tests for core/model_runtime/callbacks/logging_callback.py + +Coverage targets: + - LoggingCallback.on_before_invoke (all branches: stop, tools, user, stream, + prompt_message.name, model_parameters) + - LoggingCallback.on_new_chunk (writes to stdout) + - LoggingCallback.on_after_invoke (all branches: tool_calls present / absent) + - LoggingCallback.on_invoke_error (logs exception via logger.exception) +""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from decimal import Decimal +from unittest.mock import MagicMock, patch + +import pytest + +from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageTool, + SystemPromptMessage, + UserPromptMessage, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_usage() -> LLMUsage: + """Return a minimal LLMUsage instance.""" + return LLMUsage( + prompt_tokens=10, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("0.001"), + prompt_price=Decimal("0.01"), + completion_tokens=20, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("0.002"), + completion_price=Decimal("0.04"), + total_tokens=30, + total_price=Decimal("0.05"), + currency="USD", + latency=0.5, + ) + + +def _make_llm_result( + content: str = "hello world", + tool_calls: list | None = None, + model: str = "gpt-4", + system_fingerprint: str | None = "fp-abc", +) -> LLMResult: + """Return an LLMResult with an AssistantPromptMessage.""" + assistant_msg = AssistantPromptMessage( + content=content, + tool_calls=tool_calls or [], + ) + return LLMResult( + model=model, + message=assistant_msg, + usage=_make_usage(), + system_fingerprint=system_fingerprint, + ) + + +def _make_chunk(content: str = "chunk-text") -> LLMResultChunk: + """Return a minimal LLMResultChunk.""" + return LLMResultChunk( + model="gpt-4", + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content), + ), + ) + + +def _make_user_prompt(content: str = "Hello!", name: str | None = None) -> UserPromptMessage: + return UserPromptMessage(content=content, name=name) + + +def _make_system_prompt(content: str = "You are helpful.") -> SystemPromptMessage: + return SystemPromptMessage(content=content) + + +def _make_tool(name: str = "my_tool") -> PromptMessageTool: + return PromptMessageTool(name=name, description="A tool", parameters={}) + + +def _make_tool_call( + call_id: str = "call-1", + func_name: str = "some_func", + arguments: str = '{"key": "value"}', +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=call_id, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=func_name, arguments=arguments), + ) + + +# --------------------------------------------------------------------------- +# Fixture: shared LoggingCallback instance (no heavy state) +# --------------------------------------------------------------------------- + + +@pytest.fixture +def cb() -> LoggingCallback: + return LoggingCallback() + + +@pytest.fixture +def llm_instance() -> MagicMock: + return MagicMock() + + +# =========================================================================== +# Tests for on_before_invoke +# =========================================================================== + + +class TestOnBeforeInvoke: + """Tests for LoggingCallback.on_before_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + *, + model: str = "gpt-4", + credentials: dict | None = None, + prompt_messages: list | None = None, + model_parameters: dict | None = None, + tools: list[PromptMessageTool] | None = None, + stop: Sequence[str] | None = None, + stream: bool = True, + user: str | None = None, + ): + cb.on_before_invoke( + llm_instance=llm_instance, + model=model, + credentials=credentials or {}, + prompt_messages=prompt_messages or [], + model_parameters=model_parameters or {}, + tools=tools, + stop=stop, + stream=stream, + user=user, + ) + + def test_minimal_call_does_not_raise(self, cb: LoggingCallback, llm_instance: MagicMock): + """Calling with bare-minimum args should not raise.""" + self._invoke(cb, llm_instance) + + def test_model_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The model name must appear in print_text calls.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model="claude-3") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "claude-3" in calls_text + + def test_model_parameters_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Each key-value pair of model_parameters must be printed.""" + params = {"temperature": 0.7, "max_tokens": 512} + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, model_parameters=params) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "temperature" in calls_text + assert "0.7" in calls_text + assert "max_tokens" in calls_text + assert "512" in calls_text + + def test_empty_model_parameters(self, cb: LoggingCallback, llm_instance: MagicMock): + """Empty model_parameters dict should not raise.""" + self._invoke(cb, llm_instance, model_parameters={}) + + # ------------------------------------------------------------------ + # stop branch + # ------------------------------------------------------------------ + + def test_stop_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """stop words must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=["STOP", "END"]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "stop" in calls_text + + def test_stop_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=None the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + def test_stop_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stop=[] (falsy) the stop line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stop=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tstop:" not in calls_text + + # ------------------------------------------------------------------ + # tools branch + # ------------------------------------------------------------------ + + def test_tools_branch_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """Tool names must appear in output when tools are provided.""" + tools = [_make_tool("search"), _make_tool("calculate")] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=tools) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "search" in calls_text + assert "calculate" in calls_text + + def test_tools_branch_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=None the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + def test_tools_branch_skipped_when_empty_list(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tools=[] (falsy) the Tools section must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, tools=[]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tools:" not in calls_text + + # ------------------------------------------------------------------ + # user branch + # ------------------------------------------------------------------ + + def test_user_printed_when_provided(self, cb: LoggingCallback, llm_instance: MagicMock): + """User string must appear in output when provided.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user="alice") + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "alice" in calls_text + + def test_user_skipped_when_none(self, cb: LoggingCallback, llm_instance: MagicMock): + """When user=None the User line must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, user=None) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "User:" not in calls_text + + # ------------------------------------------------------------------ + # stream branch + # ------------------------------------------------------------------ + + def test_stream_true_prints_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=True the [on_llm_new_chunk] marker must be printed.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=True) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" in calls_text + + def test_stream_false_no_new_chunk_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """When stream=False the [on_llm_new_chunk] marker must NOT appear.""" + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, stream=False) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_new_chunk]" not in calls_text + + # ------------------------------------------------------------------ + # prompt_messages branch + # ------------------------------------------------------------------ + + def test_prompt_message_with_name_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has a name it must be printed.""" + msg = _make_user_prompt("hi", name="bob") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "bob" in calls_text + + def test_prompt_message_without_name_skips_name_line(self, cb: LoggingCallback, llm_instance: MagicMock): + """When a PromptMessage has no name the name line must NOT appear.""" + msg = _make_user_prompt("hi", name=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "\tname:" not in calls_text + + def test_prompt_message_role_and_content_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """Role and content of each PromptMessage must appear in output.""" + msg = _make_system_prompt("Be concise.") + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=[msg]) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "system" in calls_text + assert "Be concise." in calls_text + + def test_multiple_prompt_messages_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All entries in prompt_messages are iterated and printed.""" + msgs = [ + _make_system_prompt("sys"), + _make_user_prompt("user msg"), + ] + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, prompt_messages=msgs) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "sys" in calls_text + assert "user msg" in calls_text + + # ------------------------------------------------------------------ + # Combination: everything provided + # ------------------------------------------------------------------ + + def test_all_optional_fields_combined(self, cb: LoggingCallback, llm_instance: MagicMock): + """Supply stop, tools, user, multiple params, named message – no exception.""" + msgs = [_make_user_prompt("question", name="alice")] + tools = [_make_tool("tool_a")] + with patch.object(cb, "print_text"): + self._invoke( + cb, + llm_instance, + model="gpt-3.5", + model_parameters={"temperature": 1.0, "top_p": 0.9}, + tools=tools, + stop=["DONE"], + stream=True, + user="alice", + prompt_messages=msgs, + ) + + +# =========================================================================== +# Tests for on_new_chunk +# =========================================================================== + + +class TestOnNewChunk: + """Tests for LoggingCallback.on_new_chunk.""" + + def test_chunk_content_written_to_stdout(self, cb: LoggingCallback, llm_instance: MagicMock): + """on_new_chunk must write the chunk's text content to sys.stdout.""" + chunk = _make_chunk("hello from LLM") + written = [] + + with patch("sys.stdout") as mock_stdout: + mock_stdout.write.side_effect = written.append + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("hello from LLM") + mock_stdout.flush.assert_called_once() + + def test_chunk_content_empty_string(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works correctly even when the chunk content is an empty string.""" + chunk = _make_chunk("") + with patch("sys.stdout") as mock_stdout: + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + mock_stdout.write.assert_called_once_with("") + mock_stdout.flush.assert_called_once() + + def test_chunk_passes_all_optional_params(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters are accepted without errors.""" + chunk = _make_chunk("data") + with patch("sys.stdout"): + cb.on_new_chunk( + llm_instance=llm_instance, + chunk=chunk, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.5}, + tools=[_make_tool("t1")], + stop=["EOS"], + stream=True, + user="bob", + ) + + +# =========================================================================== +# Tests for on_after_invoke +# =========================================================================== + + +class TestOnAfterInvoke: + """Tests for LoggingCallback.on_after_invoke.""" + + def _invoke( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + result: LLMResult, + **kwargs, + ): + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_basic_result_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """After-invoke header, content, model, usage, fingerprint must be printed.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_after_invoke]" in calls_text + assert "hello world" in calls_text + assert "gpt-4" in calls_text + assert "fp-abc" in calls_text + + def test_no_tool_calls_skips_tool_call_block(self, cb: LoggingCallback, llm_instance: MagicMock): + """When there are no tool_calls the 'Tool calls:' block must NOT appear.""" + result = _make_llm_result(tool_calls=[]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" not in calls_text + + def test_with_tool_calls_prints_all_fields(self, cb: LoggingCallback, llm_instance: MagicMock): + """When tool_calls exist their id, name, and JSON arguments must be printed.""" + tc = _make_tool_call( + call_id="call-xyz", + func_name="fetch_data", + arguments='{"url": "https://example.com"}', + ) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Tool calls:" in calls_text + assert "call-xyz" in calls_text + assert "fetch_data" in calls_text + # arguments should be JSON-dumped + assert "https://example.com" in calls_text + + def test_multiple_tool_calls_all_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """All tool calls in the list must be iterated.""" + tcs = [ + _make_tool_call("id-1", "func_a", '{"a": 1}'), + _make_tool_call("id-2", "func_b", '{"b": 2}'), + ] + result = _make_llm_result(tool_calls=tcs) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "id-1" in calls_text + assert "func_a" in calls_text + assert "id-2" in calls_text + assert "func_b" in calls_text + + def test_system_fingerprint_none_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """When system_fingerprint is None it should still be printed (as None).""" + result = _make_llm_result(system_fingerprint=None) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "System Fingerprint: None" in calls_text + + def test_usage_printed(self, cb: LoggingCallback, llm_instance: MagicMock): + """The usage object must appear in the printed output.""" + result = _make_llm_result() + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "Usage:" in calls_text + + def test_tool_call_arguments_are_json_dumped(self, cb: LoggingCallback, llm_instance: MagicMock): + """Verify json.dumps is applied to the arguments field (a string).""" + raw_args = '{"x": 42}' + tc = _make_tool_call(arguments=raw_args) + result = _make_llm_result(tool_calls=[tc]) + with patch.object(cb, "print_text") as mock_print: + self._invoke(cb, llm_instance, result) + + # Check if any call to print_text included the expected (json-encoded) arguments + # json.dumps(raw_args) produces a string starting and ending with quotes + expected_substring = json.dumps(raw_args) + found = any(expected_substring in str(call.args[0]) for call in mock_print.call_args_list) + assert found, f"Expected {expected_substring} to be printed in one of the calls" + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + result = _make_llm_result() + cb.on_after_invoke( + llm_instance=llm_instance, + result=result, + model=result.model, + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.9}, + tools=[_make_tool("t")], + stop=[""], + stream=False, + user="carol", + ) + + +# =========================================================================== +# Tests for on_invoke_error +# =========================================================================== + + +class TestOnInvokeError: + """Tests for LoggingCallback.on_invoke_error.""" + + def _invoke_error( + self, + cb: LoggingCallback, + llm_instance: MagicMock, + ex: Exception, + **kwargs, + ): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + **kwargs, + ) + + def test_prints_error_header(self, cb: LoggingCallback, llm_instance: MagicMock): + """The [on_llm_invoke_error] banner must be printed.""" + with patch.object(cb, "print_text") as mock_print: + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, RuntimeError("boom")) + calls_text = " ".join(str(c) for c in mock_print.call_args_list) + assert "[on_llm_invoke_error]" in calls_text + + def test_exception_logged_via_logger_exception(self, cb: LoggingCallback, llm_instance: MagicMock): + """logger.exception must be called with the exception.""" + ex = ValueError("something went wrong") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_exception_type_variety(self, cb: LoggingCallback, llm_instance: MagicMock): + """Works with any exception type (TypeError, IOError, etc.).""" + for exc_cls in (TypeError, IOError, KeyError, Exception): + ex = exc_cls("error") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger") as mock_logger: + self._invoke_error(cb, llm_instance, ex) + mock_logger.exception.assert_called_once_with(ex) + + def test_optional_params_accepted(self, cb: LoggingCallback, llm_instance: MagicMock): + """All optional parameters should be accepted without error.""" + ex = RuntimeError("fail") + with patch.object(cb, "print_text"): + with patch("dify_graph.model_runtime.callbacks.logging_callback.logger"): + cb.on_invoke_error( + llm_instance=llm_instance, + ex=ex, + model="gpt-4", + credentials={"key": "secret"}, + prompt_messages=[_make_user_prompt("q")], + model_parameters={"temperature": 0.7}, + tools=[_make_tool("t")], + stop=["STOP"], + stream=True, + user="dave", + ) + + +# =========================================================================== +# Tests for print_text (inherited from Callback, exercised through LoggingCallback) +# =========================================================================== + + +class TestPrintText: + """Verify that print_text from the Callback base class works correctly.""" + + def test_print_text_with_color(self, cb: LoggingCallback, capsys): + """print_text with a known colour should emit an ANSI escape sequence.""" + cb.print_text("hello", color="blue") + captured = capsys.readouterr() + assert "hello" in captured.out + # ANSI escape codes should be present + assert "\x1b[" in captured.out + + def test_print_text_without_color(self, cb: LoggingCallback, capsys): + """print_text without colour should print plain text.""" + cb.print_text("plain text") + captured = capsys.readouterr() + assert "plain text" in captured.out + + def test_print_text_all_colours(self, cb: LoggingCallback, capsys): + """Verify all supported colour keys don't raise.""" + for colour in ("blue", "yellow", "pink", "green", "red"): + cb.print_text("x", color=colour) + captured = capsys.readouterr() + # All outputs should contain 'x' (5 calls) + assert captured.out.count("x") >= 5 + + +# =========================================================================== +# Integration-style test: real print_text called (no mocking) +# =========================================================================== + + +class TestLoggingCallbackIntegration: + """Light integration tests – real print_text calls, just checking no exceptions.""" + + def test_on_before_invoke_full_run(self, capsys): + """Full on_before_invoke run with all optional fields – verifies real output.""" + cb = LoggingCallback() + llm = MagicMock() + msgs = [_make_user_prompt("Who are you?", name="tester")] + tools = [_make_tool("calculator")] + cb.on_before_invoke( + llm_instance=llm, + model="gpt-4-turbo", + credentials={"api_key": "sk-xxx"}, + prompt_messages=msgs, + model_parameters={"temperature": 0.8}, + tools=tools, + stop=["STOP"], + stream=True, + user="test_user", + ) + captured = capsys.readouterr() + assert "gpt-4-turbo" in captured.out + assert "calculator" in captured.out + assert "test_user" in captured.out + assert "STOP" in captured.out + assert "tester" in captured.out + + def test_on_new_chunk_full_run(self, capsys): + """Full on_new_chunk run – verifies real stdout write.""" + cb = LoggingCallback() + chunk = _make_chunk("streaming token") + cb.on_new_chunk( + llm_instance=MagicMock(), + chunk=chunk, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "streaming token" in captured.out + + def test_on_after_invoke_full_run_with_tool_calls(self, capsys): + """Full on_after_invoke run with tool calls – verifies real output.""" + cb = LoggingCallback() + tc = _make_tool_call("call-99", "do_thing", '{"n": 5}') + result = _make_llm_result(content="result content", tool_calls=[tc], system_fingerprint="fp-xyz") + cb.on_after_invoke( + llm_instance=MagicMock(), + result=result, + model=result.model, + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "result content" in captured.out + assert "call-99" in captured.out + assert "do_thing" in captured.out + assert "fp-xyz" in captured.out + + def test_on_invoke_error_full_run(self, capsys): + """Full on_invoke_error run – just verifies no exception is raised.""" + cb = LoggingCallback() + ex = RuntimeError("something bad happened") + # logger.exception writes to stderr; we just confirm it doesn't crash + cb.on_invoke_error( + llm_instance=MagicMock(), + ex=ex, + model="gpt-4", + credentials={}, + prompt_messages=[], + model_parameters={}, + ) + captured = capsys.readouterr() + assert "[on_llm_invoke_error]" in captured.out diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py new file mode 100644 index 0000000000..db147fb0cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_common_entities.py @@ -0,0 +1,35 @@ +from dify_graph.model_runtime.entities.common_entities import I18nObject + + +class TestI18nObject: + def test_i18n_object_with_both_languages(self): + """ + Test I18nObject when both zh_Hans and en_US are provided. + """ + i18n = I18nObject(zh_Hans="你好", en_US="Hello") + assert i18n.zh_Hans == "你好" + assert i18n.en_US == "Hello" + + def test_i18n_object_fallback_to_en_us(self): + """ + Test I18nObject when zh_Hans is missing, it should fallback to en_US. + """ + i18n = I18nObject(en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_none_zh_hans(self): + """ + Test I18nObject when zh_Hans is None, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans=None, en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" + + def test_i18n_object_with_empty_zh_hans(self): + """ + Test I18nObject when zh_Hans is an empty string, it should fallback to en_US. + """ + i18n = I18nObject(zh_Hans="", en_US="Hello") + assert i18n.zh_Hans == "Hello" + assert i18n.en_US == "Hello" diff --git a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py similarity index 98% rename from api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py rename to api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py index c10f7b89c3..4e435cb4c6 100644 --- a/api/tests/unit_tests/core/model_runtime/entities/test_llm_entities.py +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_llm_entities.py @@ -2,7 +2,7 @@ from decimal import Decimal -from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata class TestLLMUsage: diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py new file mode 100644 index 0000000000..a96a38f5cd --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_message_entities.py @@ -0,0 +1,210 @@ +import pytest + +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + AudioPromptMessageContent, + DocumentPromptMessageContent, + ImagePromptMessageContent, + PromptMessageContent, + PromptMessageContentType, + PromptMessageFunction, + PromptMessageRole, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, + VideoPromptMessageContent, +) + + +class TestPromptMessageRole: + def test_value_of(self): + assert PromptMessageRole.value_of("system") == PromptMessageRole.SYSTEM + assert PromptMessageRole.value_of("user") == PromptMessageRole.USER + assert PromptMessageRole.value_of("assistant") == PromptMessageRole.ASSISTANT + assert PromptMessageRole.value_of("tool") == PromptMessageRole.TOOL + + with pytest.raises(ValueError, match="invalid prompt message type value invalid"): + PromptMessageRole.value_of("invalid") + + +class TestPromptMessageEntities: + def test_prompt_message_tool(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + assert tool.name == "test_tool" + assert tool.description == "test desc" + assert tool.parameters == {"foo": "bar"} + + def test_prompt_message_function(self): + tool = PromptMessageTool(name="test_tool", description="test desc", parameters={"foo": "bar"}) + func = PromptMessageFunction(function=tool) + assert func.type == "function" + assert func.function == tool + + +class TestPromptMessageContent: + def test_text_content(self): + content = TextPromptMessageContent(data="hello") + assert content.type == PromptMessageContentType.TEXT + assert content.data == "hello" + + def test_image_content(self): + content = ImagePromptMessageContent( + format="jpg", base64_data="abc", mime_type="image/jpeg", detail=ImagePromptMessageContent.DETAIL.HIGH + ) + assert content.type == PromptMessageContentType.IMAGE + assert content.detail == ImagePromptMessageContent.DETAIL.HIGH + assert content.data == "data:image/jpeg;base64,abc" + + def test_image_content_url(self): + content = ImagePromptMessageContent(format="jpg", url="https://example.com/image.jpg", mime_type="image/jpeg") + assert content.data == "https://example.com/image.jpg" + + def test_audio_content(self): + content = AudioPromptMessageContent(format="mp3", base64_data="abc", mime_type="audio/mpeg") + assert content.type == PromptMessageContentType.AUDIO + assert content.data == "data:audio/mpeg;base64,abc" + + def test_video_content(self): + content = VideoPromptMessageContent(format="mp4", base64_data="abc", mime_type="video/mp4") + assert content.type == PromptMessageContentType.VIDEO + assert content.data == "data:video/mp4;base64,abc" + + def test_document_content(self): + content = DocumentPromptMessageContent(format="pdf", base64_data="abc", mime_type="application/pdf") + assert content.type == PromptMessageContentType.DOCUMENT + assert content.data == "data:application/pdf;base64,abc" + + +class TestPromptMessages: + def test_user_prompt_message(self): + msg = UserPromptMessage(content="hello") + assert msg.role == PromptMessageRole.USER + assert msg.content == "hello" + assert msg.is_empty() is False + assert msg.get_text_content() == "hello" + + def test_user_prompt_message_complex_content(self): + content = [TextPromptMessageContent(data="hello "), TextPromptMessageContent(data="world")] + msg = UserPromptMessage(content=content) + assert msg.get_text_content() == "hello world" + + # Test validation from dict + msg2 = UserPromptMessage(content=[{"type": "text", "data": "hi"}]) + assert isinstance(msg2.content[0], TextPromptMessageContent) + assert msg2.content[0].data == "hi" + + def test_prompt_message_empty(self): + msg = UserPromptMessage(content=None) + assert msg.is_empty() is True + assert msg.get_text_content() == "" + + def test_assistant_prompt_message(self): + msg = AssistantPromptMessage(content="thinking...") + assert msg.role == PromptMessageRole.ASSISTANT + assert msg.is_empty() is False + + tool_call = AssistantPromptMessage.ToolCall( + id="call_1", + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + msg_with_tools = AssistantPromptMessage(content=None, tool_calls=[tool_call]) + assert msg_with_tools.is_empty() is False + assert msg_with_tools.role == PromptMessageRole.ASSISTANT + + def test_assistant_tool_call_id_transform(self): + tool_call = AssistantPromptMessage.ToolCall( + id=123, + type="function", + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="test", arguments="{}"), + ) + assert tool_call.id == "123" + + def test_system_prompt_message(self): + msg = SystemPromptMessage(content="you are a bot") + assert msg.role == PromptMessageRole.SYSTEM + assert msg.content == "you are a bot" + + def test_tool_prompt_message(self): + # Case 1: Both content and tool_call_id are present + msg = ToolPromptMessage(content="result", tool_call_id="call_1") + assert msg.role == PromptMessageRole.TOOL + assert msg.tool_call_id == "call_1" + assert msg.is_empty() is False + + # Case 2: Content is present, but tool_call_id is empty + msg_content_only = ToolPromptMessage(content="result", tool_call_id="") + assert msg_content_only.is_empty() is False + + # Case 3: Content is None, but tool_call_id is present + msg_id_only = ToolPromptMessage(content=None, tool_call_id="call_1") + assert msg_id_only.is_empty() is False + + # Case 4: Both content and tool_call_id are empty + msg_empty = ToolPromptMessage(content=None, tool_call_id="") + assert msg_empty.is_empty() is True + + def test_prompt_message_validation_errors(self): + with pytest.raises(KeyError): + # Invalid content type in list + UserPromptMessage(content=[{"type": "invalid", "data": "foo"}]) + + with pytest.raises(ValueError, match="invalid prompt message"): + # Not a dict or PromptMessageContent + UserPromptMessage(content=[123]) + + def test_prompt_message_serialization(self): + # Case: content is None + assert UserPromptMessage(content=None).serialize_content(None) is None + + # Case: content is str + assert UserPromptMessage(content="hello").serialize_content("hello") == "hello" + + # Case: content is list of dict + content_list = [{"type": "text", "data": "hi"}] + msg = UserPromptMessage(content=content_list) + assert msg.serialize_content(msg.content) == [{"type": PromptMessageContentType.TEXT, "data": "hi"}] + + # Case: content is Sequence but not list (e.g. tuple) + # To hit line 204, we can call serialize_content manually or + # try to pass a type that pydantic doesn't convert to list in its internal state. + # Actually, let's just call it manually on the instance. + msg = UserPromptMessage(content="test") + content_tuple = (TextPromptMessageContent(data="hi"),) + assert msg.serialize_content(content_tuple) == content_tuple + + def test_prompt_message_mixed_content_validation(self): + # Test branch: isinstance(prompt, PromptMessageContent) + # but not (TextPromptMessageContent | MultiModalPromptMessageContent) + # Line 187: prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) + + # We need a PromptMessageContent that is NOT Text or MultiModal. + # But PromptMessageContentUnionTypes discriminator handles this usually. + # We can bypass high-level validation by passing the object directly in a list. + + class MockContent(PromptMessageContent): + type: PromptMessageContentType = PromptMessageContentType.TEXT + data: str + + mock_item = MockContent(data="test") + msg = UserPromptMessage(content=[mock_item]) + # It should hit line 187 and convert to TextPromptMessageContent + assert isinstance(msg.content[0], TextPromptMessageContent) + assert msg.content[0].data == "test" + + def test_prompt_message_get_text_content_branches(self): + # content is None + msg_none = UserPromptMessage(content=None) + assert msg_none.get_text_content() == "" + + # content is list but no text content + image = ImagePromptMessageContent(format="jpg", base64_data="abc", mime_type="image/jpeg") + msg_image = UserPromptMessage(content=[image]) + assert msg_image.get_text_content() == "" + + # content is list with mixed + text = TextPromptMessageContent(data="hello") + msg_mixed = UserPromptMessage(content=[text, image]) + assert msg_mixed.get_text_content() == "hello" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py new file mode 100644 index 0000000000..3d03361f2a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/entities/test_model_entities.py @@ -0,0 +1,220 @@ +from decimal import Decimal + +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ModelUsage, + ParameterRule, + ParameterType, + PriceConfig, + PriceInfo, + PriceType, + ProviderModel, +) + + +class TestModelType: + def test_value_of(self): + assert ModelType.value_of("text-generation") == ModelType.LLM + assert ModelType.value_of(ModelType.LLM) == ModelType.LLM + assert ModelType.value_of("embeddings") == ModelType.TEXT_EMBEDDING + assert ModelType.value_of(ModelType.TEXT_EMBEDDING) == ModelType.TEXT_EMBEDDING + assert ModelType.value_of("reranking") == ModelType.RERANK + assert ModelType.value_of(ModelType.RERANK) == ModelType.RERANK + assert ModelType.value_of("speech2text") == ModelType.SPEECH2TEXT + assert ModelType.value_of(ModelType.SPEECH2TEXT) == ModelType.SPEECH2TEXT + assert ModelType.value_of("tts") == ModelType.TTS + assert ModelType.value_of(ModelType.TTS) == ModelType.TTS + assert ModelType.value_of(ModelType.MODERATION) == ModelType.MODERATION + + with pytest.raises(ValueError, match="invalid origin model type invalid"): + ModelType.value_of("invalid") + + def test_to_origin_model_type(self): + assert ModelType.LLM.to_origin_model_type() == "text-generation" + assert ModelType.TEXT_EMBEDDING.to_origin_model_type() == "embeddings" + assert ModelType.RERANK.to_origin_model_type() == "reranking" + assert ModelType.SPEECH2TEXT.to_origin_model_type() == "speech2text" + assert ModelType.TTS.to_origin_model_type() == "tts" + assert ModelType.MODERATION.to_origin_model_type() == "moderation" + + # Testing the else branch in to_origin_model_type + # Since it's a StrEnum, it's hard to get an invalid value here unless we mock or Force it. + # But if we look at the implementation: + # if self == self.LLM: ... elif ... else: raise ValueError + # We can try to create a "dummy" member if possible, or just skip it if we have 100% coverage otherwise. + # Actually, adding a new member to an enum at runtime is possible but messy. + # Let's see if we can trigger it. + + +class TestFetchFrom: + def test_values(self): + assert FetchFrom.PREDEFINED_MODEL == "predefined-model" + assert FetchFrom.CUSTOMIZABLE_MODEL == "customizable-model" + + +class TestModelFeature: + def test_values(self): + assert ModelFeature.TOOL_CALL == "tool-call" + assert ModelFeature.MULTI_TOOL_CALL == "multi-tool-call" + assert ModelFeature.AGENT_THOUGHT == "agent-thought" + assert ModelFeature.VISION == "vision" + assert ModelFeature.STREAM_TOOL_CALL == "stream-tool-call" + assert ModelFeature.DOCUMENT == "document" + assert ModelFeature.VIDEO == "video" + assert ModelFeature.AUDIO == "audio" + assert ModelFeature.STRUCTURED_OUTPUT == "structured-output" + + +class TestDefaultParameterName: + def test_value_of(self): + assert DefaultParameterName.value_of("temperature") == DefaultParameterName.TEMPERATURE + assert DefaultParameterName.value_of("top_p") == DefaultParameterName.TOP_P + + with pytest.raises(ValueError, match="invalid parameter name invalid"): + DefaultParameterName.value_of("invalid") + + +class TestParameterType: + def test_values(self): + assert ParameterType.FLOAT == "float" + assert ParameterType.INT == "int" + assert ParameterType.STRING == "string" + assert ParameterType.BOOLEAN == "boolean" + assert ParameterType.TEXT == "text" + + +class TestModelPropertyKey: + def test_values(self): + assert ModelPropertyKey.MODE == "mode" + assert ModelPropertyKey.CONTEXT_SIZE == "context_size" + + +class TestProviderModel: + def test_provider_model(self): + model = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model.model == "gpt-4" + assert model.support_structure_output is False + + model_with_features = ProviderModel( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.STRUCTURED_OUTPUT], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + ) + assert model_with_features.support_structure_output is True + + +class TestParameterRule: + def test_parameter_rule(self): + rule = ParameterRule( + name="temperature", + label=I18nObject(en_US="Temperature"), + type=ParameterType.FLOAT, + default=0.7, + min=0.0, + max=1.0, + precision=2, + ) + assert rule.name == "temperature" + assert rule.default == 0.7 + + +class TestPriceConfig: + def test_price_config(self): + config = PriceConfig(input=Decimal("0.01"), output=Decimal("0.02"), unit=Decimal("0.001"), currency="USD") + assert config.input == Decimal("0.01") + assert config.output == Decimal("0.02") + + +class TestAIModelEntity: + def test_ai_model_entity_no_json_schema(self): + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="temperature", label=I18nObject(en_US="Temperature"), type=ParameterType.FLOAT) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT not in (entity.features or []) + + def test_ai_model_entity_with_json_schema(self): + # Case: json_schema in parameter rules, features is None + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_features_empty(self): + # Case: json_schema in parameter rules, features is empty list + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + + def test_ai_model_entity_with_json_schema_and_other_features(self): + # Case: json_schema in parameter rules, features has other things + entity = AIModelEntity( + model="gpt-4", + label=I18nObject(en_US="GPT-4"), + model_type=ModelType.LLM, + features=[ModelFeature.VISION], + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 8192}, + parameter_rules=[ + ParameterRule(name="json_schema", label=I18nObject(en_US="JSON Schema"), type=ParameterType.STRING) + ], + ) + assert ModelFeature.STRUCTURED_OUTPUT in entity.features + assert ModelFeature.VISION in entity.features + + +class TestModelUsage: + def test_model_usage(self): + usage = ModelUsage() + assert isinstance(usage, ModelUsage) + + +class TestPriceType: + def test_values(self): + assert PriceType.INPUT == "input" + assert PriceType.OUTPUT == "output" + + +class TestPriceInfo: + def test_price_info(self): + info = PriceInfo(unit_price=Decimal("0.01"), unit=Decimal(1000), total_amount=Decimal("0.05"), currency="USD") + assert info.total_amount == Decimal("0.05") diff --git a/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py new file mode 100644 index 0000000000..af62b2a84c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/errors/test_invoke.py @@ -0,0 +1,63 @@ +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) + + +class TestInvokeErrors: + def test_invoke_error_with_description(self): + error = InvokeError("Custom description") + assert error.description == "Custom description" + assert str(error) == "Custom description" + assert isinstance(error, ValueError) + + def test_invoke_error_without_description(self): + error = InvokeError() + assert error.description is None + assert str(error) == "InvokeError" + + def test_invoke_connection_error(self): + # Now preserves class-level description + error = InvokeConnectionError() + assert error.description == "Connection Error" + assert str(error) == "Connection Error" + assert isinstance(error, InvokeError) + + # Test with explicit description + error_with_desc = InvokeConnectionError("Connection Error") + assert error_with_desc.description == "Connection Error" + assert str(error_with_desc) == "Connection Error" + + def test_invoke_server_unavailable_error(self): + error = InvokeServerUnavailableError() + assert error.description == "Server Unavailable Error" + assert str(error) == "Server Unavailable Error" + assert isinstance(error, InvokeError) + + def test_invoke_rate_limit_error(self): + error = InvokeRateLimitError() + assert error.description == "Rate Limit Error" + assert str(error) == "Rate Limit Error" + assert isinstance(error, InvokeError) + + def test_invoke_authorization_error(self): + error = InvokeAuthorizationError() + assert error.description == "Incorrect model credentials provided, please check and try again. " + assert str(error) == "Incorrect model credentials provided, please check and try again. " + assert isinstance(error, InvokeError) + + def test_invoke_bad_request_error(self): + error = InvokeBadRequestError() + assert error.description == "Bad Request Error" + assert str(error) == "Bad Request Error" + assert isinstance(error, InvokeError) + + def test_invoke_error_inheritance(self): + # Test that we can override the default description in subclasses + error = InvokeBadRequestError("Overridden Error") + assert error.description == "Overridden Error" + assert str(error) == "Overridden Error" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py new file mode 100644 index 0000000000..382dce876e --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_ai_model.py @@ -0,0 +1,336 @@ +import decimal +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + DefaultParameterName, + FetchFrom, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, + PriceConfig, + PriceType, +) +from dify_graph.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel + + +class TestAIModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def ai_model(self, mock_plugin_model_provider): + return AIModel( + tenant_id="tenant_123", + model_type=ModelType.LLM, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_invoke_error_mapping(self, ai_model): + mapping = ai_model._invoke_error_mapping + assert InvokeConnectionError in mapping + assert InvokeServerUnavailableError in mapping + assert InvokeRateLimitError in mapping + assert InvokeAuthorizationError in mapping + assert InvokeBadRequestError in mapping + assert PluginDaemonInnerError in mapping + assert ValueError in mapping + + def test_transform_invoke_error(self, ai_model): + # Case: mapped error (InvokeAuthorizationError) + err = Exception("Original error") + with patch.object(AIModel, "_invoke_error_mapping", {InvokeAuthorizationError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeAuthorizationError) + assert "Incorrect model credentials provided" in str(transformed.description) + + # Case: mapped error (InvokeError subclass) + with patch.object(AIModel, "_invoke_error_mapping", {InvokeRateLimitError("Rate limit"): [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert isinstance(transformed, InvokeError) + assert "[test_provider]" in transformed.description + + # Case: mapped error (not InvokeError) + class CustomNonInvokeError(Exception): + pass + + with patch.object(AIModel, "_invoke_error_mapping", {CustomNonInvokeError: [Exception]}): + transformed = ai_model._transform_invoke_error(err) + assert transformed == err + + # Case: unmapped error + unmapped_err = Exception("Unmapped") + transformed = ai_model._transform_invoke_error(unmapped_err) + assert isinstance(transformed, InvokeError) + assert "Error: Unmapped" in transformed.description + + def test_get_price(self, ai_model): + model_name = "test_model" + credentials = {"key": "value"} + + # Mock get_model_schema + mock_schema = MagicMock(spec=AIModelEntity) + mock_schema.pricing = PriceConfig( + input=decimal.Decimal("0.002"), + output=decimal.Decimal("0.004"), + unit=decimal.Decimal(1000), # 1000 tokens per unit + currency="USD", + ) + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + # Test INPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.002") + + # Test OUTPUT + price_info = ai_model.get_price(model_name, credentials, PriceType.OUTPUT, 2000) + assert price_info.unit_price == decimal.Decimal("0.004") + + # Case: unit_price is None (returns zeroed PriceInfo) + mock_schema.pricing = None + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + price_info = ai_model.get_price(model_name, credentials, PriceType.INPUT, 1000) + assert price_info.total_amount == decimal.Decimal("0.0") + + def test_get_price_no_price_config_error(self, ai_model): + model_name = "test_model" + + # We need it to be truthy at line 107 and 112 but falsy at line 127. + class ChangingPriceConfig: + def __init__(self): + self.input = decimal.Decimal("0.01") + self.unit = decimal.Decimal(1) + self.currency = "USD" + self.called = 0 + + def __bool__(self): + self.called += 1 + return self.called <= 2 + + mock_schema = MagicMock() + mock_schema.pricing = ChangingPriceConfig() + + with patch.object(AIModel, "get_model_schema", return_value=mock_schema): + with pytest.raises(ValueError) as excinfo: + ai_model.get_price(model_name, {}, PriceType.INPUT, 1000) + assert "Price config not found" in str(excinfo.value) + + def test_get_model_schema_cache_hit(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis: + mock_redis.get.return_value = mock_schema.model_dump_json().encode() + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema.model == "test_model" + mock_redis.get.assert_called_once() + + def test_get_model_schema_cache_miss(self, ai_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, credentials) + + assert schema == mock_schema + mock_manager.get_model_schema.assert_called_once() + mock_redis.setex.assert_called_once() + + def test_get_model_schema_redis_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.side_effect = RedisError("Connection refused") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_manager.get_model_schema.assert_called_once() + + def test_get_model_schema_validation_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b"invalid json" + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + # This should trigger ValidationError at line 166 and go to delete() + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_delete_error(self, ai_model): + model_name = "test_model" + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = b'{"invalid": "schema"}' + mock_redis.delete.side_effect = RedisError("Delete failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = None + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema is None + mock_redis.delete.assert_called() + + def test_get_model_schema_redis_setex_error(self, ai_model): + model_name = "test_model" + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + with ( + patch("dify_graph.model_runtime.model_providers.__base.ai_model.redis_client") as mock_redis, + patch("core.plugin.impl.model.PluginModelClient") as mock_client, + ): + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RuntimeError("Setex failed") + mock_manager = mock_client.return_value + mock_manager.get_model_schema.return_value = mock_schema + + schema = ai_model.get_model_schema(model_name, {}) + + assert schema == mock_schema + mock_redis.setex.assert_called() + + def test_get_customizable_model_schema_from_credentials_template_mapping_value_error(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="invalid", + use_template="invalid_template_name", + label=I18nObject(en_US="Invalid"), + type=ParameterType.FLOAT, + ) + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + assert schema.parameter_rules[0].use_template == "invalid_template_name" + + def test_get_customizable_model_schema_from_credentials(self, ai_model): + model_name = "test_model" + + mock_schema = AIModelEntity( + model="test_model", + label=I18nObject(en_US="Test Model"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[ + ParameterRule( + name="temp", use_template="temperature", label=I18nObject(en_US="Temp"), type=ParameterType.FLOAT + ), + ParameterRule( + name="top_p", + use_template="top_p", + label=I18nObject(en_US="Top P"), + type=ParameterType.FLOAT, + help=I18nObject(en_US=""), + ), + ParameterRule( + name="max_tokens", + use_template="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + help=I18nObject(en_US="", zh_Hans=""), + ), + ParameterRule(name="custom", label=I18nObject(en_US="Custom"), type=ParameterType.STRING), + ], + ) + + with patch.object(AIModel, "get_customizable_model_schema", return_value=mock_schema): + schema = ai_model.get_customizable_model_schema_from_credentials(model_name, {}) + + assert schema.parameter_rules[0].max == 1.0 + assert schema.parameter_rules[1].help.en_US != "" + assert schema.parameter_rules[2].help.zh_Hans != "" + assert schema.parameter_rules[3].use_template is None + + def test_get_customizable_model_schema_from_credentials_none(self, ai_model): + with patch.object(AIModel, "get_customizable_model_schema", return_value=None): + schema = ai_model.get_customizable_model_schema_from_credentials("model", {}) + assert schema is None + + def test_get_customizable_model_schema_default(self, ai_model): + assert ai_model.get_customizable_model_schema("model", {}) is None + + def test_get_default_parameter_rule_variable_map(self, ai_model): + # Valid + res = ai_model._get_default_parameter_rule_variable_map(DefaultParameterName.TEMPERATURE) + assert res["default"] == 0.0 + + # Invalid + with pytest.raises(Exception) as excinfo: + ai_model._get_default_parameter_rule_variable_map("invalid_name") + assert "Invalid model parameter rule name" in str(excinfo.value) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py new file mode 100644 index 0000000000..a692f8023a --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_large_language_model.py @@ -0,0 +1,476 @@ +import logging +from collections.abc import Generator, Iterator, Sequence +from dataclasses import dataclass, field +from datetime import datetime +from decimal import Decimal +from types import SimpleNamespace +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import dify_graph.model_runtime.model_providers.__base.large_language_model as llm_module + +# Access large_language_model members via llm_module to avoid partial import issues in CI +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.callbacks.base_callback import Callback +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelType, PriceInfo +from dify_graph.model_runtime.model_providers.__base.large_language_model import _build_llm_result_from_chunks + + +def _usage(prompt_tokens: int = 1, completion_tokens: int = 2) -> LLMUsage: + return LLMUsage( + prompt_tokens=prompt_tokens, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1), + prompt_price=Decimal(prompt_tokens) * Decimal("0.001"), + completion_tokens=completion_tokens, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1), + completion_price=Decimal(completion_tokens) * Decimal("0.002"), + total_tokens=prompt_tokens + completion_tokens, + total_price=Decimal(prompt_tokens) * Decimal("0.001") + Decimal(completion_tokens) * Decimal("0.002"), + currency="USD", + latency=0.0, + ) + + +def _tool_call_delta( + *, + tool_call_id: str, + tool_type: str = "function", + function_name: str = "", + function_arguments: str = "", +) -> AssistantPromptMessage.ToolCall: + return AssistantPromptMessage.ToolCall( + id=tool_call_id, + type=tool_type, + function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=function_name, arguments=function_arguments), + ) + + +def _chunk( + *, + model: str = "test-model", + content: str | list[Any] | None = None, + tool_calls: list[AssistantPromptMessage.ToolCall] | None = None, + usage: LLMUsage | None = None, + system_fingerprint: str | None = None, +) -> LLMResultChunk: + return LLMResultChunk( + model=model, + system_fingerprint=system_fingerprint, + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=content, tool_calls=tool_calls or []), + usage=usage, + ), + ) + + +@dataclass +class SpyCallback(Callback): + raise_error: bool = False + before: list[dict[str, Any]] = field(default_factory=list) + new_chunk: list[dict[str, Any]] = field(default_factory=list) + after: list[dict[str, Any]] = field(default_factory=list) + error: list[dict[str, Any]] = field(default_factory=list) + + def on_before_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.before.append(kwargs) + + def on_new_chunk(self, **kwargs: Any) -> None: # type: ignore[override] + self.new_chunk.append(kwargs) + + def on_after_invoke(self, **kwargs: Any) -> None: # type: ignore[override] + self.after.append(kwargs) + + def on_invoke_error(self, **kwargs: Any) -> None: # type: ignore[override] + self.error.append(kwargs) + + +class _TestLLM(llm_module.LargeLanguageModel): + def get_price(self, model: str, credentials: dict, price_type: Any, tokens: int) -> PriceInfo: # type: ignore[override] + return PriceInfo( + unit_price=Decimal("0.01"), + unit=Decimal(1), + total_amount=Decimal(tokens) * Decimal("0.01"), + currency="USD", + ) + + def _transform_invoke_error(self, error: Exception) -> Exception: # type: ignore[override] + return RuntimeError(f"transformed: {error}") + + +@pytest.fixture +def llm() -> _TestLLM: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return _TestLLM.model_construct( + tenant_id="tenant", + model_type=ModelType.LLM, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + started_at=1.0, + ) + + +def test_gen_tool_call_id_is_uuid_based(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="abc123")) + assert llm_module._gen_tool_call_id() == "chatcmpl-tool-abc123" + + +def test_run_callbacks_no_callbacks_noop() -> None: + invoked: list[int] = [] + llm_module._run_callbacks(None, event="x", invoke=lambda _: invoked.append(1)) + llm_module._run_callbacks([], event="x", invoke=lambda _: invoked.append(1)) + assert invoked == [] + + +def test_run_callbacks_swallows_error_when_raise_error_false(caplog: pytest.LogCaptureFixture) -> None: + class Boom: + raise_error = False + + caplog.set_level(logging.WARNING) + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + assert any("Callback" in record.message and "failed with error" in record.message for record in caplog.records) + + +def test_run_callbacks_reraises_when_raise_error_true() -> None: + class Boom: + raise_error = True + + with pytest.raises(ValueError, match="boom"): + llm_module._run_callbacks( + [Boom()], event="on_before_invoke", invoke=lambda _: (_ for _ in ()).throw(ValueError("boom")) + ) + + +def test_get_or_create_tool_call_empty_id_returns_last() -> None: + calls = [ + _tool_call_delta(tool_call_id="id1", function_name="a"), + _tool_call_delta(tool_call_id="id2", function_name="b"), + ] + assert llm_module._get_or_create_tool_call(calls, "") is calls[-1] + + +def test_get_or_create_tool_call_empty_id_without_existing_raises() -> None: + with pytest.raises(ValueError, match="tool_call_id is empty"): + llm_module._get_or_create_tool_call([], "") + + +def test_get_or_create_tool_call_creates_if_missing() -> None: + calls: list[AssistantPromptMessage.ToolCall] = [] + tool_call = llm_module._get_or_create_tool_call(calls, "new-id") + assert tool_call.id == "new-id" + assert tool_call.function.name == "" + assert tool_call.function.arguments == "" + assert calls == [tool_call] + + +def test_get_or_create_tool_call_returns_existing_when_found() -> None: + existing = _tool_call_delta(tool_call_id="same-id", function_name="fn", function_arguments="{}") + calls = [existing] + assert llm_module._get_or_create_tool_call(calls, "same-id") is existing + + +def test_merge_tool_call_delta_updates_fields_and_appends_arguments() -> None: + tool_call = _tool_call_delta(tool_call_id="id", tool_type="function", function_name="x", function_arguments="{") + delta = _tool_call_delta(tool_call_id="id2", tool_type="function", function_name="y", function_arguments="}") + llm_module._merge_tool_call_delta(tool_call, delta) + assert tool_call.id == "id2" + assert tool_call.type == "function" + assert tool_call.function.name == "y" + assert tool_call.function.arguments == "{}" + + +def test_increase_tool_call_generates_id_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="fixed")) + delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{") + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call([delta], existing) + assert len(existing) == 1 + assert existing[0].id == "chatcmpl-tool-fixed" + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{" + + +def test_increase_tool_call_merges_incremental_arguments() -> None: + existing: list[AssistantPromptMessage.ToolCall] = [] + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="fn", function_arguments="{")], existing + ) + llm_module._increase_tool_call( + [_tool_call_delta(tool_call_id="id", function_name="", function_arguments="}")], existing + ) + assert len(existing) == 1 + assert existing[0].function.name == "fn" + assert existing[0].function.arguments == "{}" + + +@pytest.mark.parametrize( + ("content", "expected_type"), + [ + ("hello", str), + ([TextPromptMessageContent(data="hello")], list), + ], +) +def test_build_llm_result_from_chunks_accumulates_and_raises_error( + content: str | list[TextPromptMessageContent], + expected_type: type, + monkeypatch: pytest.MonkeyPatch, + caplog: pytest.LogCaptureFixture, +) -> None: + monkeypatch.setattr(llm_module.uuid, "uuid4", lambda: SimpleNamespace(hex="drain")) + caplog.set_level(logging.DEBUG) + + tool_delta = _tool_call_delta(tool_call_id="", function_name="fn", function_arguments="{}") + first = _chunk(content=content, tool_calls=[tool_delta], usage=_usage(3, 4), system_fingerprint="fp1") + + def iter_with_error() -> Iterator[LLMResultChunk]: + yield first + raise RuntimeError("drain boom") + + with pytest.raises(RuntimeError, match="drain boom"): + _build_llm_result_from_chunks( + model="m", prompt_messages=[UserPromptMessage(content="u")], chunks=iter_with_error() + ) + + assert any("Error while consuming non-stream plugin chunk iterator" in record.message for record in caplog.records) + + +def test_build_llm_result_from_chunks_empty_iterator() -> None: + def empty() -> Iterator[LLMResultChunk]: + if False: # pragma: no cover + yield _chunk() + return + + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=empty()) + assert result.message.content == [] + assert result.usage.total_tokens == 0 + assert result.system_fingerprint is None + + +def test_build_llm_result_from_chunks_accumulates_all_chunks() -> None: + chunks = iter([_chunk(content="first"), _chunk(content="second")]) + result = _build_llm_result_from_chunks(model="m", prompt_messages=[], chunks=chunks) + assert result.message.content == "firstsecond" + + +def test_invoke_llm_via_plugin_passes_list_converted_stop(monkeypatch: pytest.MonkeyPatch) -> None: + invoked: dict[str, Any] = {} + + class FakePluginModelClient: + def invoke_llm(self, **kwargs: Any) -> str: + invoked.update(kwargs) + return "ok" + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + + prompt_messages: Sequence[PromptMessage] = (UserPromptMessage(content="hi"),) + result = llm_module._invoke_llm_via_plugin( + tenant_id="t", + user_id="u", + plugin_id="p", + provider="prov", + model="m", + credentials={"k": "v"}, + model_parameters={"temp": 1}, + prompt_messages=prompt_messages, + tools=None, + stop=("a", "b"), + stream=True, + ) + + assert result == "ok" + assert invoked["prompt_messages"] == list(prompt_messages) + assert invoked["stop"] == ["a", "b"] + + +def test_normalize_non_stream_plugin_result_passthrough_llmresult() -> None: + llm_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + assert ( + llm_module._normalize_non_stream_plugin_result(model="m", prompt_messages=[], result=llm_result) is llm_result + ) + + +def test_normalize_non_stream_plugin_result_builds_from_chunks() -> None: + chunks = iter([_chunk(content="hello", usage=_usage(1, 1))]) + result = llm_module._normalize_non_stream_plugin_result( + model="m", prompt_messages=[UserPromptMessage(content="u")], result=chunks + ) + assert isinstance(result, LLMResult) + assert result.message.content == "hello" + + +def test_invoke_non_stream_normalizes_and_sets_prompt_messages(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_result = LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_result, + ) + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + result = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=False, callbacks=[cb]) + assert isinstance(result, LLMResult) + assert result.prompt_messages == prompt_messages + assert len(cb.before) == 1 + assert len(cb.after) == 1 + assert cb.after[0]["result"].prompt_messages == prompt_messages + + +def test_invoke_stream_wraps_generator_and_triggers_callbacks(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + plugin_chunks = iter( + [ + _chunk(model="m1", content="a"), + _chunk( + model="m2", content=[TextPromptMessageContent(data="b")], usage=_usage(2, 3), system_fingerprint="fp" + ), + _chunk(model="m3", content=None), + ] + ) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: plugin_chunks, + ) + + cb = SpyCallback() + prompt_messages = [UserPromptMessage(content="hi")] + gen = llm.invoke(model="m", credentials={}, prompt_messages=prompt_messages, stream=True, callbacks=[cb]) + + assert isinstance(gen, Generator) + chunks = list(gen) + assert len(chunks) == 3 + assert all(chunk.prompt_messages == prompt_messages for chunk in chunks) + assert len(cb.before) == 1 + assert len(cb.new_chunk) == 3 + assert len(cb.after) == 1 + final_result: LLMResult = cb.after[0]["result"] + assert final_result.model == "m3" + assert final_result.system_fingerprint == "fp" + assert isinstance(final_result.message.content, list) + assert [c.data for c in final_result.message.content] == ["a", "b"] + assert final_result.usage.total_tokens == 5 + + +def test_invoke_triggers_error_callbacks_and_raises_transformed(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + def boom(**_: Any) -> Any: + raise ValueError("plugin down") + + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", boom + ) + cb = SpyCallback() + with pytest.raises(RuntimeError, match="transformed: plugin down"): + llm.invoke( + model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False, callbacks=[cb] + ) + assert len(cb.error) == 1 + assert isinstance(cb.error[0]["ex"], ValueError) + + +def test_invoke_raises_not_implemented_for_unsupported_result_type( + llm: _TestLLM, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(llm_module, "_invoke_llm_via_plugin", lambda **_: "not-a-result") + monkeypatch.setattr(llm_module, "_normalize_non_stream_plugin_result", lambda **_: "not-a-result") + with pytest.raises(NotImplementedError, match="unsupported invoke result type"): + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + + +def test_invoke_appends_logging_callback_in_debug(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + captured_callbacks: list[list[Callback]] = [] + + class FakeLoggingCallback(SpyCallback): + pass + + monkeypatch.setattr(llm_module, "LoggingCallback", FakeLoggingCallback) + monkeypatch.setattr(llm_module.dify_config, "DEBUG", True) + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.__base.large_language_model._invoke_llm_via_plugin", + lambda **_: LLMResult(model="m", message=AssistantPromptMessage(content="x"), usage=_usage()), + ) + + original_trigger = llm._trigger_before_invoke_callbacks + + def spy_trigger(*args: Any, **kwargs: Any) -> None: + captured_callbacks.append(list(kwargs["callbacks"])) + original_trigger(*args, **kwargs) + + monkeypatch.setattr(llm, "_trigger_before_invoke_callbacks", spy_trigger) + llm.invoke(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")], stream=False) + assert any(isinstance(cb, FakeLoggingCallback) for cb in captured_callbacks[0]) + + +def test_get_num_tokens_returns_0_when_plugin_disabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 0 + + +def test_get_num_tokens_uses_plugin_when_enabled(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", True) + + class FakePluginModelClient: + def get_llm_num_tokens(self, **kwargs: Any) -> int: + assert kwargs["tenant_id"] == "tenant" + assert kwargs["plugin_id"] == "plugin-id" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "llm" + return 42 + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + assert llm.get_num_tokens(model="m", credentials={}, prompt_messages=[UserPromptMessage(content="x")]) == 42 + + +def test_calc_response_usage_uses_prices_and_latency(llm: _TestLLM, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(llm_module.time, "perf_counter", lambda: 4.5) + llm.started_at = 1.0 + usage = llm.calc_response_usage(model="m", credentials={}, prompt_tokens=10, completion_tokens=5) + assert usage.total_tokens == 15 + assert usage.total_price == Decimal("0.15") + assert usage.latency == 3.5 + + +def test_invoke_result_generator_raises_transformed_on_iteration_error(llm: _TestLLM) -> None: + def broken() -> Iterator[LLMResultChunk]: + yield _chunk(content="ok") + raise ValueError("chunk stream broken") + + gen = llm._invoke_result_generator( + model="m", + result=broken(), + credentials={}, + prompt_messages=[UserPromptMessage(content="u")], + model_parameters={}, + callbacks=[SpyCallback()], + ) + + with pytest.raises(RuntimeError, match="transformed: chunk stream broken"): + list(gen) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py new file mode 100644 index 0000000000..6ccc44ceb8 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_moderation_model.py @@ -0,0 +1,90 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel + + +class TestModerationModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def moderation_model(self, mock_plugin_model_provider): + return ModerationModel( + tenant_id="tenant_123", + model_type=ModelType.MODERATION, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, moderation_model): + assert moderation_model.model_type == ModelType.MODERATION + + def test_invoke_success(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + user = "user_123" + + with ( + patch("core.plugin.impl.model.PluginModelClient") as mock_client_class, + patch("time.perf_counter", return_value=1.0), + ): + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = True + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text, user=user) + + assert result is True + assert moderation_model.started_at == 1.0 + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_success_no_user(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.return_value = False + + result = moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert result is False + mock_client.invoke_moderation.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + text=text, + ) + + def test_invoke_exception(self, moderation_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + text = "test text" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_moderation.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + moderation_model.invoke(model=model_name, credentials=credentials, text=text) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py new file mode 100644 index 0000000000..67828894b3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_rerank_model.py @@ -0,0 +1,181 @@ +from datetime import datetime +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel + + +@pytest.fixture +def rerank_model() -> RerankModel: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + return RerankModel.model_construct( + tenant_id="tenant", + model_type=ModelType.RERANK, + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + + +def test_model_type_is_rerank_by_default() -> None: + plugin_provider = PluginModelProviderEntity.model_construct( + id="provider-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider="provider", + tenant_id="tenant", + plugin_unique_identifier="plugin-uid", + plugin_id="plugin-id", + declaration=MagicMock(), + ) + model = RerankModel( + tenant_id="tenant", + plugin_id="plugin-id", + provider_name="provider", + plugin_model_provider=plugin_provider, + ) + assert model.model_type == ModelType.RERANK + + +def test_invoke_calls_plugin_and_passes_args(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + expected = RerankResult(model="rerank", docs=[RerankDocument(index=0, text="a", score=0.5)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_rerank_called_with: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + result = rerank_model.invoke( + model="rerank", + credentials={"k": "v"}, + query="q", + docs=["d1", "d2"], + score_threshold=0.2, + top_n=10, + user="user-1", + ) + + assert result == expected + assert fake_client.invoke_rerank_called_with == { + "tenant_id": "tenant", + "user_id": "user-1", + "plugin_id": "plugin-id", + "provider": "provider", + "model": "rerank", + "credentials": {"k": "v"}, + "query": "q", + "docs": ["d1", "d2"], + "score_threshold": 0.2, + "top_n": 10, + } + + +def test_invoke_uses_unknown_user_when_not_provided(rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch) -> None: + class FakePluginModelClient: + def __init__(self) -> None: + self.kwargs: dict[str, Any] | None = None + + def invoke_rerank(self, **kwargs: Any) -> RerankResult: + self.kwargs = kwargs + return RerankResult(model="m", docs=[]) + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + assert fake_client.kwargs is not None + assert fake_client.kwargs["user_id"] == "unknown" + + +def test_invoke_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke(model="m", credentials={}, query="q", docs=["d"]) + + +def test_invoke_multimodal_calls_plugin_and_passes_args( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + expected = RerankResult(model="mm", docs=[RerankDocument(index=0, text="x", score=0.9)]) + + class FakePluginModelClient: + def __init__(self) -> None: + self.invoke_multimodal_rerank_called_with: dict[str, Any] | None = None + + def invoke_multimodal_rerank(self, **kwargs: Any) -> RerankResult: + self.invoke_multimodal_rerank_called_with = kwargs + return expected + + import core.plugin.impl.model as plugin_model_module + + fake_client = FakePluginModelClient() + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: fake_client) + + query = {"type": "text", "text": "q"} + docs = [{"type": "text", "text": "d1"}] + result = rerank_model.invoke_multimodal_rerank( + model="mm", + credentials={"k": "v"}, + query=query, + docs=docs, + score_threshold=None, + top_n=None, + user=None, + ) + + assert result == expected + assert fake_client.invoke_multimodal_rerank_called_with is not None + assert fake_client.invoke_multimodal_rerank_called_with["tenant_id"] == "tenant" + assert fake_client.invoke_multimodal_rerank_called_with["user_id"] == "unknown" + assert fake_client.invoke_multimodal_rerank_called_with["query"] == query + assert fake_client.invoke_multimodal_rerank_called_with["docs"] == docs + + +def test_invoke_multimodal_transforms_and_raises_on_plugin_error( + rerank_model: RerankModel, monkeypatch: pytest.MonkeyPatch +) -> None: + class FakePluginModelClient: + def invoke_multimodal_rerank(self, **_: Any) -> RerankResult: + raise ValueError("plugin down") + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", FakePluginModelClient) + monkeypatch.setattr(rerank_model, "_transform_invoke_error", lambda e: RuntimeError(f"transformed: {e}")) + + with pytest.raises(RuntimeError, match="transformed: plugin down"): + rerank_model.invoke_multimodal_rerank(model="m", credentials={}, query={"q": 1}, docs=[{"d": 1}]) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py new file mode 100644 index 0000000000..f891718dc6 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_speech2text_model.py @@ -0,0 +1,87 @@ +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel + + +class TestSpeech2TextModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def speech2text_model(self, mock_plugin_model_provider): + return Speech2TextModel( + tenant_id="tenant_123", + model_type=ModelType.SPEECH2TEXT, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, speech2text_model): + assert speech2text_model.model_type == ModelType.SPEECH2TEXT + + def test_invoke_success(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file, user=user) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_success_no_user(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.return_value = "transcribed text" + + result = speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert result == "transcribed text" + mock_client.invoke_speech_to_text.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + file=file, + ) + + def test_invoke_exception(self, speech2text_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + file = BytesIO(b"audio data") + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_speech_to_text.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + speech2text_model.invoke(model=model_name, credentials=credentials, file=file) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py new file mode 100644 index 0000000000..c8f0a2ad49 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_text_embedding_model.py @@ -0,0 +1,185 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.embedding_type import EmbeddingInputType +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel + + +class TestTextEmbeddingModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def text_embedding_model(self, mock_plugin_model_provider): + return TextEmbeddingModel( + tenant_id="tenant_123", + model_type=ModelType.TEXT_EMBEDDING, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, text_embedding_model): + assert text_embedding_model.model_type == ModelType.TEXT_EMBEDDING + + def test_invoke_with_texts(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + user = "user_123" + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts, user=user) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_with_multimodel_documents(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + multimodel_documents = [{"type": "text", "text": "hello"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_multimodal_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_multimodal_embedding.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + documents=multimodel_documents, + input_type=EmbeddingInputType.DOCUMENT, + ) + + def test_invoke_no_input(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + with pytest.raises(ValueError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials) + + assert "No texts or files provided" in str(excinfo.value) + + def test_invoke_precedence(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + multimodel_documents = [{"type": "text", "text": "world"}] + expected_result = MagicMock(spec=EmbeddingResult) + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.return_value = expected_result + + result = text_embedding_model.invoke( + model=model_name, credentials=credentials, texts=texts, multimodel_documents=multimodel_documents + ) + + assert result == expected_result + mock_client.invoke_text_embedding.assert_called_once() + mock_client.invoke_multimodal_embedding.assert_not_called() + + def test_invoke_exception(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello"] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_text_embedding.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + text_embedding_model.invoke(model=model_name, credentials=credentials, texts=texts) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_num_tokens(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + texts = ["hello", "world"] + expected_tokens = [1, 1] + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_text_embedding_num_tokens.return_value = expected_tokens + + result = text_embedding_model.get_num_tokens(model=model_name, credentials=credentials, texts=texts) + + assert result == expected_tokens + mock_client.get_text_embedding_num_tokens.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + texts=texts, + ) + + def test_get_context_size(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Context size in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.CONTEXT_SIZE: 2048} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 2048 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + # Test case 3: Context size NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_context_size(model_name, credentials) == 1000 + + def test_get_max_chunks(self, text_embedding_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + + # Test case 1: Max chunks in schema + mock_schema = MagicMock() + mock_schema.model_properties = {ModelPropertyKey.MAX_CHUNKS: 10} + + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 10 + + # Test case 2: No schema + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=None): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 + + # Test case 3: Max chunks NOT in schema properties + mock_schema.model_properties = {} + with patch.object(TextEmbeddingModel, "get_model_schema", return_value=mock_schema): + assert text_embedding_model._get_max_chunks(model_name, credentials) == 1 diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py new file mode 100644 index 0000000000..b1aca9baa3 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/test_tts_model.py @@ -0,0 +1,131 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.invoke import InvokeError +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel + + +class TestTTSModel: + @pytest.fixture + def mock_plugin_model_provider(self): + return MagicMock(spec=PluginModelProviderEntity) + + @pytest.fixture + def tts_model(self, mock_plugin_model_provider): + return TTSModel( + tenant_id="tenant_123", + model_type=ModelType.TTS, + plugin_id="plugin_123", + provider_name="test_provider", + plugin_model_provider=mock_plugin_model_provider, + ) + + def test_model_type(self, tts_model): + assert tts_model.model_type == ModelType.TTS + + def test_invoke_success(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + user = "user_123" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + user=user, + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="user_123", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_success_no_user(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.return_value = [b"audio_chunk"] + + result = tts_model.invoke( + model=model_name, tenant_id=tenant_id, credentials=credentials, content_text=content_text, voice=voice + ) + + assert list(result) == [b"audio_chunk"] + mock_client.invoke_tts.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def test_invoke_exception(self, tts_model): + model_name = "test_model" + tenant_id = "ignored_tenant_id" + credentials = {"api_key": "abc"} + content_text = "Hello world" + voice = "alloy" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.invoke_tts.side_effect = Exception("Test error") + + with pytest.raises(InvokeError) as excinfo: + tts_model.invoke( + model=model_name, + tenant_id=tenant_id, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + assert "[test_provider] Error: Test error" in str(excinfo.value.description) + + def test_get_tts_model_voices(self, tts_model): + model_name = "test_model" + credentials = {"api_key": "abc"} + language = "en-US" + + with patch("core.plugin.impl.model.PluginModelClient") as mock_client_class: + mock_client = mock_client_class.return_value + mock_client.get_tts_model_voices.return_value = [{"name": "Voice1"}] + + result = tts_model.get_tts_model_voices(model=model_name, credentials=credentials, language=language) + + assert result == [{"name": "Voice1"}] + mock_client.get_tts_model_voices.assert_called_once_with( + tenant_id="tenant_123", + user_id="unknown", + plugin_id="plugin_123", + provider="test_provider", + model=model_name, + credentials=credentials, + language=language, + ) diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py new file mode 100644 index 0000000000..dde6ea02b5 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/__base/tokenizers/test_gpt2_tokenizer.py @@ -0,0 +1,96 @@ +from unittest.mock import MagicMock, patch + +import dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer as gpt2_tokenizer_module +from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + + +class TestGPT2Tokenizer: + def setup_method(self): + # Reset the global tokenizer before each test to ensure we test initialization + gpt2_tokenizer_module._tokenizer = None + + def test_get_encoder_tiktoken(self): + """ + Test that get_encoder successfully uses tiktoken when available. + """ + mock_encoding = MagicMock() + # Mock tiktoken to be sure it's used + with patch("tiktoken.get_encoding", return_value=mock_encoding) as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_get_encoding.assert_called_once_with("gpt2") + + # Verify singleton behavior within the same test + encoder2 = GPT2Tokenizer.get_encoder() + assert encoder2 is encoder + assert mock_get_encoding.call_count == 1 + + def test_get_encoder_tiktoken_fallback(self): + """ + Test that get_encoder falls back to transformers when tiktoken fails. + """ + # patch tiktoken.get_encoding to raise an exception + with patch("tiktoken.get_encoding", side_effect=Exception("Tiktoken failure")): + # patch transformers.GPT2Tokenizer + with patch("transformers.GPT2Tokenizer.from_pretrained") as mock_from_pretrained: + mock_transformer_tokenizer = MagicMock() + mock_from_pretrained.return_value = mock_transformer_tokenizer + + with patch( + "dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer.logger" + ) as mock_logger: + encoder = GPT2Tokenizer.get_encoder() + + assert encoder == mock_transformer_tokenizer + mock_from_pretrained.assert_called_once() + mock_logger.info.assert_called_once_with("Fallback to Transformers' GPT-2 tokenizer from tiktoken") + + def test_get_num_tokens(self): + """ + Test get_num_tokens returns the correct count. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2, 3, 4, 5] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer.get_num_tokens("test text") + assert tokens_count == 5 + mock_encoder.encode.assert_called_once_with("test text") + + def test_get_num_tokens_by_gpt2_direct(self): + """ + Test _get_num_tokens_by_gpt2 directly. + """ + mock_encoder = MagicMock() + mock_encoder.encode.return_value = [1, 2] + + with patch.object(GPT2Tokenizer, "get_encoder", return_value=mock_encoder): + tokens_count = GPT2Tokenizer._get_num_tokens_by_gpt2("hello") + assert tokens_count == 2 + mock_encoder.encode.assert_called_once_with("hello") + + def test_get_encoder_already_initialized(self): + """ + Test that if _tokenizer is already set, it returns it immediately. + """ + mock_existing_tokenizer = MagicMock() + gpt2_tokenizer_module._tokenizer = mock_existing_tokenizer + + # Tiktoken should not be called if already initialized + with patch("tiktoken.get_encoding") as mock_get_encoding: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_existing_tokenizer + mock_get_encoding.assert_not_called() + + def test_get_encoder_thread_safety(self): + """ + Simple test to ensure the lock is used. + """ + mock_encoding = MagicMock() + with patch("tiktoken.get_encoding", return_value=mock_encoding): + # We patch the lock in the module + with patch("dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer._lock") as mock_lock: + encoder = GPT2Tokenizer.get_encoder() + assert encoder == mock_encoding + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() diff --git a/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py new file mode 100644 index 0000000000..1ad0210375 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/model_providers/test_model_provider_factory.py @@ -0,0 +1,522 @@ +import logging +from datetime import datetime +from threading import Lock +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from redis import RedisError + +import contexts +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelPropertyKey, + ModelType, +) +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _provider_entity( + *, + provider: str, + supported_model_types: list[ModelType] | None = None, + models: list[AIModelEntity] | None = None, + icon_small: I18nObject | None = None, + icon_small_dark: I18nObject | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + label=I18nObject(en_US=provider), + supported_model_types=supported_model_types or [ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + icon_small=icon_small, + icon_small_dark=icon_small_dark, + ) + + +def _plugin_provider( + *, plugin_id: str, declaration: ProviderEntity, provider: str = "provider" +) -> PluginModelProviderEntity: + return PluginModelProviderEntity.model_construct( + id=f"{plugin_id}-id", + created_at=datetime.now(), + updated_at=datetime.now(), + provider=provider, + tenant_id="tenant", + plugin_unique_identifier=f"{plugin_id}-uid", + plugin_id=plugin_id, + declaration=declaration, + ) + + +@pytest.fixture(autouse=True) +def _reset_plugin_model_provider_context() -> None: + contexts.plugin_model_providers_lock.set(Lock()) + contexts.plugin_model_providers.set(None) + + +@pytest.fixture +def fake_plugin_manager(monkeypatch: pytest.MonkeyPatch) -> MagicMock: + manager = MagicMock() + + import core.plugin.impl.model as plugin_model_module + + monkeypatch.setattr(plugin_model_module, "PluginModelClient", lambda: manager) + return manager + + +@pytest.fixture +def factory(fake_plugin_manager: MagicMock) -> ModelProviderFactory: + return ModelProviderFactory(tenant_id="tenant") + + +def test_get_plugin_model_providers_initializes_context_on_lookup_error( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + original_get = contexts.plugin_model_providers.get + calls = {"n": 0} + + def flaky_get() -> Any: + calls["n"] += 1 + if calls["n"] == 1: + raise LookupError + return original_get() + + monkeypatch.setattr(contexts.plugin_model_providers, "get", flaky_get) + + providers = factory.get_plugin_model_providers() + assert len(providers) == 1 + assert providers[0].declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_providers_caches_and_does_not_refetch( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + first = factory.get_plugin_model_providers() + second = factory.get_plugin_model_providers() + + assert first is second + fake_plugin_manager.fetch_model_providers.assert_called_once_with("tenant") + + +def test_get_providers_returns_declarations(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + d1 = _provider_entity(provider="openai") + d2 = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=d1), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=d2), + ] + + providers = factory.get_providers() + assert [p.provider for p in providers] == ["langgenius/openai/openai", "langgenius/anthropic/anthropic"] + + +def test_get_plugin_model_provider_converts_short_provider_id( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + provider = factory.get_plugin_model_provider("openai") + assert provider.declaration.provider == "langgenius/openai/openai" + + +def test_get_plugin_model_provider_raises_on_invalid_provider( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + with pytest.raises(ValueError, match="Invalid provider"): + factory.get_plugin_model_provider("langgenius/unknown/unknown") + + +def test_get_provider_schema_returns_declaration(factory: ModelProviderFactory, fake_plugin_manager: MagicMock) -> None: + declaration = _provider_entity(provider="openai") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=declaration) + ] + + schema = factory.get_provider_schema("openai") + assert schema.provider == "langgenius/openai/openai" + + +def test_provider_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"x": "y"}) + + +def test_provider_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.provider_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ProviderCredentialSchemaValidator", + lambda _: fake_validator, + ) + + filtered = factory.provider_credentials_validate(provider="openai", credentials={"raw": True}) + assert filtered == {"filtered": True} + fake_plugin_manager.validate_provider_credentials.assert_called_once() + kwargs = fake_plugin_manager.validate_provider_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["credentials"] == {"filtered": True} + + +def test_model_credentials_validate_errors_when_schema_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = None + monkeypatch.setattr( + factory, + "get_plugin_model_provider", + lambda **_: _plugin_provider(plugin_id="langgenius/openai", declaration=schema), + ) + + with pytest.raises(ValueError, match="does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", model_type=ModelType.LLM, model="m", credentials={"x": "y"} + ) + + +def test_model_credentials_validate_filters_and_calls_plugin_validation( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock, monkeypatch: pytest.MonkeyPatch +) -> None: + schema = _provider_entity(provider="openai") + schema.model_credential_schema = MagicMock() + plugin_provider = _plugin_provider(plugin_id="langgenius/openai", declaration=schema) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda **_: plugin_provider) + + fake_validator = MagicMock() + fake_validator.validate_and_filter.return_value = {"filtered": True} + monkeypatch.setattr( + "dify_graph.model_runtime.model_providers.model_provider_factory.ModelCredentialSchemaValidator", + lambda *_: fake_validator, + ) + + filtered = factory.model_credentials_validate( + provider="openai", model_type=ModelType.TEXT_EMBEDDING, model="m", credentials={"raw": True} + ) + assert filtered == {"filtered": True} + kwargs = fake_plugin_manager.validate_model_credentials.call_args.kwargs + assert kwargs["plugin_id"] == "langgenius/openai" + assert kwargs["provider"] == "provider" + assert kwargs["model_type"] == "text-embedding" + assert kwargs["model"] == "m" + assert kwargs["credentials"] == {"filtered": True} + + +def test_get_model_schema_cache_hit(factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch) -> None: + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = model_schema.model_dump_json().encode() + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials={"k": "v"}) + == model_schema + ) + + +def test_get_model_schema_cache_invalid_json_deletes_key( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert mock_redis.delete.called + assert any("Failed to validate cached plugin model schema" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_delete_redis_error_is_logged( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = b'{"model":"m"}' + mock_redis.delete.side_effect = RedisError("nope") + factory.plugin_model_manager.get_model_schema.return_value = None + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + assert any("Failed to delete invalid plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_redis_get_error_falls_back_to_plugin( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + factory.plugin_model_manager.get_model_schema.return_value = None + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.side_effect = RedisError("down") + assert factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) is None + assert any("Failed to read plugin model schema cache" in r.message for r in caplog.records) + + +def test_get_model_schema_cache_miss_sets_cache_and_handles_setex_error( + factory: ModelProviderFactory, caplog: pytest.LogCaptureFixture +) -> None: + caplog.set_level(logging.WARNING) + factory.get_plugin_id_and_provider_name_from_provider = lambda *_: ("pid", "prov") # type: ignore[method-assign] + + model_schema = AIModelEntity( + model="m", + label=I18nObject(en_US="m"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + factory.plugin_model_manager.get_model_schema.return_value = model_schema + + with patch("dify_graph.model_runtime.model_providers.model_provider_factory.redis_client") as mock_redis: + mock_redis.get.return_value = None + mock_redis.setex.side_effect = RedisError("nope") + assert ( + factory.get_model_schema(provider="x", model_type=ModelType.LLM, model="m", credentials=None) + == model_schema + ) + assert any("Failed to write plugin model schema cache" in r.message for r in caplog.records) + + +@pytest.mark.parametrize( + ("model_type", "expected_class"), + [ + (ModelType.LLM, "LargeLanguageModel"), + (ModelType.TEXT_EMBEDDING, "TextEmbeddingModel"), + (ModelType.RERANK, "RerankModel"), + (ModelType.SPEECH2TEXT, "Speech2TextModel"), + (ModelType.MODERATION, "ModerationModel"), + (ModelType.TTS, "TTSModel"), + ], +) +def test_get_model_type_instance_dispatches_by_type( + factory: ModelProviderFactory, model_type: ModelType, expected_class: str, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + sentinel = object() + monkeypatch.setattr( + f"dify_graph.model_runtime.model_providers.model_provider_factory.{expected_class}", + MagicMock(model_validate=lambda _: sentinel), + ) + + assert factory.get_model_type_instance("langgenius/openai/openai", model_type) is sentinel + + +def test_get_model_type_instance_raises_on_unsupported( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + monkeypatch.setattr(factory, "get_plugin_id_and_provider_name_from_provider", lambda *_: ("pid", "prov")) + monkeypatch.setattr(factory, "get_plugin_model_provider", lambda *_: MagicMock(spec=PluginModelProviderEntity)) + + class UnknownModelType: + pass + + with pytest.raises(ValueError, match="Unsupported model type"): + factory.get_model_type_instance("langgenius/openai/openai", UnknownModelType()) # type: ignore[arg-type] + + +def test_get_models_filters_by_provider_and_model_type( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + llm = AIModelEntity( + model="m1", + label=I18nObject(en_US="m1"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + embed = AIModelEntity( + model="e1", + label=I18nObject(en_US="e1"), + model_type=ModelType.TEXT_EMBEDDING, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={ModelPropertyKey.CONTEXT_SIZE: 1024}, + parameter_rules=[], + ) + + openai = _provider_entity( + provider="openai", supported_model_types=[ModelType.LLM, ModelType.TEXT_EMBEDDING], models=[llm, embed] + ) + anthropic = _provider_entity(provider="anthropic", supported_model_types=[ModelType.LLM], models=[llm]) + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + # ModelType filter picks only matching models + providers = factory.get_models(model_type=ModelType.TEXT_EMBEDDING) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert [m.model for m in providers[0].models] == ["e1"] + + # Provider filter excludes others + providers = factory.get_models(provider="langgenius/anthropic/anthropic", model_type=ModelType.LLM) + assert len(providers) == 1 + assert providers[0].provider == "langgenius/anthropic/anthropic" + + +def test_get_models_provider_filter_skips_non_matching( + factory: ModelProviderFactory, fake_plugin_manager: MagicMock +) -> None: + openai = _provider_entity(provider="openai") + anthropic = _provider_entity(provider="anthropic") + fake_plugin_manager.fetch_model_providers.return_value = [ + _plugin_provider(plugin_id="langgenius/openai", declaration=openai), + _plugin_provider(plugin_id="langgenius/anthropic", declaration=anthropic), + ] + + providers = factory.get_models(provider="langgenius/not-exist/not-exist", model_type=ModelType.LLM) + assert providers == [] + + +def test_get_provider_icon_fetches_asset_and_returns_mime_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + assert tenant_id == "tenant" + return f"bytes:{id}".encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, mime = factory.get_provider_icon("openai", "icon_small", "en_US") + assert data == b"bytes:icon.png" + assert mime == "image/png" + + data, mime = factory.get_provider_icon("openai", "icon_small_dark", "zh_Hans") + assert data == b"bytes:dark-zh.svg" + assert mime == "image/svg+xml" + + +def test_get_provider_icon_uses_zh_hans_for_small_and_en_us_for_dark( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="icon-en.png", zh_Hans="icon-zh.png"), + icon_small_dark=I18nObject(en_US="dark-en.svg", zh_Hans="dark-zh.svg"), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + class FakePluginAssetManager: + def fetch_asset(self, tenant_id: str, id: str) -> bytes: + return id.encode() + + import core.plugin.impl.asset as asset_module + + monkeypatch.setattr(asset_module, "PluginAssetManager", FakePluginAssetManager) + + data, _ = factory.get_provider_icon("openai", "icon_small", "zh_Hans") + assert data == b"icon-zh.png" + + data, _ = factory.get_provider_icon("openai", "icon_small_dark", "en_US") + assert data == b"dark-en.svg" + + +def test_get_provider_icon_raises_for_missing_icons( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity(provider="langgenius/openai/openai") + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + + with pytest.raises(ValueError, match="does not have small icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + with pytest.raises(ValueError, match="does not have small dark icon"): + factory.get_provider_icon("openai", "icon_small_dark", "en_US") + + +def test_get_provider_icon_raises_for_unsupported_icon_type( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="Unsupported icon type"): + factory.get_provider_icon("openai", "nope", "en_US") + + +def test_get_provider_icon_raises_when_file_name_missing( + factory: ModelProviderFactory, monkeypatch: pytest.MonkeyPatch +) -> None: + provider_schema = _provider_entity( + provider="langgenius/openai/openai", + icon_small=I18nObject(en_US="", zh_Hans=""), + ) + monkeypatch.setattr(factory, "get_provider_schema", lambda *_: provider_schema) + with pytest.raises(ValueError, match="does not have icon"): + factory.get_provider_icon("openai", "icon_small", "en_US") + + +def test_get_plugin_id_and_provider_name_from_provider_handles_google_special_case( + factory: ModelProviderFactory, +) -> None: + plugin_id, provider_name = factory.get_plugin_id_and_provider_name_from_provider("google") + assert plugin_id == "langgenius/gemini" + assert provider_name == "google" diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py new file mode 100644 index 0000000000..6d52457c8c --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_common_validator.py @@ -0,0 +1,201 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FormOption, + FormShowOnObject, + FormType, +) +from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator + + +class TestCommonValidator: + def test_validate_credential_form_schema_required_missing(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + with pytest.raises(ValueError, match="Variable api_key is required"): + validator._validate_credential_form_schema(schema, {}) + + def test_validate_credential_form_schema_not_required_missing_with_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + required=False, + default="default_value", + ) + assert validator._validate_credential_form_schema(schema, {}) == "default_value" + + def test_validate_credential_form_schema_not_required_missing_no_default(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=False + ) + assert validator._validate_credential_form_schema(schema, {}) is None + + def test_validate_credential_form_schema_max_length_exceeded(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, max_length=5 + ) + with pytest.raises(ValueError, match="Variable api_key length should not be greater than 5"): + validator._validate_credential_form_schema(schema, {"api_key": "123456"}) + + def test_validate_credential_form_schema_not_string(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT) + with pytest.raises(ValueError, match="Variable api_key should be string"): + validator._validate_credential_form_schema(schema, {"api_key": 123}) + + def test_validate_credential_form_schema_select_invalid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + with pytest.raises(ValueError, match="Variable mode is not in options"): + validator._validate_credential_form_schema(schema, {"mode": "medium"}) + + def test_validate_credential_form_schema_select_valid_option(self): + validator = CommonValidator() + schema = CredentialFormSchema( + variable="mode", + label=I18nObject(en_US="Mode"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="Fast"), value="fast"), + FormOption(label=I18nObject(en_US="Slow"), value="slow"), + ], + ) + assert validator._validate_credential_form_schema(schema, {"mode": "fast"}) == "fast" + + def test_validate_credential_form_schema_switch_invalid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator._validate_credential_form_schema(schema, {"enabled": "maybe"}) + + def test_validate_credential_form_schema_switch_valid(self): + validator = CommonValidator() + schema = CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH) + assert validator._validate_credential_form_schema(schema, {"enabled": "true"}) is True + assert validator._validate_credential_form_schema(schema, {"enabled": "FALSE"}) is False + + def test_validate_and_filter_credential_form_schemas_with_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="auth_type", + label=I18nObject(en_US="Auth Type"), + type=FormType.SELECT, + options=[ + FormOption(label=I18nObject(en_US="API Key"), value="api_key"), + FormOption(label=I18nObject(en_US="OAuth"), value="oauth"), + ], + ), + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ), + CredentialFormSchema( + variable="client_id", + label=I18nObject(en_US="Client ID"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="oauth")], + ), + ] + + # Case 1: auth_type = api_key + credentials = {"auth_type": "api_key", "api_key": "my_secret"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "auth_type" in result + assert "api_key" in result + assert "client_id" not in result + assert result["api_key"] == "my_secret" + + # Case 2: auth_type = oauth + credentials = {"auth_type": "oauth", "client_id": "my_client"} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + # Note: 'auth_type' contains 'oauth'. 'result' contains keys that pass validation. + # Since 'oauth' is not an empty string, it is in result. + assert "auth_type" in result + assert "api_key" not in result + assert "client_id" in result + assert result["client_id"] == "my_client" + + def test_validate_and_filter_show_on_missing_variable(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is missing in credentials, so api_key should be filtered out + result = validator._validate_and_filter_credential_form_schemas(schemas, {}) + assert result == {} + + def test_validate_and_filter_show_on_mismatch_value(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="auth_type", value="api_key")], + ) + ] + # auth_type is oauth, which doesn't match show_on + result = validator._validate_and_filter_credential_form_schemas(schemas, {"auth_type": "oauth"}) + assert result == {} + + def test_validate_and_filter_multiple_show_on(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema( + variable="target", + label=I18nObject(en_US="Target"), + type=FormType.TEXT_INPUT, + show_on=[FormShowOnObject(variable="v1", value="a"), FormShowOnObject(variable="v2", value="b")], + ) + ] + # Both match + assert "target" in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "b", "target": "val"} + ) + # One mismatch + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "v2": "c", "target": "val"} + ) + # One missing + assert "target" not in validator._validate_and_filter_credential_form_schemas( + schemas, {"v1": "a", "target": "val"} + ) + + def test_validate_and_filter_skips_falsy_results(self): + validator = CommonValidator() + schemas = [ + CredentialFormSchema(variable="enabled", label=I18nObject(en_US="Enabled"), type=FormType.SWITCH), + CredentialFormSchema( + variable="empty_str", label=I18nObject(en_US="Empty"), type=FormType.TEXT_INPUT, required=False + ), + ] + # Result of false switch is False. if result: is false. Not added. + # Result of empty string is "", if result: is false. Not added. + credentials = {"enabled": "false", "empty_str": ""} + result = validator._validate_and_filter_credential_form_schemas(schemas, credentials) + assert "enabled" not in result + assert "empty_str" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py new file mode 100644 index 0000000000..bab2805276 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_model_credential_schema_validator.py @@ -0,0 +1,233 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormOption, + FormShowOnObject, + FormType, + ModelCredentialSchema, +) +from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator + + +def test_validate_and_filter_with_none_schema(): + validator = ModelCredentialSchemaValidator(ModelType.LLM, None) + with pytest.raises(ValueError, match="Model credential schema is None"): + validator.validate_and_filter({}) + + +def test_validate_and_filter_success(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ), + CredentialFormSchema( + variable="optional_field", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + default="default_val", + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + credentials = {"api_key": "sk-123456"} + result = validator.validate_and_filter(credentials) + + assert result["api_key"] == "sk-123456" + assert result["optional_field"] == "default_val" + assert credentials["__model_type"] == ModelType.LLM.value + + +def test_validate_and_filter_with_show_on(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="mode", label=I18nObject(en_US="Mode", zh_Hans="模式"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=True, + show_on=[FormShowOnObject(variable="mode", value="advanced")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # mode is 'simple', conditional_field should be filtered out + credentials = {"mode": "simple", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert "conditional_field" not in result + assert result["mode"] == "simple" + + # mode is 'advanced', conditional_field should be kept + credentials = {"mode": "advanced", "conditional_field": "secret"} + result = validator.validate_and_filter(credentials) + assert result["conditional_field"] == "secret" + assert result["mode"] == "advanced" + + # show_on variable missing in credentials + credentials = {"conditional_field": "secret"} # mode missing + with pytest.raises(ValueError, match="Variable mode is required"): # because mode is required in schema + validator.validate_and_filter(credentials) + + +def test_validate_and_filter_show_on_missing_trigger_var(): + # specifically test all_show_on_match = False when variable not in credentials + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional_trigger", + label=I18nObject(en_US="Optional Trigger", zh_Hans="可选触发"), + type=FormType.TEXT_INPUT, + required=False, + ), + CredentialFormSchema( + variable="conditional_field", + label=I18nObject(en_US="Conditional", zh_Hans="条件"), + type=FormType.TEXT_INPUT, + required=False, + show_on=[FormShowOnObject(variable="optional_trigger", value="active")], + ), + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + # optional_trigger missing, conditional_field should be skipped + result = validator.validate_and_filter({"conditional_field": "val"}) + assert "conditional_field" not in result + + +def test_common_validator_logic_required(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API Key", zh_Hans="API Key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({"api_key": ""}) + + +def test_common_validator_logic_max_length(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", + label=I18nObject(en_US="Key", zh_Hans="Key"), + type=FormType.TEXT_INPUT, + required=True, + max_length=5, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key length should not be greater than 5"): + validator.validate_and_filter({"key": "123456"}) + + +def test_common_validator_logic_invalid_type(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="key", label=I18nObject(en_US="Key", zh_Hans="Key"), type=FormType.TEXT_INPUT, required=True + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + with pytest.raises(ValueError, match="Variable key should be string"): + validator.validate_and_filter({"key": 123}) + + +def test_common_validator_logic_switch(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="enabled", + label=I18nObject(en_US="Enabled", zh_Hans="启用"), + type=FormType.SWITCH, + required=True, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"enabled": "true"}) + assert result["enabled"] is True + + result = validator.validate_and_filter({"enabled": "false"}) + assert "enabled" not in result + + with pytest.raises(ValueError, match="Variable enabled should be true or false"): + validator.validate_and_filter({"enabled": "not_a_bool"}) + + +def test_common_validator_logic_options(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="choice", + label=I18nObject(en_US="Choice", zh_Hans="选择"), + type=FormType.SELECT, + required=True, + options=[ + FormOption(label=I18nObject(en_US="A", zh_Hans="A"), value="a"), + FormOption(label=I18nObject(en_US="B", zh_Hans="B"), value="b"), + ], + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({"choice": "a"}) + assert result["choice"] == "a" + + with pytest.raises(ValueError, match="Variable choice is not in options"): + validator.validate_and_filter({"choice": "c"}) + + +def test_validate_and_filter_optional_no_default(): + schema = ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model", zh_Hans="模型")), + credential_form_schemas=[ + CredentialFormSchema( + variable="optional", + label=I18nObject(en_US="Optional", zh_Hans="可选"), + type=FormType.TEXT_INPUT, + required=False, + ) + ], + ) + validator = ModelCredentialSchemaValidator(ModelType.LLM, schema) + + result = validator.validate_and_filter({}) + assert "optional" not in result diff --git a/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py new file mode 100644 index 0000000000..043306840f --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/schema_validators/test_provider_credential_schema_validator.py @@ -0,0 +1,72 @@ +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType, ProviderCredentialSchema +from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( + ProviderCredentialSchemaValidator, +) + + +class TestProviderCredentialSchemaValidator: + def test_validate_and_filter_success(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ), + CredentialFormSchema( + variable="endpoint", + label=I18nObject(en_US="Endpoint"), + type=FormType.TEXT_INPUT, + required=False, + default="https://api.example.com", + ), + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test valid credentials + credentials = {"api_key": "my-secret-key"} + result = validator.validate_and_filter(credentials) + + assert result == {"api_key": "my-secret-key", "endpoint": "https://api.example.com"} + + def test_validate_and_filter_missing_required(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test missing required credentials + with pytest.raises(ValueError, match="Variable api_key is required"): + validator.validate_and_filter({}) + + def test_validate_and_filter_extra_fields_filtered(self): + # Setup schema + schema = ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.TEXT_INPUT, required=True + ) + ] + ) + validator = ProviderCredentialSchemaValidator(schema) + + # Test credentials with extra fields + credentials = {"api_key": "my-secret-key", "extra_field": "should-be-filtered"} + result = validator.validate_and_filter(credentials) + + assert "api_key" in result + assert "extra_field" not in result + assert result == {"api_key": "my-secret-key"} + + def test_init(self): + schema = ProviderCredentialSchema(credential_form_schemas=[]) + validator = ProviderCredentialSchemaValidator(schema) + assert validator.provider_credential_schema == schema diff --git a/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py new file mode 100644 index 0000000000..1ce8765a3b --- /dev/null +++ b/api/tests/unit_tests/dify_graph/model_runtime/utils/test_encoders.py @@ -0,0 +1,231 @@ +import dataclasses +import datetime +from collections import deque +from decimal import Decimal +from enum import Enum +from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network +from pathlib import Path, PurePath +from re import compile +from typing import Any +from unittest.mock import MagicMock +from uuid import UUID + +import pytest +from pydantic import BaseModel, ConfigDict +from pydantic.networks import AnyUrl, NameEmail +from pydantic.types import SecretBytes, SecretStr +from pydantic_core import Url +from pydantic_extra_types.color import Color + +from dify_graph.model_runtime.utils.encoders import ( + _model_dump, + decimal_encoder, + generate_encoders_by_class_tuples, + isoformat, + jsonable_encoder, +) + + +class MockEnum(Enum): + A = "a" + B = "b" + + +class MockPydanticModel(BaseModel): + model_config = ConfigDict(populate_by_name=True) + name: str + age: int + + +@dataclasses.dataclass +class MockDataclass: + name: str + value: Any + + +class MockWithDict: + def __init__(self, data): + self.data = data + + def __iter__(self): + return iter(self.data.items()) + + +class MockWithVars: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class TestEncoders: + def test_model_dump(self): + model = MockPydanticModel(name="test", age=20) + result = _model_dump(model) + assert result == {"name": "test", "age": 20} + + def test_isoformat(self): + d = datetime.date(2023, 1, 1) + assert isoformat(d) == "2023-01-01" + t = datetime.time(12, 0, 0) + assert isoformat(t) == "12:00:00" + + def test_decimal_encoder(self): + assert decimal_encoder(Decimal("1.0")) == 1.0 + assert decimal_encoder(Decimal(1)) == 1 + assert decimal_encoder(Decimal("1.5")) == 1.5 + assert decimal_encoder(Decimal(0)) == 0 + assert decimal_encoder(Decimal(-1)) == -1 + + def test_generate_encoders_by_class_tuples(self): + type_map = {int: str, float: str, str: int} + result = generate_encoders_by_class_tuples(type_map) + assert result[str] == (int, float) + assert result[int] == (str,) + + def test_jsonable_encoder_basic_types(self): + assert jsonable_encoder("string") == "string" + assert jsonable_encoder(123) == 123 + assert jsonable_encoder(1.23) == 1.23 + assert jsonable_encoder(None) is None + + def test_jsonable_encoder_pydantic(self): + model = MockPydanticModel(name="test", age=20) + assert jsonable_encoder(model) == {"name": "test", "age": 20} + + def test_jsonable_encoder_pydantic_root(self): + # Manually create a mock that behaves like a model with __root__ + # because Pydantic v2 handles root differently, but the code checks for "__root__" + model = MagicMock(spec=BaseModel) + # _model_dump(obj, mode="json", ...) -> model.model_dump(mode="json", ...) + model.model_dump.return_value = {"__root__": [1, 2, 3]} + assert jsonable_encoder(model) == [1, 2, 3] + + def test_jsonable_encoder_dataclass(self): + obj = MockDataclass(name="test", value=1) + assert jsonable_encoder(obj) == {"name": "test", "value": 1} + # Test dataclass type (should not be treated as instance) + # It should fall back to vars() or dict() or at least not crash + with pytest.raises(ValueError): + jsonable_encoder(MockDataclass) + + def test_jsonable_encoder_enum(self): + assert jsonable_encoder(MockEnum.A) == "a" + + def test_jsonable_encoder_path(self): + assert jsonable_encoder(Path("/tmp/test")) == "/tmp/test" + assert jsonable_encoder(PurePath("/tmp/test")) == "/tmp/test" + + def test_jsonable_encoder_decimal(self): + # In jsonable_encoder, Decimal is formatted as string via format(obj, "f") + assert jsonable_encoder(Decimal("1.23")) == "1.23" + assert jsonable_encoder(Decimal("1.000")) == "1.000" + + def test_jsonable_encoder_dict(self): + d = {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + assert jsonable_encoder(d) == {"a": 1, "b": [2, 3]} + assert jsonable_encoder(d, sqlalchemy_safe=False) == {"a": 1, "b": [2, 3], "_sa_instance": "hidden"} + + d_with_none = {"a": 1, "b": None} + assert jsonable_encoder(d_with_none, exclude_none=True) == {"a": 1} + assert jsonable_encoder(d_with_none, exclude_none=False) == {"a": 1, "b": None} + + def test_jsonable_encoder_collections(self): + assert jsonable_encoder([1, 2]) == [1, 2] + assert jsonable_encoder((1, 2)) == [1, 2] + assert jsonable_encoder({1, 2}) == [1, 2] + assert jsonable_encoder(frozenset([1, 2])) == [1, 2] + assert jsonable_encoder(deque([1, 2])) == [1, 2] + + def gen(): + yield 1 + yield 2 + + assert jsonable_encoder(gen()) == [1, 2] + + def test_jsonable_encoder_custom_encoder(self): + custom = {int: lambda x: str(x + 1)} + assert jsonable_encoder(1, custom_encoder=custom) == "2" + + # Test subclass matching for custom encoder + class SubInt(int): + pass + + assert jsonable_encoder(SubInt(1), custom_encoder=custom) == "2" + + def test_jsonable_encoder_special_types(self): + # These hit ENCODERS_BY_TYPE or encoders_by_class_tuples + assert jsonable_encoder(b"bytes") == "bytes" + assert jsonable_encoder(Color("red")) == "red" + + dt = datetime.datetime(2023, 1, 1, 12, 0, 0) + assert jsonable_encoder(dt) == dt.isoformat() + + date = datetime.date(2023, 1, 1) + assert jsonable_encoder(date) == date.isoformat() + + time = datetime.time(12, 0, 0) + assert jsonable_encoder(time) == time.isoformat() + + td = datetime.timedelta(seconds=60) + assert jsonable_encoder(td) == 60.0 + + assert jsonable_encoder(IPv4Address("127.0.0.1")) == "127.0.0.1" + assert jsonable_encoder(IPv4Interface("127.0.0.1/24")) == "127.0.0.1/24" + assert jsonable_encoder(IPv4Network("127.0.0.0/24")) == "127.0.0.0/24" + assert jsonable_encoder(IPv6Address("::1")) == "::1" + assert jsonable_encoder(IPv6Interface("::1/128")) == "::1/128" + assert jsonable_encoder(IPv6Network("::/128")) == "::/128" + + assert jsonable_encoder(NameEmail(name="test", email="test@example.com")) == "test " + + assert jsonable_encoder(compile("abc")) == "abc" + + # Secret types + # Check what they actually return in this environment + res_bytes = jsonable_encoder(SecretBytes(b"secret")) + assert "**********" in res_bytes + + res_str = jsonable_encoder(SecretStr("secret")) + assert res_str == "**********" + + u = UUID("12345678-1234-5678-1234-567812345678") + assert jsonable_encoder(u) == str(u) + + url = AnyUrl("https://example.com") + assert jsonable_encoder(url) == "https://example.com/" + + purl = Url("https://example.com") + assert jsonable_encoder(purl) == "https://example.com/" + + def test_jsonable_encoder_fallback(self): + # dict(obj) success + obj_dict = MockWithDict({"a": 1}) + assert jsonable_encoder(obj_dict) == {"a": 1} + + # vars(obj) success + obj_vars = MockWithVars(x=10, y=20) + assert jsonable_encoder(obj_vars) == {"x": 10, "y": 20} + + # error fallback + class ReallyUnserializable: + __slots__ = ["__weakref__"] # No __dict__ + + def __iter__(self): + raise TypeError("not iterable") + + with pytest.raises(ValueError) as exc: + jsonable_encoder(ReallyUnserializable()) + assert "not iterable" in str(exc.value) + + def test_jsonable_encoder_nested(self): + data = { + "model": MockPydanticModel(name="test", age=20), + "list": [Decimal("1.1"), {MockEnum.A: Path("/tmp")}], + "set": {1, 2}, + } + expected = { + "model": {"name": "test", "age": 20}, + "list": ["1.1", {"a": "/tmp"}], + "set": [1, 2], + } + assert jsonable_encoder(data) == expected diff --git a/api/tests/unit_tests/dify_graph/node_events/test_base.py b/api/tests/unit_tests/dify_graph/node_events/test_base.py new file mode 100644 index 0000000000..6d789abac0 --- /dev/null +++ b/api/tests/unit_tests/dify_graph/node_events/test_base.py @@ -0,0 +1,19 @@ +from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from dify_graph.node_events.base import NodeRunResult + + +def test_node_run_result_accepts_trigger_info_metadata() -> None: + result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + metadata={ + WorkflowNodeExecutionMetadataKey.TRIGGER_INFO: { + "provider_id": "provider-id", + "event_name": "event-name", + } + }, + ) + + assert result.metadata[WorkflowNodeExecutionMetadataKey.TRIGGER_INFO] == { + "provider_id": "provider-id", + "event_name": "event-name", + } diff --git a/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py new file mode 100644 index 0000000000..7a537b0502 --- /dev/null +++ b/api/tests/unit_tests/extensions/otel/test_celery_sqlcommenter.py @@ -0,0 +1,172 @@ +"""Tests for Celery SQL comment context injection.""" + +from unittest.mock import MagicMock, patch + +from opentelemetry import context + + +class TestBuildCelerySqlcommenterTags: + """Tests for _build_celery_sqlcommenter_tags.""" + + def test_includes_framework_and_task_name(self): + """Tags include celery framework version and task name.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + + def test_includes_celery_retries_when_nonzero(self): + """celery_retries is included when retries > 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 3 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["celery_retries"] == 3 + + def test_omits_celery_retries_when_zero(self): + """celery_retries is omitted when retries is 0.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "celery_retries" not in tags + + def test_includes_routing_key_from_delivery_info(self): + """routing_key is included when present in delivery_info.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["routing_key"] == "workflow_based_app_execution" + + def test_includes_traceparent_when_available(self): + """traceparent is included when injectable from current context.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + task.request = MagicMock() + task.request.retries = 0 + task.request.delivery_info = {} + + traceparent = "00-5db86c23fa8d05b67db315694b518684-737bbf30cdcda066-00" + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value=traceparent, + ): + tags = _build_celery_sqlcommenter_tags(task) + + assert tags["traceparent"] == traceparent + + def test_handles_task_without_request(self): + """Gracefully handles task without request attribute.""" + from extensions.otel.celery_sqlcommenter import _build_celery_sqlcommenter_tags + + task = MagicMock() + task.name = "tasks.my_task" + del task.request + + with patch("extensions.otel.celery_sqlcommenter._get_traceparent", return_value=None): + tags = _build_celery_sqlcommenter_tags(task) + + assert "framework" in tags + assert "task_name" in tags + + +class TestTaskPrerunPostrunHandlers: + """Tests for task_prerun and task_postrun signal handlers.""" + + def test_prerun_sets_context_postrun_detaches(self): + """task_prerun attaches SQLCOMMENTER context; task_postrun detaches it.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _TOKEN_ATTR, + _on_task_postrun, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + task = MagicMock() + task.name = "tasks.async_workflow_tasks.execute_workflow_team" + task.request = MagicMock() + task.request.retries = 1 + task.request.delivery_info = {"routing_key": "workflow_based_app_execution"} + + with patch( + "extensions.otel.celery_sqlcommenter._get_traceparent", + return_value="00-abc123-def456-00", + ): + _on_task_prerun(task=task) + + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is not None + assert tags["framework"].startswith("celery:") + assert tags["task_name"] == "tasks.async_workflow_tasks.execute_workflow_team" + assert tags["celery_retries"] == 1 + assert tags["routing_key"] == "workflow_based_app_execution" + assert tags["traceparent"] == "00-abc123-def456-00" + assert hasattr(task, _TOKEN_ATTR) + + _on_task_postrun(task=task) + + tags_after = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags_after is None + assert not hasattr(task, _TOKEN_ATTR) + finally: + context.detach(token) + + def test_prerun_skips_when_no_task(self): + """prerun does nothing when task is missing from kwargs.""" + from extensions.otel.celery_sqlcommenter import ( + _SQLCOMMENTER_CONTEXT_KEY, + _on_task_prerun, + ) + + clean_ctx = context.set_value(_SQLCOMMENTER_CONTEXT_KEY, None) + token = context.attach(clean_ctx) + try: + _on_task_prerun() + tags = context.get_value(_SQLCOMMENTER_CONTEXT_KEY) + assert tags is None + finally: + context.detach(token) + + def test_postrun_skips_when_no_token(self): + """postrun does nothing when task has no token (e.g. prerun was skipped).""" + from extensions.otel.celery_sqlcommenter import _on_task_postrun + + task = MagicMock() + _on_task_postrun(task=task) diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 87d02cb187..ce6b9232ce 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -7,8 +7,8 @@ import pytest from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st -from core.workflow.file import File, FileTransferMethod, FileType -from core.workflow.variables import ( +from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -17,8 +17,8 @@ from core.workflow.variables import ( SecretVariable, StringVariable, ) -from core.workflow.variables.exc import VariableError -from core.workflow.variables.segments import ( +from dify_graph.variables.exc import VariableError +from dify_graph.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -33,7 +33,7 @@ from core.workflow.variables.segments import ( Segment, StringSegment, ) -from core.workflow.variables.types import SegmentType +from dify_graph.variables.types import SegmentType from factories import variable_factory from factories.variable_factory import TypeMismatchError, build_segment, build_segment_with_type diff --git a/api/tests/unit_tests/libs/_human_input/support.py b/api/tests/unit_tests/libs/_human_input/support.py index bd86c13a2c..3fff54f487 100644 --- a/api/tests/unit_tests/libs/_human_input/support.py +++ b/api/tests/unit_tests/libs/_human_input/support.py @@ -4,8 +4,8 @@ from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Any -from core.workflow.nodes.human_input.entities import FormInput -from core.workflow.nodes.human_input.enums import TimeoutUnit +from dify_graph.nodes.human_input.entities import FormInput +from dify_graph.nodes.human_input.enums import TimeoutUnit # Exceptions diff --git a/api/tests/unit_tests/libs/_human_input/test_form_service.py b/api/tests/unit_tests/libs/_human_input/test_form_service.py index 15e7d41e85..82598c5c6d 100644 --- a/api/tests/unit_tests/libs/_human_input/test_form_service.py +++ b/api/tests/unit_tests/libs/_human_input/test_form_service.py @@ -6,11 +6,11 @@ from datetime import datetime, timedelta import pytest -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) diff --git a/api/tests/unit_tests/libs/_human_input/test_models.py b/api/tests/unit_tests/libs/_human_input/test_models.py index 962eeb9e11..5d14b5eb4e 100644 --- a/api/tests/unit_tests/libs/_human_input/test_models.py +++ b/api/tests/unit_tests/libs/_human_input/test_models.py @@ -6,11 +6,11 @@ from datetime import datetime, timedelta import pytest -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import ( +from dify_graph.nodes.human_input.enums import ( FormInputType, TimeoutUnit, ) diff --git a/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py new file mode 100644 index 0000000000..248aa0b145 --- /dev/null +++ b/api/tests/unit_tests/libs/broadcast_channel/redis/test_streams_channel_unit_tests.py @@ -0,0 +1,145 @@ +import time + +import pytest + +from libs.broadcast_channel.redis.streams_channel import ( + StreamsBroadcastChannel, + StreamsTopic, + _StreamsSubscription, +) + + +class FakeStreamsRedis: + """Minimal in-memory Redis Streams stub for unit tests. + + - Stores entries per key as [(id, {b"data": bytes}), ...] + - xadd appends entries and returns an auto-increment id like "1-0" + - xread returns entries strictly greater than last_id + - expire is recorded but has no effect on behavior + """ + + def __init__(self) -> None: + self._store: dict[str, list[tuple[str, dict]]] = {} + self._next_id: dict[str, int] = {} + self._expire_calls: dict[str, int] = {} + + # Publisher API + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + """Append entry to stream; accept optional maxlen for API compatibility. + + The test double ignores maxlen trimming semantics; only records the entry. + """ + n = self._next_id.get(key, 0) + 1 + self._next_id[key] = n + entry_id = f"{n}-0" + self._store.setdefault(key, []).append((entry_id, fields)) + return entry_id + + def expire(self, key: str, seconds: int) -> None: + self._expire_calls[key] = self._expire_calls.get(key, 0) + 1 + + # Consumer API + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + # Expect a single key + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._store.get(key, []) + + # Find position strictly greater than last_id + start_idx = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start_idx = i + 1 + break + if start_idx >= len(entries): + # Simulate blocking wait (bounded) if requested + if block and block > 0: + time.sleep(min(0.01, block / 1000.0)) + return [] + + end_idx = len(entries) if count is None else min(len(entries), start_idx + count) + batch = entries[start_idx:end_idx] + return [(key, batch)] + + +@pytest.fixture +def fake_redis() -> FakeStreamsRedis: + return FakeStreamsRedis() + + +@pytest.fixture +def streams_channel(fake_redis: FakeStreamsRedis) -> StreamsBroadcastChannel: + return StreamsBroadcastChannel(fake_redis, retention_seconds=60) + + +class TestStreamsBroadcastChannel: + def test_topic_creation(self, streams_channel: StreamsBroadcastChannel, fake_redis: FakeStreamsRedis): + topic = streams_channel.topic("alpha") + assert isinstance(topic, StreamsTopic) + assert topic._client is fake_redis + assert topic._topic == "alpha" + assert topic._key == "stream:alpha" + + def test_publish_calls_xadd_and_expire( + self, + streams_channel: StreamsBroadcastChannel, + fake_redis: FakeStreamsRedis, + ): + topic = streams_channel.topic("beta") + payload = b"hello" + topic.publish(payload) + # One entry stored under stream key (bytes key for payload field) + assert fake_redis._store["stream:beta"][0][1] == {b"data": payload} + # Expire called after publish + assert fake_redis._expire_calls.get("stream:beta", 0) >= 1 + + +class TestStreamsSubscription: + def test_subscribe_and_receive_from_beginning(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("gamma") + # Pre-publish events before subscribing (late subscriber) + topic.publish(b"e1") + topic.publish(b"e2") + + sub = topic.subscribe() + assert isinstance(sub, _StreamsSubscription) + + received: list[bytes] = [] + with sub: + # Give listener thread a moment to xread + time.sleep(0.05) + # Drain using receive() to avoid indefinite iteration in tests + for _ in range(5): + msg = sub.receive(timeout=0.1) + if msg is None: + break + received.append(msg) + + assert received == [b"e1", b"e2"] + + def test_receive_timeout_returns_none(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("delta") + sub = topic.subscribe() + with sub: + # No messages yet + assert sub.receive(timeout=0.05) is None + + def test_close_stops_listener(self, streams_channel: StreamsBroadcastChannel): + topic = streams_channel.topic("epsilon") + sub = topic.subscribe() + with sub: + # Listener running; now close and ensure no crash + sub.close() + # After close, receive should raise SubscriptionClosedError + from libs.broadcast_channel.exc import SubscriptionClosedError + + with pytest.raises(SubscriptionClosedError): + sub.receive() + + def test_no_expire_when_zero_retention(self, fake_redis: FakeStreamsRedis): + channel = StreamsBroadcastChannel(fake_redis, retention_seconds=0) + topic = channel.topic("zeta") + topic.publish(b"payload") + # No expire recorded when retention is disabled + assert fake_redis._expire_calls.get("stream:zeta") is None diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py index df80428ee8..a94ba0c00b 100644 --- a/api/tests/unit_tests/libs/test_login.py +++ b/api/tests/unit_tests/libs/test_login.py @@ -140,7 +140,7 @@ class TestLoginRequired: # Remove ensure_sync to simulate Flask 1.x if hasattr(setup_app, "ensure_sync"): - delattr(setup_app, "ensure_sync") + del setup_app.ensure_sync with setup_app.test_request_context(): mock_user = MockUser("test_user", is_authenticated=True) diff --git a/api/tests/unit_tests/libs/test_oauth_clients.py b/api/tests/unit_tests/libs/test_oauth_clients.py index bc7880ccc8..3918e8ee4b 100644 --- a/api/tests/unit_tests/libs/test_oauth_clients.py +++ b/api/tests/unit_tests/libs/test_oauth_clients.py @@ -95,13 +95,11 @@ class TestGitHubOAuth(BaseOAuthTest): ], "primary@example.com", ), - # User with no emails - fallback to noreply - ({"id": 12345, "login": "testuser", "name": "Test User"}, [], "12345+testuser@users.noreply.github.com"), - # User with only secondary email - fallback to noreply + # User with private email (null email and name from API) ( - {"id": 12345, "login": "testuser", "name": "Test User"}, - [{"email": "secondary@example.com", "primary": False}], - "12345+testuser@users.noreply.github.com", + {"id": 12345, "login": "testuser", "name": None, "email": None}, + [{"email": "primary@example.com", "primary": True}], + "primary@example.com", ), ], ) @@ -118,9 +116,54 @@ class TestGitHubOAuth(BaseOAuthTest): user_info = oauth.get_user_info("test_token") assert user_info.id == str(user_data["id"]) - assert user_info.name == user_data["name"] + assert user_info.name == (user_data["name"] or "") assert user_info.email == expected_email + @pytest.mark.parametrize( + ("user_data", "email_data"), + [ + # User with no emails + ({"id": 12345, "login": "testuser", "name": "Test User"}, []), + # User with only secondary email + ( + {"id": 12345, "login": "testuser", "name": "Test User"}, + [{"email": "secondary@example.com", "primary": False}], + ), + # User with private email and no primary in emails endpoint + ( + {"id": 12345, "login": "testuser", "name": None, "email": None}, + [], + ), + ], + ) + @patch("httpx.get", autospec=True) + def test_should_raise_error_when_no_primary_email(self, mock_get, oauth, user_data, email_data): + user_response = MagicMock() + user_response.json.return_value = user_data + + email_response = MagicMock() + email_response.json.return_value = email_data + + mock_get.side_effect = [user_response, email_response] + + with pytest.raises(ValueError, match="Keep my email addresses private"): + oauth.get_user_info("test_token") + + @patch("httpx.get", autospec=True) + def test_should_raise_error_when_email_endpoint_fails(self, mock_get, oauth): + user_response = MagicMock() + user_response.json.return_value = {"id": 12345, "login": "testuser", "name": "Test User"} + + email_response = MagicMock() + email_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "Forbidden", request=MagicMock(), response=MagicMock() + ) + + mock_get.side_effect = [user_response, email_response] + + with pytest.raises(ValueError, match="Keep my email addresses private"): + oauth.get_user_info("test_token") + @patch("httpx.get", autospec=True) def test_should_handle_network_errors(self, mock_get, oauth): mock_get.side_effect = httpx.RequestError("Network error") diff --git a/api/tests/unit_tests/libs/test_rsa.py b/api/tests/unit_tests/libs/test_rsa.py index 6a448d4f1f..7063a068ff 100644 --- a/api/tests/unit_tests/libs/test_rsa.py +++ b/api/tests/unit_tests/libs/test_rsa.py @@ -1,13 +1,12 @@ -import rsa as pyrsa from Crypto.PublicKey import RSA from libs import gmpy2_pkcs10aep_cipher def test_gmpy2_pkcs10aep_cipher(): - rsa_key_pair = pyrsa.newkeys(2048) - public_key = rsa_key_pair[0].save_pkcs1() - private_key = rsa_key_pair[1].save_pkcs1() + rsa_key = RSA.generate(2048) + public_key = rsa_key.publickey().export_key(format="PEM") + private_key = rsa_key.export_key(format="PEM") public_rsa_key = RSA.import_key(public_key) public_cipher_rsa2 = gmpy2_pkcs10aep_cipher.new(public_rsa_key) diff --git a/api/tests/unit_tests/models/test_account_models.py b/api/tests/unit_tests/models/test_account_models.py index cc311d447f..f48db77bb5 100644 --- a/api/tests/unit_tests/models/test_account_models.py +++ b/api/tests/unit_tests/models/test_account_models.py @@ -98,7 +98,7 @@ class TestAccountModelValidation: ) # Assert - assert account.status == "active" + assert account.status == AccountStatus.ACTIVE def test_account_get_status_method(self): """Test the get_status method returns AccountStatus enum.""" @@ -106,7 +106,7 @@ class TestAccountModelValidation: account = Account( name="Test User", email="test@example.com", - status="pending", + status=AccountStatus.PENDING, ) # Act @@ -622,28 +622,10 @@ class TestAccountGetByOpenId: mock_account = Account(name="Test User", email="test@example.com") mock_account.id = account_id - # Mock the query chain - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = mock_account_integrate - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query - - # Mock the second query for account - mock_account_query = MagicMock() - mock_account_where = MagicMock() - mock_account_where.one_or_none.return_value = mock_account - mock_account_query.where.return_value = mock_account_where - - # Setup query to return different results based on model - def query_side_effect(model): - if model.__name__ == "AccountIntegrate": - return mock_query - elif model.__name__ == "Account": - return mock_account_query - return MagicMock() - - mock_db.session.query.side_effect = query_side_effect + # Mock db.session.execute().scalar_one_or_none() for AccountIntegrate lookup + mock_db.session.execute.return_value.scalar_one_or_none.return_value = mock_account_integrate + # Mock db.session.scalar() for Account lookup + mock_db.session.scalar.return_value = mock_account # Act result = Account.get_by_openid(provider, open_id) @@ -658,12 +640,8 @@ class TestAccountGetByOpenId: provider = "github" open_id = "github_user_456" - # Mock the query chain to return None - mock_query = MagicMock() - mock_where = MagicMock() - mock_where.one_or_none.return_value = None - mock_query.where.return_value = mock_where - mock_db.session.query.return_value = mock_query + # Mock db.session.execute().scalar_one_or_none() to return None + mock_db.session.execute.return_value.scalar_one_or_none.return_value = None # Act result = Account.get_by_openid(provider, open_id) diff --git a/api/tests/unit_tests/models/test_app_models.py b/api/tests/unit_tests/models/test_app_models.py index 8b96c62dc9..e5f92fbed5 100644 --- a/api/tests/unit_tests/models/test_app_models.py +++ b/api/tests/unit_tests/models/test_app_models.py @@ -16,6 +16,7 @@ from uuid import uuid4 import pytest +from models.enums import ConversationFromSource from models.model import ( App, AppAnnotationHitHistory, @@ -300,10 +301,8 @@ class TestAppModelConfig: created_by=str(uuid4()), ) - # Mock database query to return None - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.first.return_value = None - + # Mock database scalar to return None (no annotation setting found) + with patch("models.model.db.session.scalar", return_value=None): # Act result = config.annotation_reply_dict @@ -326,7 +325,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=from_end_user_id, ) @@ -347,7 +346,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) conversation._inputs = inputs @@ -366,7 +365,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) inputs = {"query": "Hello", "context": "test"} @@ -385,7 +384,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), summary="Test summary", ) @@ -404,7 +403,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), summary=None, ) @@ -427,7 +426,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), override_model_configs='{"model": "gpt-4"}', ) @@ -448,7 +447,7 @@ class TestConversationModel: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=from_end_user_id, dialogue_count=5, ) @@ -489,7 +488,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) # Assert @@ -513,7 +512,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message._inputs = inputs @@ -535,7 +534,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) inputs = {"query": "Hello", "context": "test"} @@ -557,7 +556,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, override_model_configs='{"model": "gpt-4"}', ) @@ -580,7 +579,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, message_metadata=json.dumps(metadata), ) @@ -602,7 +601,7 @@ class TestMessageModel: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, message_metadata=None, ) @@ -629,7 +628,7 @@ class TestMessageModel: answer_unit_price=Decimal("0.0002"), total_price=Decimal("0.0003"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, status="normal", ) message.id = str(uuid4()) @@ -951,10 +950,8 @@ class TestSiteModel: def test_site_generate_code(self): """Test Site.generate_code static method.""" - # Mock database query to return 0 (no existing codes) - with patch("models.model.db.session.query", autospec=True) as mock_query: - mock_query.return_value.where.return_value.count.return_value = 0 - + # Mock database scalar to return 0 (no existing codes) + with patch("models.model.db.session.scalar", return_value=0): # Act code = Site.generate_code(8) @@ -992,7 +989,7 @@ class TestModelIntegration: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, from_end_user_id=str(uuid4()), ) conversation.id = conversation_id @@ -1007,7 +1004,7 @@ class TestModelIntegration: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message.id = message_id @@ -1068,7 +1065,7 @@ class TestModelIntegration: message_unit_price=Decimal("0.0001"), answer_unit_price=Decimal("0.0002"), currency="USD", - from_source="api", + from_source=ConversationFromSource.API, ) message.id = message_id @@ -1162,7 +1159,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = str(uuid4()) @@ -1187,7 +1184,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1204,7 +1201,7 @@ class TestConversationStatusCount: def test_status_count_batch_loading_implementation(self): """Test that status_count uses batch loading instead of N+1 queries.""" # Arrange - from core.workflow.enums import WorkflowExecutionStatus + from dify_graph.enums import WorkflowExecutionStatus app_id = str(uuid4()) conversation_id = str(uuid4()) @@ -1219,7 +1216,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1311,7 +1308,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1365,7 +1362,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id @@ -1411,7 +1408,7 @@ class TestConversationStatusCount: def test_status_count_paused(self): """Test status_count includes paused workflow runs.""" # Arrange - from core.workflow.enums import WorkflowExecutionStatus + from dify_graph.enums import WorkflowExecutionStatus app_id = str(uuid4()) conversation_id = str(uuid4()) @@ -1422,7 +1419,7 @@ class TestConversationStatusCount: mode=AppMode.CHAT, name="Test Conversation", status="normal", - from_source="api", + from_source=ConversationFromSource.API, ) conversation.id = conversation_id diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index d44aa56488..7d7674da3c 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -1,6 +1,6 @@ from uuid import uuid4 -from core.workflow.variables import SegmentType +from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 9bb7c05a91..98dd07907a 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -25,6 +25,13 @@ from models.dataset import ( DocumentSegment, Embedding, ) +from models.enums import ( + DataSourceType, + DocumentCreatedFrom, + IndexingStatus, + ProcessRuleMode, + SegmentStatus, +) class TestDatasetModelValidation: @@ -40,14 +47,14 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, ) # Assert assert dataset.name == "Test Dataset" assert dataset.tenant_id == tenant_id - assert dataset.data_source_type == "upload_file" + assert dataset.data_source_type == DataSourceType.UPLOAD_FILE assert dataset.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -57,7 +64,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), description="Test description", indexing_technique="high_quality", @@ -77,14 +84,14 @@ class TestDatasetModelValidation: dataset_high_quality = Dataset( tenant_id=str(uuid4()), name="High Quality Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="high_quality", ) dataset_economy = Dataset( tenant_id=str(uuid4()), name="Economy Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), indexing_technique="economy", ) @@ -101,14 +108,14 @@ class TestDatasetModelValidation: dataset_vendor = Dataset( tenant_id=str(uuid4()), name="Vendor Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="vendor", ) dataset_external = Dataset( tenant_id=str(uuid4()), name="External Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), provider="external", ) @@ -126,7 +133,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), index_struct=json.dumps(index_struct_data), ) @@ -145,7 +152,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -161,7 +168,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -178,7 +185,7 @@ class TestDatasetModelValidation: dataset = Dataset( tenant_id=str(uuid4()), name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=str(uuid4()), ) @@ -218,10 +225,10 @@ class TestDocumentModelRelationships: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test_document.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, ) @@ -229,10 +236,10 @@ class TestDocumentModelRelationships: assert document.tenant_id == tenant_id assert document.dataset_id == dataset_id assert document.position == 1 - assert document.data_source_type == "upload_file" + assert document.data_source_type == DataSourceType.UPLOAD_FILE assert document.batch == "batch_001" assert document.name == "test_document.pdf" - assert document.created_from == "web" + assert document.created_from == DocumentCreatedFrom.WEB assert document.created_by == created_by # Note: Default values are set by database, not by model instantiation @@ -250,12 +257,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="waiting", + indexing_status=IndexingStatus.WAITING, ) # Act @@ -271,12 +278,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="parsing", + indexing_status=IndexingStatus.PARSING, is_paused=True, ) @@ -289,15 +296,20 @@ class TestDocumentModelRelationships: def test_document_display_status_indexing(self): """Test document display_status property for indexing state.""" # Arrange - for indexing_status in ["parsing", "cleaning", "splitting", "indexing"]: + for indexing_status in [ + IndexingStatus.PARSING, + IndexingStatus.CLEANING, + IndexingStatus.SPLITTING, + IndexingStatus.INDEXING, + ]: document = Document( tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), indexing_status=indexing_status, ) @@ -315,12 +327,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="error", + indexing_status=IndexingStatus.ERROR, ) # Act @@ -336,12 +348,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=True, archived=False, ) @@ -359,12 +371,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, enabled=False, archived=False, ) @@ -382,12 +394,12 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, archived=True, ) @@ -405,10 +417,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), data_source_info=json.dumps(data_source_info), ) @@ -428,10 +440,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), ) @@ -448,10 +460,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=1000, ) @@ -471,10 +483,10 @@ class TestDocumentModelRelationships: tenant_id=str(uuid4()), dataset_id=str(uuid4()), position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=str(uuid4()), word_count=0, ) @@ -582,7 +594,7 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="waiting", + status=SegmentStatus.WAITING, ) segment_completed = DocumentSegment( tenant_id=str(uuid4()), @@ -593,12 +605,12 @@ class TestDocumentSegmentIndexing: word_count=1, tokens=2, created_by=str(uuid4()), - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert - assert segment_waiting.status == "waiting" - assert segment_completed.status == "completed" + assert segment_waiting.status == SegmentStatus.WAITING + assert segment_completed.status == SegmentStatus.COMPLETED def test_document_segment_enabled_disabled_tracking(self): """Test document segment enabled/disabled state tracking.""" @@ -769,13 +781,13 @@ class TestDatasetProcessRule: # Act process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=created_by, ) # Assert assert process_rule.dataset_id == dataset_id - assert process_rule.mode == "automatic" + assert process_rule.mode == ProcessRuleMode.AUTOMATIC assert process_rule.created_by == created_by def test_dataset_process_rule_modes(self): @@ -797,7 +809,7 @@ class TestDatasetProcessRule: } process_rule = DatasetProcessRule( dataset_id=str(uuid4()), - mode="custom", + mode=ProcessRuleMode.CUSTOM, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -817,7 +829,7 @@ class TestDatasetProcessRule: rules_data = {"test": "data"} process_rule = DatasetProcessRule( dataset_id=dataset_id, - mode="automatic", + mode=ProcessRuleMode.AUTOMATIC, created_by=str(uuid4()), rules=json.dumps(rules_data), ) @@ -827,7 +839,7 @@ class TestDatasetProcessRule: # Assert assert result["dataset_id"] == dataset_id - assert result["mode"] == "automatic" + assert result["mode"] == ProcessRuleMode.AUTOMATIC assert result["rules"] == rules_data def test_dataset_process_rule_automatic_rules(self): @@ -969,7 +981,7 @@ class TestModelIntegration: dataset = Dataset( tenant_id=tenant_id, name="Test Dataset", - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, created_by=created_by, indexing_technique="high_quality", ) @@ -980,10 +992,10 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, ) @@ -999,7 +1011,7 @@ class TestModelIntegration: word_count=3, tokens=5, created_by=created_by, - status="completed", + status=SegmentStatus.COMPLETED, ) # Assert @@ -1009,7 +1021,7 @@ class TestModelIntegration: assert segment.document_id == document_id assert dataset.indexing_technique == "high_quality" assert document.word_count == 100 - assert segment.status == "completed" + assert segment.status == SegmentStatus.COMPLETED def test_document_to_dict_serialization(self): """Test document to_dict method for serialization.""" @@ -1022,13 +1034,13 @@ class TestModelIntegration: tenant_id=tenant_id, dataset_id=dataset_id, position=1, - data_source_type="upload_file", + data_source_type=DataSourceType.UPLOAD_FILE, batch="batch_001", name="test.pdf", - created_from="web", + created_from=DocumentCreatedFrom.WEB, created_by=created_by, word_count=100, - indexing_status="completed", + indexing_status=IndexingStatus.COMPLETED, ) # Mock segment_count and hit_count @@ -1044,6 +1056,6 @@ class TestModelIntegration: assert result["dataset_id"] == dataset_id assert result["name"] == "test.pdf" assert result["word_count"] == 100 - assert result["indexing_status"] == "completed" + assert result["indexing_status"] == IndexingStatus.COMPLETED assert result["segment_count"] == 5 assert result["hit_count"] == 10 diff --git a/api/tests/unit_tests/models/test_enums_creator_user_role.py b/api/tests/unit_tests/models/test_enums_creator_user_role.py new file mode 100644 index 0000000000..6317166fdc --- /dev/null +++ b/api/tests/unit_tests/models/test_enums_creator_user_role.py @@ -0,0 +1,19 @@ +import pytest + +from models.enums import CreatorUserRole + + +def test_creator_user_role_missing_maps_hyphen_to_enum(): + # given an alias with hyphen + value = "end-user" + + # when converting to enum (invokes StrEnum._missing_ override) + role = CreatorUserRole(value) + + # then it should map to END_USER + assert role is CreatorUserRole.END_USER + + +def test_creator_user_role_missing_raises_for_unknown(): + with pytest.raises(ValueError): + CreatorUserRole("unknown") diff --git a/api/tests/unit_tests/models/test_provider_models.py b/api/tests/unit_tests/models/test_provider_models.py index ec84a61c8e..f628e54a4d 100644 --- a/api/tests/unit_tests/models/test_provider_models.py +++ b/api/tests/unit_tests/models/test_provider_models.py @@ -19,6 +19,7 @@ from uuid import uuid4 import pytest +from models.enums import CredentialSourceType, PaymentStatus from models.provider import ( LoadBalancingModelConfig, Provider, @@ -158,7 +159,7 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == provider_name - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_used == 0 @@ -172,10 +173,10 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="anthropic", - provider_type="system", + provider_type=ProviderType.SYSTEM, is_valid=True, credential_id=credential_id, - quota_type="paid", + quota_type=ProviderQuotaType.PAID, quota_limit=10000, quota_used=500, ) @@ -183,10 +184,10 @@ class TestProviderModel: # Assert assert provider.tenant_id == tenant_id assert provider.provider_name == "anthropic" - assert provider.provider_type == "system" + assert provider.provider_type == ProviderType.SYSTEM assert provider.is_valid is True assert provider.credential_id == credential_id - assert provider.quota_type == "paid" + assert provider.quota_type == ProviderQuotaType.PAID assert provider.quota_limit == 10000 assert provider.quota_used == 500 @@ -199,7 +200,7 @@ class TestProviderModel: ) # Assert - assert provider.provider_type == "custom" + assert provider.provider_type == ProviderType.CUSTOM assert provider.is_valid is False assert provider.quota_type == "" assert provider.quota_limit is None @@ -213,7 +214,7 @@ class TestProviderModel: provider = Provider( tenant_id=tenant_id, provider_name="openai", - provider_type="custom", + provider_type=ProviderType.CUSTOM, ) # Act @@ -253,7 +254,7 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - provider_type=ProviderType.SYSTEM.value, + provider_type=ProviderType.SYSTEM, is_valid=True, ) @@ -266,13 +267,13 @@ class TestProviderModel: provider = Provider( tenant_id=str(uuid4()), provider_name="openai", - quota_type="trial", + quota_type=ProviderQuotaType.TRIAL, quota_limit=1000, quota_used=250, ) # Assert - assert provider.quota_type == "trial" + assert provider.quota_type == ProviderQuotaType.TRIAL assert provider.quota_limit == 1000 assert provider.quota_used == 250 remaining = provider.quota_limit - provider.quota_used @@ -429,13 +430,13 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=tenant_id, provider_name="openai", - preferred_provider_type="custom", + preferred_provider_type=ProviderType.CUSTOM, ) # Assert assert preferred.tenant_id == tenant_id assert preferred.provider_name == "openai" - assert preferred.preferred_provider_type == "custom" + assert preferred.preferred_provider_type == ProviderType.CUSTOM def test_tenant_preferred_provider_system_type(self): """Test tenant preferred provider with system type.""" @@ -443,11 +444,11 @@ class TestTenantPreferredModelProvider: preferred = TenantPreferredModelProvider( tenant_id=str(uuid4()), provider_name="anthropic", - preferred_provider_type="system", + preferred_provider_type=ProviderType.SYSTEM, ) # Assert - assert preferred.preferred_provider_type == "system" + assert preferred.preferred_provider_type == ProviderType.SYSTEM class TestProviderOrder: @@ -470,7 +471,7 @@ class TestProviderOrder: quantity=1, currency=None, total_amount=None, - payment_status="wait_pay", + payment_status=PaymentStatus.WAIT_PAY, paid_at=None, pay_failed_at=None, refunded_at=None, @@ -481,7 +482,7 @@ class TestProviderOrder: assert order.provider_name == "openai" assert order.account_id == account_id assert order.payment_product_id == "prod_123" - assert order.payment_status == "wait_pay" + assert order.payment_status == PaymentStatus.WAIT_PAY assert order.quantity == 1 def test_provider_order_with_payment_details(self): @@ -502,7 +503,7 @@ class TestProviderOrder: quantity=5, currency="USD", total_amount=9999, - payment_status="paid", + payment_status=PaymentStatus.PAID, paid_at=paid_time, pay_failed_at=None, refunded_at=None, @@ -514,7 +515,7 @@ class TestProviderOrder: assert order.quantity == 5 assert order.currency == "USD" assert order.total_amount == 9999 - assert order.payment_status == "paid" + assert order.payment_status == PaymentStatus.PAID assert order.paid_at == paid_time def test_provider_order_payment_statuses(self): @@ -536,23 +537,23 @@ class TestProviderOrder: } # Act & Assert - Wait pay status - wait_order = ProviderOrder(**base_params, payment_status="wait_pay") - assert wait_order.payment_status == "wait_pay" + wait_order = ProviderOrder(**base_params, payment_status=PaymentStatus.WAIT_PAY) + assert wait_order.payment_status == PaymentStatus.WAIT_PAY # Act & Assert - Paid status - paid_order = ProviderOrder(**base_params, payment_status="paid") - assert paid_order.payment_status == "paid" + paid_order = ProviderOrder(**base_params, payment_status=PaymentStatus.PAID) + assert paid_order.payment_status == PaymentStatus.PAID # Act & Assert - Failed status failed_params = {**base_params, "pay_failed_at": datetime.now(UTC)} - failed_order = ProviderOrder(**failed_params, payment_status="failed") - assert failed_order.payment_status == "failed" + failed_order = ProviderOrder(**failed_params, payment_status=PaymentStatus.FAILED) + assert failed_order.payment_status == PaymentStatus.FAILED assert failed_order.pay_failed_at is not None # Act & Assert - Refunded status refunded_params = {**base_params, "refunded_at": datetime.now(UTC)} - refunded_order = ProviderOrder(**refunded_params, payment_status="refunded") - assert refunded_order.payment_status == "refunded" + refunded_order = ProviderOrder(**refunded_params, payment_status=PaymentStatus.REFUNDED) + assert refunded_order.payment_status == PaymentStatus.REFUNDED assert refunded_order.refunded_at is not None @@ -650,13 +651,13 @@ class TestLoadBalancingModelConfig: name="Secondary API Key", encrypted_config='{"api_key": "encrypted_value"}', credential_id=credential_id, - credential_source_type="custom", + credential_source_type=CredentialSourceType.CUSTOM_MODEL, ) # Assert assert config.encrypted_config == '{"api_key": "encrypted_value"}' assert config.credential_id == credential_id - assert config.credential_source_type == "custom" + assert config.credential_source_type == CredentialSourceType.CUSTOM_MODEL def test_load_balancing_config_disabled(self): """Test disabled load balancing config.""" diff --git a/api/tests/unit_tests/models/test_tool_models.py b/api/tests/unit_tests/models/test_tool_models.py index 1a75eb9a01..a6c2eae2c0 100644 --- a/api/tests/unit_tests/models/test_tool_models.py +++ b/api/tests/unit_tests/models/test_tool_models.py @@ -12,7 +12,7 @@ This test suite covers: import json from uuid import uuid4 -from core.tools.entities.tool_entities import ApiProviderSchemaType +from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolProviderType from models.tools import ( ApiToolProvider, BuiltinToolProvider, @@ -631,7 +631,7 @@ class TestToolLabelBinding: """Test creating a tool label binding.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN label_name = "search" # Act @@ -655,7 +655,7 @@ class TestToolLabelBinding: # Act label_binding = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name=label_name, ) @@ -667,7 +667,7 @@ class TestToolLabelBinding: """Test multiple labels can be bound to the same tool.""" # Arrange tool_id = "google.search" - tool_type = "builtin" + tool_type = ToolProviderType.BUILT_IN # Act binding1 = ToolLabelBinding( @@ -688,7 +688,7 @@ class TestToolLabelBinding: def test_tool_label_binding_different_tool_types(self): """Test label bindings for different tool types.""" # Arrange - tool_types = ["builtin", "api", "workflow"] + tool_types = [ToolProviderType.BUILT_IN, ToolProviderType.API, ToolProviderType.WORKFLOW] # Act & Assert for tool_type in tool_types: @@ -951,12 +951,12 @@ class TestToolProviderRelationships: # Act binding1 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="search", ) binding2 = ToolLabelBinding( tool_id=tool_id, - tool_type="builtin", + tool_type=ToolProviderType.BUILT_IN, label_name="web", ) diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 544693da34..ef29b26a7a 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -4,12 +4,18 @@ from unittest import mock from uuid import uuid4 from constants import HIDDEN_VALUE -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable -from core.workflow.variables.segments import IntegerSegment, Segment +from core.helper import encrypter +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables import FloatVariable, IntegerVariable, SecretVariable, StringVariable +from dify_graph.variables.segments import IntegerSegment, Segment from factories.variable_factory import build_segment -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import ( + Workflow, + WorkflowDraftVariable, + WorkflowNodeExecutionModel, + is_system_variable_editable, +) def test_environment_variables(): @@ -144,6 +150,36 @@ def test_to_dict(): assert workflow_dict["environment_variables"][1]["value"] == "text" +def test_normalize_environment_variable_mappings_converts_full_mask_to_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": encrypter.full_mask_token(), + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + +def test_normalize_environment_variable_mappings_keeps_hidden_value(): + normalized = Workflow.normalize_environment_variable_mappings( + [ + { + "id": str(uuid4()), + "name": "secret", + "value": HIDDEN_VALUE, + "value_type": "secret", + } + ] + ) + + assert normalized[0]["value"] == HIDDEN_VALUE + + class TestWorkflowNodeExecution: def test_execution_metadata_dict(self): node_exec = WorkflowNodeExecutionModel() diff --git a/api/tests/unit_tests/models/test_workflow_models.py b/api/tests/unit_tests/models/test_workflow_models.py index 9907cf05c0..4fcef34549 100644 --- a/api/tests/unit_tests/models/test_workflow_models.py +++ b/api/tests/unit_tests/models/test_workflow_models.py @@ -14,8 +14,8 @@ from uuid import uuid4 import pytest -from core.workflow.enums import ( - NodeType, +from dify_graph.enums import ( + BuiltinNodeTypes, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, ) @@ -471,7 +471,7 @@ class TestNodeExecutionRelationships: workflow_run_id=workflow_run_id, index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -484,7 +484,7 @@ class TestNodeExecutionRelationships: assert node_execution.workflow_id == workflow_id assert node_execution.workflow_run_id == workflow_run_id assert node_execution.node_id == "start" - assert node_execution.node_type == NodeType.START.value + assert node_execution.node_type == BuiltinNodeTypes.START assert node_execution.index == 1 def test_node_execution_with_predecessor_relationship(self): @@ -503,7 +503,7 @@ class TestNodeExecutionRelationships: index=2, predecessor_node_id=predecessor_node_id, node_id=current_node_id, - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -526,7 +526,7 @@ class TestNodeExecutionRelationships: workflow_run_id=None, # Single-step has no workflow run index=1, node_id="llm_test", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="Test LLM", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -553,7 +553,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="llm_1", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -579,7 +579,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="code_1", - node_type=NodeType.CODE.value, + node_type=BuiltinNodeTypes.CODE, title="Code Node", status=WorkflowNodeExecutionStatus.RUNNING.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -610,7 +610,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=3, node_id="code_1", - node_type=NodeType.CODE.value, + node_type=BuiltinNodeTypes.CODE, title="Code Node", status=WorkflowNodeExecutionStatus.FAILED.value, error=error_message, @@ -641,7 +641,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="llm_1", - node_type=NodeType.LLM.value, + node_type=BuiltinNodeTypes.LLM, title="LLM Node", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -664,7 +664,7 @@ class TestNodeExecutionRelationships: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -682,12 +682,12 @@ class TestNodeExecutionRelationships: """Test node execution with different node types.""" # Test various node types node_types = [ - (NodeType.START, "Start Node"), - (NodeType.LLM, "LLM Node"), - (NodeType.CODE, "Code Node"), - (NodeType.TOOL, "Tool Node"), - (NodeType.IF_ELSE, "Conditional Node"), - (NodeType.END, "End Node"), + (BuiltinNodeTypes.START, "Start Node"), + (BuiltinNodeTypes.LLM, "LLM Node"), + (BuiltinNodeTypes.CODE, "Code Node"), + (BuiltinNodeTypes.TOOL, "Tool Node"), + (BuiltinNodeTypes.IF_ELSE, "Conditional Node"), + (BuiltinNodeTypes.END, "End Node"), ] for node_type, title in node_types: @@ -699,8 +699,8 @@ class TestNodeExecutionRelationships: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value, workflow_run_id=str(uuid4()), index=1, - node_id=f"{node_type.value}_1", - node_type=node_type.value, + node_id=f"{node_type}_1", + node_type=node_type, title=title, status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -708,7 +708,7 @@ class TestNodeExecutionRelationships: ) # Assert - assert node_execution.node_type == node_type.value + assert node_execution.node_type == node_type assert node_execution.title == title @@ -1004,7 +1004,7 @@ class TestGraphConfigurationValidation: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, @@ -1029,7 +1029,7 @@ class TestGraphConfigurationValidation: workflow_run_id=str(uuid4()), index=1, node_id="start", - node_type=NodeType.START.value, + node_type=BuiltinNodeTypes.START, title="Start", status=WorkflowNodeExecutionStatus.SUCCEEDED.value, created_by_role=CreatorUserRole.ACCOUNT.value, diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py deleted file mode 100644 index 4b5b3b318c..0000000000 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Unit tests for non-SQL helper logic in workflow run repository.""" - -import secrets -from datetime import UTC, datetime -from unittest.mock import Mock, patch - -import pytest - -from core.workflow.entities.pause_reason import HumanInputRequired, PauseReasonType -from core.workflow.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormStatus -from models.human_input import BackstageRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType -from models.workflow import WorkflowPause as WorkflowPauseModel -from models.workflow import WorkflowPauseReason -from repositories.sqlalchemy_api_workflow_run_repository import ( - _build_human_input_required_reason, - _PrivateWorkflowPauseEntity, -) - - -@pytest.fixture -def sample_workflow_pause() -> Mock: - """Create a sample WorkflowPause model.""" - pause = Mock(spec=WorkflowPauseModel) - pause.id = "pause-123" - pause.workflow_id = "workflow-123" - pause.workflow_run_id = "workflow-run-123" - pause.state_object_key = "workflow-state-123.json" - pause.resumed_at = None - pause.created_at = datetime.now(UTC) - return pause - - -class TestPrivateWorkflowPauseEntity: - """Test _PrivateWorkflowPauseEntity class.""" - - def test_properties(self, sample_workflow_pause: Mock) -> None: - """Test entity properties.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - - # Assert - assert entity.id == sample_workflow_pause.id - assert entity.workflow_execution_id == sample_workflow_pause.workflow_run_id - assert entity.resumed_at == sample_workflow_pause.resumed_at - - def test_get_state(self, sample_workflow_pause: Mock) -> None: - """Test getting state from storage.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result = entity.get_state() - - # Assert - assert result == expected_state - mock_storage.load.assert_called_once_with(sample_workflow_pause.state_object_key) - - def test_get_state_caching(self, sample_workflow_pause: Mock) -> None: - """Test state caching in get_state method.""" - # Arrange - entity = _PrivateWorkflowPauseEntity(pause_model=sample_workflow_pause, reason_models=[], human_input_form=[]) - expected_state = b'{"test": "state"}' - - with patch("repositories.sqlalchemy_api_workflow_run_repository.storage") as mock_storage: - mock_storage.load.return_value = expected_state - - # Act - result1 = entity.get_state() - result2 = entity.get_state() - - # Assert - assert result1 == expected_state - assert result2 == expected_state - mock_storage.load.assert_called_once() - - -class TestBuildHumanInputRequiredReason: - """Test helper that builds HumanInputRequired pause reasons.""" - - def test_prefers_backstage_token_when_available(self) -> None: - """Use backstage token when multiple recipient types may exist.""" - # Arrange - expiration_time = datetime.now(UTC) - form_definition = FormDefinition( - form_content="content", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - rendered_content="rendered", - expiration_time=expiration_time, - default_values={"name": "Alice"}, - node_title="Ask Name", - display_in_ui=True, - ) - form_model = HumanInputForm( - id="form-1", - tenant_id="tenant-1", - app_id="app-1", - workflow_run_id="run-1", - node_id="node-1", - form_definition=form_definition.model_dump_json(), - rendered_content="rendered", - status=HumanInputFormStatus.WAITING, - expiration_time=expiration_time, - ) - reason_model = WorkflowPauseReason( - pause_id="pause-1", - type_=PauseReasonType.HUMAN_INPUT_REQUIRED, - form_id="form-1", - node_id="node-1", - message="", - ) - access_token = secrets.token_urlsafe(8) - backstage_recipient = HumanInputFormRecipient( - form_id="form-1", - delivery_id="delivery-1", - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - - # Act - reason = _build_human_input_required_reason(reason_model, form_model, [backstage_recipient]) - - # Assert - assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token - assert reason.node_title == "Ask Name" - assert reason.form_content == "content" - assert reason.inputs[0].output_variable_name == "name" - assert reason.actions[0].id == "approve" diff --git a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index f5428b46ff..8daf91c538 100644 --- a/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/unit_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -6,11 +6,11 @@ from datetime import UTC, datetime, timedelta from core.entities.execution_extra_content import HumanInputContent as HumanInputContentDomain from core.entities.execution_extra_content import HumanInputFormSubmissionData -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, UserAction, ) -from core.workflow.nodes.human_input.enums import HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormStatus from models.execution_extra_content import HumanInputContent as HumanInputContentModel from models.human_input import ConsoleRecipientPayload, HumanInputForm, HumanInputFormRecipient, RecipientType from repositories.sqlalchemy_execution_extra_content_repository import SQLAlchemyExecutionExtraContentRepository diff --git a/api/tests/unit_tests/repositories/test_workflow_run_repository.py b/api/tests/unit_tests/repositories/test_workflow_run_repository.py deleted file mode 100644 index 8f47f0df48..0000000000 --- a/api/tests/unit_tests/repositories/test_workflow_run_repository.py +++ /dev/null @@ -1,251 +0,0 @@ -"""Unit tests for workflow run repository with status filter.""" - -import uuid -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import sessionmaker - -from models import WorkflowRun, WorkflowRunTriggeredFrom -from repositories.sqlalchemy_api_workflow_run_repository import DifyAPISQLAlchemyWorkflowRunRepository - - -class TestDifyAPISQLAlchemyWorkflowRunRepository: - """Test workflow run repository with status filtering.""" - - @pytest.fixture - def mock_session_maker(self): - """Create a mock session maker.""" - return MagicMock(spec=sessionmaker) - - @pytest.fixture - def repository(self, mock_session_maker): - """Create repository instance with mock session.""" - return DifyAPISQLAlchemyWorkflowRunRepository(mock_session_maker) - - def test_get_paginated_workflow_runs_without_status(self, repository, mock_session_maker): - """Test getting paginated workflow runs without status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - mock_runs = [MagicMock(spec=WorkflowRun) for _ in range(3)] - mock_session.scalars.return_value.all.return_value = mock_runs - - # Act - result = repository.get_paginated_workflow_runs( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - limit=20, - last_id=None, - status=None, - ) - - # Assert - assert len(result.data) == 3 - assert result.limit == 20 - assert result.has_more is False - - def test_get_paginated_workflow_runs_with_status_filter(self, repository, mock_session_maker): - """Test getting paginated workflow runs with status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - mock_runs = [MagicMock(spec=WorkflowRun, status="succeeded") for _ in range(2)] - mock_session.scalars.return_value.all.return_value = mock_runs - - # Act - result = repository.get_paginated_workflow_runs( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - limit=20, - last_id=None, - status="succeeded", - ) - - # Assert - assert len(result.data) == 2 - assert all(run.status == "succeeded" for run in result.data) - - def test_get_workflow_runs_count_without_status(self, repository, mock_session_maker): - """Test getting workflow runs count without status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the GROUP BY query results - mock_results = [ - ("succeeded", 5), - ("failed", 2), - ("running", 1), - ] - mock_session.execute.return_value.all.return_value = mock_results - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status=None, - ) - - # Assert - assert result["total"] == 8 - assert result["succeeded"] == 5 - assert result["failed"] == 2 - assert result["running"] == 1 - assert result["stopped"] == 0 - assert result["partial-succeeded"] == 0 - - def test_get_workflow_runs_count_with_status_filter(self, repository, mock_session_maker): - """Test getting workflow runs count with status filter.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the count query for succeeded status - mock_session.scalar.return_value = 5 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="succeeded", - ) - - # Assert - assert result["total"] == 5 - assert result["succeeded"] == 5 - assert result["running"] == 0 - assert result["failed"] == 0 - assert result["stopped"] == 0 - assert result["partial-succeeded"] == 0 - - def test_get_workflow_runs_count_with_invalid_status(self, repository, mock_session_maker): - """Test that invalid status is still counted in total but not in any specific status.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock count query returning 0 for invalid status - mock_session.scalar.return_value = 0 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="invalid_status", - ) - - # Assert - assert result["total"] == 0 - assert all(result[status] == 0 for status in ["running", "succeeded", "failed", "stopped", "partial-succeeded"]) - - def test_get_workflow_runs_count_with_time_range(self, repository, mock_session_maker): - """Test getting workflow runs count with time range filter verifies SQL query construction.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the GROUP BY query results - mock_results = [ - ("succeeded", 3), - ("running", 2), - ] - mock_session.execute.return_value.all.return_value = mock_results - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status=None, - time_range="1d", - ) - - # Assert results - assert result["total"] == 5 - assert result["succeeded"] == 3 - assert result["running"] == 2 - assert result["failed"] == 0 - - # Verify that execute was called (which means GROUP BY query was used) - assert mock_session.execute.called, "execute should have been called for GROUP BY query" - - # Verify SQL query includes time filter by checking the statement - call_args = mock_session.execute.call_args - assert call_args is not None, "execute should have been called with a statement" - - # The first argument should be the SQL statement - stmt = call_args[0][0] - # Convert to string to inspect the query - query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) - - # Verify the query includes created_at filter - # The query should have a WHERE clause with created_at comparison - assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( - "Query should include created_at filter for time range" - ) - - def test_get_workflow_runs_count_with_status_and_time_range(self, repository, mock_session_maker): - """Test getting workflow runs count with both status and time range filters verifies SQL query.""" - # Arrange - tenant_id = str(uuid.uuid4()) - app_id = str(uuid.uuid4()) - mock_session = MagicMock() - mock_session_maker.return_value.__enter__.return_value = mock_session - - # Mock the count query for running status within time range - mock_session.scalar.return_value = 2 - - # Act - result = repository.get_workflow_runs_count( - tenant_id=tenant_id, - app_id=app_id, - triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, - status="running", - time_range="1d", - ) - - # Assert results - assert result["total"] == 2 - assert result["running"] == 2 - assert result["succeeded"] == 0 - assert result["failed"] == 0 - - # Verify that scalar was called (which means COUNT query was used) - assert mock_session.scalar.called, "scalar should have been called for count query" - - # Verify SQL query includes both status and time filter - call_args = mock_session.scalar.call_args - assert call_args is not None, "scalar should have been called with a statement" - - # The first argument should be the SQL statement - stmt = call_args[0][0] - # Convert to string to inspect the query - query_str = str(stmt.compile(compile_kwargs={"literal_binds": True})) - - # Verify the query includes both filters - assert "created_at" in query_str.lower() or "workflow_runs.created_at" in query_str.lower(), ( - "Query should include created_at filter for time range" - ) - assert "status" in query_str.lower() or "workflow_runs.status" in query_str.lower(), ( - "Query should include status filter" - ) diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py index 5cba43714a..086d1ac52e 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_repository.py @@ -12,17 +12,17 @@ import pytest from pytest_mock import MockerFixture from sqlalchemy.orm import Session, sessionmaker -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.entities import ( +from dify_graph.entities import ( WorkflowNodeExecution, ) -from core.workflow.enums import ( - NodeType, +from dify_graph.enums import ( + BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig +from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models.account import Account, Tenant from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom @@ -230,7 +230,7 @@ def test_to_db_model(repository): index=1, predecessor_node_id="test-predecessor-id", node_id="test-node-id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, title="Test Node", inputs={"input_key": "input_value"}, process_data={"process_key": "process_value"}, @@ -298,7 +298,7 @@ def test_to_domain_model(repository): db_model.predecessor_node_id = "test-predecessor-id" db_model.node_execution_id = "test-node-execution-id" db_model.node_id = "test-node-id" - db_model.node_type = NodeType.START + db_model.node_type = BuiltinNodeTypes.START db_model.title = "Test Node" db_model.inputs = json.dumps(inputs_dict) db_model.process_data = json.dumps(process_data_dict) @@ -324,7 +324,7 @@ def test_to_domain_model(repository): assert domain_model.predecessor_node_id == db_model.predecessor_node_id assert domain_model.node_execution_id == db_model.node_execution_id assert domain_model.node_id == db_model.node_id - assert domain_model.node_type == NodeType(db_model.node_type) + assert domain_model.node_type == db_model.node_type assert domain_model.title == db_model.title assert domain_model.inputs == inputs_dict assert domain_model.process_data == process_data_dict diff --git a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py index 5539856083..e01fb8456f 100644 --- a/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/repositories/workflow_node_execution/test_sqlalchemy_workflow_node_execution_repository.py @@ -11,8 +11,8 @@ from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution -from core.workflow.enums import NodeType +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution +from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom @@ -62,7 +62,7 @@ class TestSQLAlchemyWorkflowNodeExecutionRepositoryProcessData: workflow_id="test-workflow-id", index=1, node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, title="Test Node", process_data=process_data, created_at=datetime.now(), diff --git a/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py new file mode 100644 index 0000000000..c2fcd71875 --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_jina_auth_standalone_module.py @@ -0,0 +1,157 @@ +from __future__ import annotations + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock + +import httpx +import pytest + + +@pytest.fixture(scope="module") +def jina_module() -> ModuleType: + """ + Load `api/services/auth/jina.py` as a standalone module. + + This repo contains both `services/auth/jina.py` and a package at + `services/auth/jina/`, so importing `services.auth.jina` can be ambiguous. + """ + + module_path = Path(__file__).resolve().parents[4] / "services" / "auth" / "jina.py" + # Use a stable module name so pytest-cov can target it with `--cov=services.auth.jina_file`. + spec = importlib.util.spec_from_file_location("services.auth.jina_file", module_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) + return module + + +def _credentials(api_key: str | None = "test_api_key_123", auth_type: str = "bearer") -> dict: + config: dict = {} if api_key is None else {"api_key": api_key} + return {"auth_type": auth_type, "config": config} + + +def test_init_valid_bearer_credentials(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials()) + assert auth.api_key == "test_api_key_123" + assert auth.credentials["auth_type"] == "bearer" + + +def test_init_rejects_invalid_auth_type(jina_module: ModuleType) -> None: + with pytest.raises(ValueError, match="Invalid auth type.*Bearer"): + jina_module.JinaAuth(_credentials(auth_type="basic")) + + +@pytest.mark.parametrize("credentials", [{"auth_type": "bearer", "config": {}}, {"auth_type": "bearer"}]) +def test_init_requires_api_key(jina_module: ModuleType, credentials: dict) -> None: + with pytest.raises(ValueError, match="No API key provided"): + jina_module.JinaAuth(credentials) + + +def test_prepare_headers_includes_bearer_api_key(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + assert auth._prepare_headers() == {"Content-Type": "application/json", "Authorization": "Bearer k"} + + +def test_post_request_calls_httpx(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + post_mock = MagicMock(name="httpx.post") + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + auth._post_request("https://r.jina.ai", {"url": "https://example.com"}, {"h": "v"}) + post_mock.assert_called_once_with("https://r.jina.ai", headers={"h": "v"}, json={"url": "https://example.com"}) + + +def test_validate_credentials_success(jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 200 + post_mock = MagicMock(return_value=response) + monkeypatch.setattr(jina_module.httpx, "post", post_mock) + + assert auth.validate_credentials() is True + post_mock.assert_called_once_with( + "https://r.jina.ai", + headers={"Content-Type": "application/json", "Authorization": "Bearer k"}, + json={"url": "https://example.com"}, + ) + + +def test_validate_credentials_non_200_raises_via_handle_error( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + + response = MagicMock() + response.status_code = 402 + response.json.return_value = {"error": "Payment required"} + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(return_value=response)) + + with pytest.raises(Exception, match="Status code: 402.*Payment required"): + auth.validate_credentials() + + +@pytest.mark.parametrize("status_code", [402, 409, 500]) +def test_handle_error_statuses_use_response_json(jina_module: ModuleType, status_code: int) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = status_code + response.json.return_value = {"error": "boom"} + + with pytest.raises(Exception, match=f"Status code: {status_code}.*boom"): + auth._handle_error(response) + + +def test_handle_error_statuses_default_unknown_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 402 + response.json.return_value = {} + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = '{"error": "Forbidden"}' + + with pytest.raises(Exception, match="Status code: 403.*Forbidden"): + auth._handle_error(response) + + +def test_handle_error_with_text_json_body_missing_error(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 403 + response.text = "{}" + + with pytest.raises(Exception, match="Unknown error occurred"): + auth._handle_error(response) + + +def test_handle_error_without_text_raises_unexpected(jina_module: ModuleType) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + response = MagicMock() + response.status_code = 404 + response.text = "" + + with pytest.raises(Exception, match="Unexpected error occurred.*404"): + auth._handle_error(response) + + +def test_validate_credentials_propagates_network_errors( + jina_module: ModuleType, monkeypatch: pytest.MonkeyPatch +) -> None: + auth = jina_module.JinaAuth(_credentials(api_key="k")) + monkeypatch.setattr(jina_module.httpx, "post", MagicMock(side_effect=httpx.ConnectError("boom"))) + + with pytest.raises(httpx.ConnectError, match="boom"): + auth.validate_credentials() diff --git a/api/tests/unit_tests/services/dataset_permission_service.py b/api/tests/unit_tests/services/dataset_permission_service.py index b687f472a5..e098e90455 100644 --- a/api/tests/unit_tests/services/dataset_permission_service.py +++ b/api/tests/unit_tests/services/dataset_permission_service.py @@ -258,323 +258,6 @@ class DatasetPermissionTestDataFactory: return [{"user_id": user_id} for user_id in user_ids] -# ============================================================================ -# Tests for get_dataset_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceGetPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.get_dataset_partial_member_list method. - - This test class covers the retrieval of partial member lists for datasets, - which returns a list of account IDs that have explicit permissions for - a given dataset. - - The get_dataset_partial_member_list method: - 1. Queries DatasetPermission table for the dataset ID - 2. Selects account_id values - 3. Returns list of account IDs - - Test scenarios include: - - Retrieving list with multiple members - - Retrieving list with single member - - Retrieving empty list (no partial members) - - Database query validation - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - query construction and execution. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_get_dataset_partial_member_list_with_members(self, mock_db_session): - """ - Test retrieving partial member list with multiple members. - - Verifies that when a dataset has multiple partial members, all - account IDs are returned correctly. - - This test ensures: - - Query is constructed correctly - - All account IDs are returned - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456", "user-789", "user-012"] - - # Mock the scalars query to return account IDs - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 3 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_with_single_member(self, mock_db_session): - """ - Test retrieving partial member list with single member. - - Verifies that when a dataset has only one partial member, the - single account ID is returned correctly. - - This test ensures: - - Query works correctly for single member - - Result is a list with one element - - Database query is executed - """ - # Arrange - dataset_id = "dataset-123" - expected_account_ids = ["user-456"] - - # Mock the scalars query to return single account ID - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = expected_account_ids - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == expected_account_ids - assert len(result) == 1 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - def test_get_dataset_partial_member_list_empty(self, mock_db_session): - """ - Test retrieving partial member list when no members exist. - - Verifies that when a dataset has no partial members, an empty - list is returned. - - This test ensures: - - Empty list is returned correctly - - Query is executed even when no results - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the scalars query to return empty list - mock_scalars_result = Mock() - mock_scalars_result.all.return_value = [] - mock_db_session.scalars.return_value = mock_scalars_result - - # Act - result = DatasetPermissionService.get_dataset_partial_member_list(dataset_id) - - # Assert - assert result == [] - assert len(result) == 0 - - # Verify query was executed - mock_db_session.scalars.assert_called_once() - - -# ============================================================================ -# Tests for update_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceUpdatePartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.update_partial_member_list method. - - This test class covers the update of partial member lists for datasets, - which replaces the existing partial member list with a new one. - - The update_partial_member_list method: - 1. Deletes all existing DatasetPermission records for the dataset - 2. Creates new DatasetPermission records for each user in the list - 3. Adds all new permissions to the session - 4. Commits the transaction - 5. Rolls back on error - - Test scenarios include: - - Adding new partial members - - Updating existing partial members - - Replacing entire member list - - Handling empty member list - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, adds, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_update_partial_member_list_add_new_members(self, mock_db_session): - """ - Test adding new partial members to a dataset. - - Verifies that when updating with new members, the old members - are deleted and new members are added correctly. - - This test ensures: - - Old permissions are deleted - - New permissions are created - - All permissions are added to session - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456", "user-789"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - mock_query.where.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_update_partial_member_list_replace_existing(self, mock_db_session): - """ - Test replacing existing partial members with new ones. - - Verifies that when updating with a different member list, the - old members are removed and new members are added. - - This test ensures: - - Old permissions are deleted - - New permissions replace old ones - - Transaction is committed successfully - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-999", "user-888"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify new permissions were added - mock_db_session.add_all.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_empty_list(self, mock_db_session): - """ - Test updating with empty member list (clearing all members). - - Verifies that when updating with an empty list, all existing - permissions are deleted and no new permissions are added. - - This test ensures: - - Old permissions are deleted - - No new permissions are added - - Transaction is committed - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = [] - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Assert - # Verify old permissions were deleted - mock_db_session.query.assert_called() - - # Verify add_all was called with empty list - mock_db_session.add_all.assert_called_once_with([]) - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_update_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during the update, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - tenant_id = "tenant-123" - dataset_id = "dataset-123" - user_list = DatasetPermissionTestDataFactory.create_user_list_mock(["user-456"]) - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.update_partial_member_list(tenant_id, dataset_id, user_list) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for check_permission # ============================================================================ @@ -776,144 +459,6 @@ class TestDatasetPermissionServiceCheckPermission: mock_get_partial_member_list.assert_called_once_with(dataset.id) -# ============================================================================ -# Tests for clear_partial_member_list -# ============================================================================ - - -class TestDatasetPermissionServiceClearPartialMemberList: - """ - Comprehensive unit tests for DatasetPermissionService.clear_partial_member_list method. - - This test class covers the clearing of partial member lists, which removes - all DatasetPermission records for a given dataset. - - The clear_partial_member_list method: - 1. Deletes all DatasetPermission records for the dataset - 2. Commits the transaction - 3. Rolls back on error - - Test scenarios include: - - Clearing list with existing members - - Clearing empty list (no members) - - Database transaction handling - - Error handling and rollback - """ - - @pytest.fixture - def mock_db_session(self): - """ - Mock database session for testing. - - Provides a mocked database session that can be used to verify - database operations including queries, deletes, commits, and rollbacks. - """ - with patch("services.dataset_service.db.session") as mock_db: - yield mock_db - - def test_clear_partial_member_list_success(self, mock_db_session): - """ - Test successful clearing of partial member list. - - Verifies that when clearing a partial member list, all permissions - are deleted and the transaction is committed. - - This test ensures: - - All permissions are deleted - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify delete was called - mock_query.where.assert_called() - mock_query.delete.assert_called_once() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - # Verify no rollback occurred - mock_db_session.rollback.assert_not_called() - - def test_clear_partial_member_list_empty_list(self, mock_db_session): - """ - Test clearing partial member list when no members exist. - - Verifies that when clearing an already empty list, the operation - completes successfully without errors. - - This test ensures: - - Operation works correctly for empty lists - - Transaction is committed - - No errors are raised - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Act - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Assert - # Verify query was executed - mock_db_session.query.assert_called() - - # Verify transaction was committed - mock_db_session.commit.assert_called_once() - - def test_clear_partial_member_list_database_error_rollback(self, mock_db_session): - """ - Test error handling and rollback on database error. - - Verifies that when a database error occurs during clearing, - the transaction is rolled back and the error is re-raised. - - This test ensures: - - Error is caught and handled - - Transaction is rolled back - - Error is re-raised - - No commit occurs after error - """ - # Arrange - dataset_id = "dataset-123" - - # Mock the query delete operation - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_query.delete.return_value = None - mock_db_session.query.return_value = mock_query - - # Mock commit to raise an error - database_error = Exception("Database connection error") - mock_db_session.commit.side_effect = database_error - - # Act & Assert - with pytest.raises(Exception, match="Database connection error"): - DatasetPermissionService.clear_partial_member_list(dataset_id) - - # Verify rollback was called - mock_db_session.rollback.assert_called_once() - - # ============================================================================ # Tests for DatasetService.check_dataset_permission # ============================================================================ @@ -1047,72 +592,6 @@ class TestDatasetServiceCheckDatasetPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = mock_permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_permission(dataset, user) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return None (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.first.return_value = None # No permission found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_permission(dataset, user) - def test_check_dataset_permission_partial_members_creator_success(self, mock_db_session): """ Test that creator can access partial_members dataset without explicit permission. @@ -1311,72 +790,6 @@ class TestDatasetServiceCheckDatasetOperatorPermission: with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - def test_check_dataset_operator_permission_partial_members_with_permission_success(self, mock_db_session): - """ - Test that user with explicit permission can access partial_members dataset. - - Verifies that when a user has an explicit DatasetPermission record - for a partial_members dataset, they can access it successfully. - - This test ensures: - - Explicit permissions are checked correctly - - Users with permissions can access - - Database query is executed - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return permission records - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=user.id - ) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [mock_permission] # User has permission - mock_db_session.query.return_value = mock_query - - # Act (should not raise) - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - - # Assert - # Verify permission query was executed - mock_db_session.query.assert_called() - - def test_check_dataset_operator_permission_partial_members_without_permission_error(self, mock_db_session): - """ - Test error when user without permission tries to access partial_members dataset. - - Verifies that when a user does not have an explicit DatasetPermission - record for a partial_members dataset, a NoPermissionError is raised. - - This test ensures: - - Missing permissions are detected - - Error message is clear - - Error type is correct - """ - # Arrange - user = DatasetPermissionTestDataFactory.create_user_mock(user_id="user-123", role=TenantAccountRole.NORMAL) - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", - permission=DatasetPermissionEnum.PARTIAL_TEAM, - created_by="other-user-456", # Not the creator - ) - - # Mock permission query to return empty list (no permission) - mock_query = Mock() - mock_query.filter_by.return_value = mock_query - mock_query.all.return_value = [] # No permissions found - mock_db_session.query.return_value = mock_query - - # Act & Assert - with pytest.raises(NoPermissionError, match="You do not have permission to access this dataset"): - DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) - # ============================================================================ # Additional Documentation and Notes diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 4923e29d73..6829691507 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -111,7 +111,7 @@ from unittest.mock import Mock, patch import pytest from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError -from core.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py index 03c4f793cf..59c07bfb37 100644 --- a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -1,9 +1,8 @@ """Unit tests for enterprise service integrations. -This module covers the enterprise-only default workspace auto-join behavior: -- Enterprise mode disabled: no external calls -- Successful join / skipped join: no errors -- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise +Covers: +- Default workspace auto-join behavior +- License status caching (get_cached_license_status) """ from unittest.mock import patch @@ -11,6 +10,9 @@ from unittest.mock import patch import pytest from services.enterprise.enterprise_service import ( + INVALID_LICENSE_CACHE_TTL, + LICENSE_STATUS_CACHE_KEY, + VALID_LICENSE_CACHE_TTL, DefaultWorkspaceJoinResult, EnterpriseService, try_join_default_workspace, @@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace: "/default-workspace/members", json={"account_id": account_id}, timeout=1.0, - raise_for_status=True, ) def test_join_default_workspace_invalid_response_format_raises(self): @@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace: # Should not raise even though UUID parsing fails inside join_default_workspace try_join_default_workspace("not-a-uuid") + + +# --------------------------------------------------------------------------- +# get_cached_license_status +# --------------------------------------------------------------------------- + +_EE_SVC = "services.enterprise.enterprise_service" + + +class TestGetCachedLicenseStatus: + """Tests for EnterpriseService.get_cached_license_status.""" + + def test_returns_none_when_enterprise_disabled(self): + with patch(f"{_EE_SVC}.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + assert EnterpriseService.get_cached_license_status() is None + + def test_cache_hit_returns_license_status_enum(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = b"active" + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + assert isinstance(result, LicenseStatus) + mock_get_info.assert_not_called() + + def test_cache_miss_fetches_api_and_caches_valid_status(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {"License": {"status": "active"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + mock_redis.setex.assert_called_once_with( + LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE + ) + + def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {"License": {"status": "expired"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.EXPIRED + mock_redis.setex.assert_called_once_with( + LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED + ) + + def test_redis_read_failure_falls_through_to_api(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.side_effect = ConnectionError("redis down") + mock_get_info.return_value = {"License": {"status": "active"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.ACTIVE + mock_get_info.assert_called_once() + + def test_redis_write_failure_still_returns_status(self): + from services.feature_service import LicenseStatus + + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_redis.setex.side_effect = ConnectionError("redis down") + mock_get_info.return_value = {"License": {"status": "expiring"}} + + result = EnterpriseService.get_cached_license_status() + + assert result == LicenseStatus.EXPIRING + + def test_api_failure_returns_none(self): + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.side_effect = Exception("network failure") + + assert EnterpriseService.get_cached_license_status() is None + + def test_api_returns_no_license_info(self): + with ( + patch(f"{_EE_SVC}.dify_config") as mock_config, + patch(f"{_EE_SVC}.redis_client") as mock_redis, + patch.object(EnterpriseService, "get_info") as mock_get_info, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_redis.get.return_value = None + mock_get_info.return_value = {} # no "License" key + + assert EnterpriseService.get_cached_license_status() is None + mock_redis.setex.assert_not_called() diff --git a/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py new file mode 100644 index 0000000000..6ee328ae2c --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_plugin_manager_service.py @@ -0,0 +1,90 @@ +"""Unit tests for PluginManagerService. + +This module covers the pre-uninstall plugin hook behavior: +- Successful API call: no exception raised, correct request sent +- API failure: soft-fail (logs and does not re-raise) +""" + +from unittest.mock import patch + +from httpx import HTTPStatusError + +from configs import dify_config +from services.enterprise.plugin_manager_service import ( + PluginManagerService, + PreUninstallPluginRequest, +) + + +class TestTryPreUninstallPlugin: + def test_try_pre_uninstall_plugin_success(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-123", + plugin_unique_identifier="com.example.my_plugin", + ) + + with patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request: + mock_send_request.return_value = {} + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"}, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + + def test_try_pre_uninstall_plugin_http_error_soft_fails(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-456", + plugin_unique_identifier="com.example.other_plugin", + ) + + with ( + patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request, + patch("services.enterprise.plugin_manager_service.logger") as mock_logger, + ): + mock_send_request.side_effect = HTTPStatusError( + "502 Bad Gateway", + request=None, + response=None, + ) + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"}, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + mock_logger.exception.assert_called_once() + + def test_try_pre_uninstall_plugin_generic_exception_soft_fails(self): + body = PreUninstallPluginRequest( + tenant_id="tenant-789", + plugin_unique_identifier="com.example.failing_plugin", + ) + + with ( + patch( + "services.enterprise.plugin_manager_service.EnterprisePluginManagerRequest.send_request" + ) as mock_send_request, + patch("services.enterprise.plugin_manager_service.logger") as mock_logger, + ): + mock_send_request.side_effect = ConnectionError("network unreachable") + + PluginManagerService.try_pre_uninstall_plugin(body) + + mock_send_request.assert_called_once_with( + "POST", + "/pre-uninstall-plugin", + json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"}, + timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT, + ) + mock_logger.exception.assert_called_once() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index 57364142ad..afc3b29fca 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -545,7 +545,7 @@ class TestExternalDatasetServiceProcessExternalApi: params={}, ) - from core.workflow.nodes.http_request.exc import InvalidHttpMethodError + from dify_graph.nodes.http_request.exc import InvalidHttpMethodError with pytest.raises(InvalidHttpMethodError): ExternalDatasetService.process_external_api(settings, files=None) diff --git a/api/tests/unit_tests/services/plugin/__init__.py b/api/tests/unit_tests/services/plugin/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/plugin/conftest.py b/api/tests/unit_tests/services/plugin/conftest.py new file mode 100644 index 0000000000..80c6077b0c --- /dev/null +++ b/api/tests/unit_tests/services/plugin/conftest.py @@ -0,0 +1,39 @@ +"""Shared fixtures for services.plugin test suite.""" + +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from services.feature_service import PluginInstallationScope + + +def make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + """Create a mock FeatureService.get_system_features() result.""" + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features + + +@pytest.fixture +def mock_installer(monkeypatch): + """Patch PluginInstaller at the service import site.""" + mock = MagicMock() + monkeypatch.setattr("services.plugin.plugin_service.PluginInstaller", lambda: mock) + return mock + + +@pytest.fixture +def mock_features(): + """Patch FeatureService to return permissive defaults.""" + from unittest.mock import patch + + features = make_features() + with patch("services.plugin.plugin_service.FeatureService") as mock_fs: + mock_fs.get_system_features.return_value = features + yield features diff --git a/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py b/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py new file mode 100644 index 0000000000..8f0886769c --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_dependencies_analysis.py @@ -0,0 +1,172 @@ +"""Tests for services.plugin.dependencies_analysis.DependenciesAnalysisService. + +Covers: provider ID resolution, leaked dependency detection with version +extraction, dependency generation from multiple sources, and latest +dependencies via marketplace. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin import PluginDependency, PluginInstallationSource +from services.plugin.dependencies_analysis import DependenciesAnalysisService + + +class TestAnalyzeToolDependency: + def test_valid_three_part_id(self): + result = DependenciesAnalysisService.analyze_tool_dependency("langgenius/google/google") + assert result == "langgenius/google" + + def test_single_part_expands_to_langgenius(self): + result = DependenciesAnalysisService.analyze_tool_dependency("websearch") + assert result == "langgenius/websearch" + + def test_invalid_format_raises(self): + with pytest.raises(ValueError): + DependenciesAnalysisService.analyze_tool_dependency("bad/format") + + +class TestAnalyzeModelProviderDependency: + def test_valid_three_part_id(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("langgenius/openai/openai") + assert result == "langgenius/openai" + + def test_google_maps_to_gemini(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("langgenius/google/google") + assert result == "langgenius/gemini" + + def test_single_part_expands(self): + result = DependenciesAnalysisService.analyze_model_provider_dependency("anthropic") + assert result == "langgenius/anthropic" + + +class TestGetLeakedDependencies: + def _make_dependency(self, identifier: str, dep_type=PluginDependency.Type.Marketplace): + return PluginDependency( + type=dep_type, + value=PluginDependency.Marketplace(marketplace_plugin_unique_identifier=identifier), + ) + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_returns_empty_when_all_present(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [] + deps = [self._make_dependency("org/plugin:1.0.0@hash")] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert result == [] + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_returns_missing_with_version_extracted(self, mock_installer_cls): + missing = MagicMock() + missing.plugin_unique_identifier = "org/plugin:1.2.3@hash" + missing.current_identifier = "org/plugin:1.0.0@oldhash" + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [missing] + + deps = [self._make_dependency("org/plugin:1.2.3@hash")] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert len(result) == 1 + assert result[0].value.version == "1.2.3" + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_skips_present_dependencies(self, mock_installer_cls): + missing = MagicMock() + missing.plugin_unique_identifier = "org/missing:1.0.0@hash" + missing.current_identifier = None + mock_installer_cls.return_value.fetch_missing_dependencies.return_value = [missing] + + deps = [ + self._make_dependency("org/present:1.0.0@hash"), + self._make_dependency("org/missing:1.0.0@hash"), + ] + + result = DependenciesAnalysisService.get_leaked_dependencies("t1", deps) + + assert len(result) == 1 + + +class TestGenerateDependencies: + def _make_installation(self, source, identifier, meta=None): + install = MagicMock() + install.source = source + install.plugin_unique_identifier = identifier + install.meta = meta or {} + return install + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_github_source(self, mock_installer_cls): + install = self._make_installation( + PluginInstallationSource.Github, + "org/plugin:1.0.0@hash", + {"repo": "org/repo", "version": "v1.0", "package": "plugin.difypkg"}, + ) + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert len(result) == 1 + assert result[0].type == PluginDependency.Type.Github + assert result[0].value.repo == "org/repo" + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_marketplace_source(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Marketplace, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert result[0].type == PluginDependency.Type.Marketplace + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_package_source(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Package, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + result = DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + assert result[0].type == PluginDependency.Type.Package + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_remote_source_raises(self, mock_installer_cls): + install = self._make_installation(PluginInstallationSource.Remote, "org/plugin:1.0.0@hash") + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [install] + + with pytest.raises(ValueError, match="remote plugin"): + DependenciesAnalysisService.generate_dependencies("t1", ["p1"]) + + @patch("services.plugin.dependencies_analysis.PluginInstaller") + def test_deduplicates_input_ids(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_installation_by_ids.return_value = [] + + DependenciesAnalysisService.generate_dependencies("t1", ["p1", "p1", "p2"]) + + call_args = mock_installer_cls.return_value.fetch_plugin_installation_by_ids.call_args[0] + assert len(call_args[1]) == 2 # deduplicated + + +class TestGenerateLatestDependencies: + @patch("services.plugin.dependencies_analysis.dify_config") + def test_returns_empty_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + result = DependenciesAnalysisService.generate_latest_dependencies(["p1"]) + + assert result == [] + + @patch("services.plugin.dependencies_analysis.marketplace") + @patch("services.plugin.dependencies_analysis.dify_config") + def test_returns_marketplace_deps_when_enabled(self, mock_config, mock_marketplace): + mock_config.MARKETPLACE_ENABLED = True + manifest = MagicMock() + manifest.latest_package_identifier = "org/plugin:2.0.0@newhash" + mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] + + result = DependenciesAnalysisService.generate_latest_dependencies(["p1"]) + + assert len(result) == 1 + assert result[0].type == PluginDependency.Type.Marketplace diff --git a/api/tests/unit_tests/services/plugin/test_endpoint_service.py b/api/tests/unit_tests/services/plugin/test_endpoint_service.py new file mode 100644 index 0000000000..ddf80c8017 --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_endpoint_service.py @@ -0,0 +1,41 @@ +"""Tests for services.plugin.endpoint_service.EndpointService. + +Smoke tests to confirm delegation to PluginEndpointClient. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from services.plugin.endpoint_service import EndpointService + + +class TestEndpointServiceDelegation: + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_create_delegates_correctly(self, mock_client_cls): + expected = MagicMock() + mock_client_cls.return_value.create_endpoint.return_value = expected + + result = EndpointService.create_endpoint("t1", "u1", "uid-1", "my-endpoint", {"key": "val"}) + + assert result is expected + mock_client_cls.return_value.create_endpoint.assert_called_once_with( + tenant_id="t1", user_id="u1", plugin_unique_identifier="uid-1", name="my-endpoint", settings={"key": "val"} + ) + + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_list_delegates_correctly(self, mock_client_cls): + expected = MagicMock() + mock_client_cls.return_value.list_endpoints.return_value = expected + + result = EndpointService.list_endpoints("t1", "u1", 1, 10) + + assert result is expected + + @patch("services.plugin.endpoint_service.PluginEndpointClient") + def test_enable_disable_delegates(self, mock_client_cls): + EndpointService.enable_endpoint("t1", "u1", "ep-1") + mock_client_cls.return_value.enable_endpoint.assert_called_once() + + EndpointService.disable_endpoint("t1", "u1", "ep-2") + mock_client_cls.return_value.disable_endpoint.assert_called_once() diff --git a/api/tests/unit_tests/services/plugin/test_oauth_service.py b/api/tests/unit_tests/services/plugin/test_oauth_service.py new file mode 100644 index 0000000000..27df4556bc --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_oauth_service.py @@ -0,0 +1,90 @@ +"""Tests for services.plugin.oauth_service.OAuthProxyService. + +Covers: CSRF proxy context creation with Redis TTL, context consumption +with one-time use semantics, and validation error paths. +""" + +from __future__ import annotations + +import json + +import pytest + +from services.plugin.oauth_service import OAuthProxyService + + +class TestCreateProxyContext: + def test_stores_context_in_redis_with_ttl(self): + context_id = OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github" + ) + + assert context_id # non-empty UUID string + from extensions.ext_redis import redis_client + + redis_client.setex.assert_called_once() + call_args = redis_client.setex.call_args + key = call_args[0][0] + ttl = call_args[0][1] + stored_data = json.loads(call_args[0][2]) + + assert key.startswith("oauth_proxy_context:") + assert ttl == 5 * 60 + assert stored_data["user_id"] == "u1" + assert stored_data["tenant_id"] == "t1" + assert stored_data["plugin_id"] == "p1" + assert stored_data["provider"] == "github" + + def test_includes_credential_id_when_provided(self): + OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github", credential_id="cred-1" + ) + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert stored_data["credential_id"] == "cred-1" + + def test_excludes_credential_id_when_none(self): + OAuthProxyService.create_proxy_context(user_id="u1", tenant_id="t1", plugin_id="p1", provider="github") + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert "credential_id" not in stored_data + + def test_includes_extra_data(self): + OAuthProxyService.create_proxy_context( + user_id="u1", tenant_id="t1", plugin_id="p1", provider="github", extra_data={"scope": "repo"} + ) + + from extensions.ext_redis import redis_client + + stored_data = json.loads(redis_client.setex.call_args[0][2]) + assert stored_data["scope"] == "repo" + + +class TestUseProxyContext: + def test_raises_when_context_id_empty(self): + with pytest.raises(ValueError, match="context_id is required"): + OAuthProxyService.use_proxy_context("") + + def test_raises_when_context_not_found(self): + from extensions.ext_redis import redis_client + + redis_client.get.return_value = None + + with pytest.raises(ValueError, match="context_id is invalid"): + OAuthProxyService.use_proxy_context("nonexistent-id") + + def test_returns_data_and_deletes_key(self): + from extensions.ext_redis import redis_client + + stored = {"user_id": "u1", "tenant_id": "t1", "plugin_id": "p1", "provider": "github"} + redis_client.get.return_value = json.dumps(stored).encode() + + result = OAuthProxyService.use_proxy_context("valid-id") + + assert result == stored + expected_key = "oauth_proxy_context:valid-id" + redis_client.delete.assert_called_once_with(expected_key) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py new file mode 100644 index 0000000000..bfa9fe976b --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py @@ -0,0 +1,216 @@ +"""Tests for services.plugin.plugin_parameter_service.PluginParameterService. + +Covers: dynamic select options via tool and trigger credential paths, +HIDDEN_VALUE replacement, and error handling for missing records. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from services.plugin.plugin_parameter_service import PluginParameterService + + +class TestGetDynamicSelectOptionsTool: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_no_credentials_needed(self, mock_tool_mgr, mock_client_cls): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = False + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) + + assert result == ["opt1"] + call_kwargs = mock_client_cls.return_value.fetch_dynamic_select_options.call_args + assert call_kwargs[0][5] == {} # empty credentials + + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") + @patch("services.plugin.plugin_parameter_service.db") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = True + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + encrypter = MagicMock() + encrypter.decrypt.return_value = {"api_key": "decrypted"} + mock_encrypter_fn.return_value = (encrypter, None) + + # Mock the Session/query chain + db_record = MagicMock() + db_record.credentials = {"api_key": "encrypted"} + db_record.credential_type = "api_key" + + with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_session.query.return_value.where.return_value.first.return_value = db_record + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id="cred-1", + provider_type="tool", + ) + + assert result == ["opt1"] + + @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") + @patch("services.plugin.plugin_parameter_service.db") + @patch("services.plugin.plugin_parameter_service.ToolManager") + def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + provider_ctrl = MagicMock() + provider_ctrl.need_credentials = True + mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl + mock_encrypter_fn.return_value = (MagicMock(), None) + + with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) + + +class TestGetDynamicSelectOptionsTrigger: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_uses_subscription_builder_when_credential_id(self, mock_builder_svc, mock_client_cls): + sub = MagicMock() + sub.credentials = {"token": "abc"} + sub.credential_type = "api_key" + mock_builder_svc.get_subscription_builder.return_value = sub + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="builder-1", + provider_type="trigger", + ) + + assert result == ["opt"] + + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_falls_back_to_trigger_service(self, mock_builder_svc, mock_provider_svc, mock_client_cls): + mock_builder_svc.get_subscription_builder.return_value = None + trigger_sub = MagicMock() + api_entity = MagicMock() + api_entity.credentials = {"token": "abc"} + api_entity.credential_type = "api_key" + trigger_sub.to_api_entity.return_value = api_entity + mock_provider_svc.get_subscription_by_id.return_value = trigger_sub + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="sub-1", + provider_type="trigger", + ) + + assert result == ["opt"] + + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + @patch("services.plugin.plugin_parameter_service.TriggerSubscriptionBuilderService") + def test_raises_when_no_subscription_found(self, mock_builder_svc, mock_provider_svc): + mock_builder_svc.get_subscription_builder.return_value = None + mock_provider_svc.get_subscription_by_id.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="nonexistent", + provider_type="trigger", + ) + + +class TestGetDynamicSelectOptionsWithCredentials: + @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + def test_replaces_hidden_values(self, mock_provider_svc, mock_client_cls): + from constants import HIDDEN_VALUE + + original = MagicMock() + original.credentials = {"token": "real-secret", "name": "real-name"} + original.credential_type = "api_key" + mock_provider_svc.get_subscription_by_id.return_value = original + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt"] + + result = PluginParameterService.get_dynamic_select_options_with_credentials( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="cred-1", + credentials={"token": HIDDEN_VALUE, "name": "new-name"}, + ) + + assert result == ["opt"] + call_args = mock_client_cls.return_value.fetch_dynamic_select_options.call_args[0] + resolved = call_args[5] + assert resolved["token"] == "real-secret" # replaced + assert resolved["name"] == "new-name" # kept as-is + + @patch("services.plugin.plugin_parameter_service.TriggerProviderService") + def test_raises_when_subscription_not_found(self, mock_provider_svc): + mock_provider_svc.get_subscription_by_id.return_value = None + + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options_with_credentials( + tenant_id="t1", + user_id="u1", + plugin_id="p1", + provider="github", + action="on_push", + parameter="branch", + credential_id="nonexistent", + credentials={"token": "val"}, + ) diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/unit_tests/services/plugin/test_plugin_service.py new file mode 100644 index 0000000000..09b9ab498b --- /dev/null +++ b/api/tests/unit_tests/services/plugin/test_plugin_service.py @@ -0,0 +1,357 @@ +"""Tests for services.plugin.plugin_service.PluginService. + +Covers: version caching with Redis, install permission/scope gates, +icon URL construction, asset retrieval with MIME guessing, plugin +verification, marketplace upgrade flows, and uninstall with credential cleanup. +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from core.plugin.entities.plugin import PluginInstallationSource +from core.plugin.entities.plugin_daemon import PluginVerification +from services.errors.plugin import PluginInstallationForbiddenError +from services.feature_service import PluginInstallationScope +from services.plugin.plugin_service import PluginService +from tests.unit_tests.services.plugin.conftest import make_features + + +class TestFetchLatestPluginVersion: + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_returns_cached_version(self, mock_redis, mock_marketplace): + cached_json = PluginService.LatestPluginCache( + plugin_id="p1", + version="1.0.0", + unique_identifier="uid-1", + status="active", + deprecated_reason="", + alternative_plugin_id="", + ).model_dump_json() + mock_redis.get.return_value = cached_json + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result["p1"].version == "1.0.0" + mock_marketplace.batch_fetch_plugin_manifests.assert_not_called() + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_fetches_from_marketplace_on_cache_miss(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + manifest = MagicMock() + manifest.plugin_id = "p1" + manifest.latest_version = "2.0.0" + manifest.latest_package_identifier = "uid-2" + manifest.status = "active" + manifest.deprecated_reason = "" + manifest.alternative_plugin_id = "" + mock_marketplace.batch_fetch_plugin_manifests.return_value = [manifest] + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result["p1"].version == "2.0.0" + mock_redis.setex.assert_called_once() + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_returns_none_for_unknown_plugin(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + mock_marketplace.batch_fetch_plugin_manifests.return_value = [] + + result = PluginService.fetch_latest_plugin_version(["unknown"]) + + assert result["unknown"] is None + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.redis_client") + def test_handles_marketplace_exception_gracefully(self, mock_redis, mock_marketplace): + mock_redis.get.return_value = None + mock_marketplace.batch_fetch_plugin_manifests.side_effect = RuntimeError("network error") + + result = PluginService.fetch_latest_plugin_version(["p1"]) + + assert result == {} + + +class TestCheckMarketplaceOnlyPermission: + @patch("services.plugin.plugin_service.FeatureService") + def test_raises_when_restricted(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_marketplace_only_permission() + + @patch("services.plugin.plugin_service.FeatureService") + def test_passes_when_not_restricted(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + + PluginService._check_marketplace_only_permission() # should not raise + + +class TestCheckPluginInstallationScope: + @patch("services.plugin.plugin_service.FeatureService") + def test_official_only_allows_langgenius(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius + + PluginService._check_plugin_installation_scope(verification) # should not raise + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_only_rejects_third_party(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(None) + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_and_partners_allows_partner(self, mock_fs): + mock_fs.get_system_features.return_value = make_features( + scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS + ) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Partner + + PluginService._check_plugin_installation_scope(verification) # should not raise + + @patch("services.plugin.plugin_service.FeatureService") + def test_official_and_partners_rejects_none(self, mock_fs): + mock_fs.get_system_features.return_value = make_features( + scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS + ) + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(None) + + @patch("services.plugin.plugin_service.FeatureService") + def test_none_scope_always_raises(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + verification = MagicMock() + verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius + + with pytest.raises(PluginInstallationForbiddenError): + PluginService._check_plugin_installation_scope(verification) + + @patch("services.plugin.plugin_service.FeatureService") + def test_all_scope_passes_any(self, mock_fs): + mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + + PluginService._check_plugin_installation_scope(None) # should not raise + + +class TestGetPluginIconUrl: + @patch("services.plugin.plugin_service.dify_config") + def test_constructs_url_with_params(self, mock_config): + mock_config.CONSOLE_API_URL = "https://console.example.com" + + url = PluginService.get_plugin_icon_url("tenant-1", "icon.svg") + + assert "tenant_id=tenant-1" in url + assert "filename=icon.svg" in url + assert "/plugin/icon" in url + + +class TestGetAsset: + @patch("services.plugin.plugin_service.PluginAssetManager") + def test_returns_bytes_and_guessed_mime(self, mock_asset_cls): + mock_asset_cls.return_value.fetch_asset.return_value = b"" + + data, mime = PluginService.get_asset("t1", "icon.svg") + + assert data == b"" + assert "svg" in mime + + @patch("services.plugin.plugin_service.PluginAssetManager") + def test_fallback_to_octet_stream_for_unknown(self, mock_asset_cls): + mock_asset_cls.return_value.fetch_asset.return_value = b"\x00" + + _, mime = PluginService.get_asset("t1", "unknown_file") + + assert mime == "application/octet-stream" + + +class TestIsPluginVerified: + @patch("services.plugin.plugin_service.PluginInstaller") + def test_returns_true_when_verified(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_manifest.return_value.verified = True + + assert PluginService.is_plugin_verified("t1", "uid-1") is True + + @patch("services.plugin.plugin_service.PluginInstaller") + def test_returns_false_on_exception(self, mock_installer_cls): + mock_installer_cls.return_value.fetch_plugin_manifest.side_effect = RuntimeError("not found") + + assert PluginService.is_plugin_verified("t1", "uid-1") is False + + +class TestUpgradePluginWithMarketplace: + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + with pytest.raises(ValueError, match="marketplace is not enabled"): + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_same_identifier(self, mock_config): + mock_config.MARKETPLACE_ENABLED = True + + with pytest.raises(ValueError, match="same plugin"): + PluginService.upgrade_plugin_with_marketplace("t1", "same-uid", "same-uid") + + @patch("services.plugin.plugin_service.marketplace") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + mock_marketplace.record_install_plugin_event.assert_called_once_with("new-uid") + installer.upgrade_plugin.assert_called_once() + + @patch("services.plugin.plugin_service.download_plugin_pkg") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") + mock_download.return_value = b"pkg-bytes" + upload_resp = MagicMock() + upload_resp.verification = None + installer.upload_pkg.return_value = upload_resp + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") + + mock_download.assert_called_once_with("new-uid") + installer.upload_pkg.assert_called_once() + + +class TestUpgradePluginWithGithub: + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.upgrade_plugin.return_value = MagicMock() + + PluginService.upgrade_plugin_with_github("t1", "old-uid", "new-uid", "org/repo", "v1", "pkg.difypkg") + + installer.upgrade_plugin.assert_called_once() + call_args = installer.upgrade_plugin.call_args + assert call_args[0][3] == PluginInstallationSource.Github + + +class TestUploadPkg: + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): + mock_fs.get_system_features.return_value = make_features() + upload_resp = MagicMock() + upload_resp.verification = None + mock_installer_cls.return_value.upload_pkg.return_value = upload_resp + + result = PluginService.upload_pkg("t1", b"pkg-bytes") + + assert result is upload_resp + + +class TestInstallFromMarketplacePkg: + @patch("services.plugin.plugin_service.dify_config") + def test_raises_when_marketplace_disabled(self, mock_config): + mock_config.MARKETPLACE_ENABLED = False + + with pytest.raises(ValueError, match="marketplace is not enabled"): + PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + @patch("services.plugin.plugin_service.download_plugin_pkg") + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") + mock_download.return_value = b"pkg" + upload_resp = MagicMock() + upload_resp.verification = None + upload_resp.unique_identifier = "resolved-uid" + installer.upload_pkg.return_value = upload_resp + installer.install_from_identifiers.return_value = "task-id" + + result = PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + assert result == "task-id" + installer.install_from_identifiers.assert_called_once() + call_args = installer.install_from_identifiers.call_args[0] + assert call_args[1] == ["resolved-uid"] # uses response uid, not input + + @patch("services.plugin.plugin_service.FeatureService") + @patch("services.plugin.plugin_service.PluginInstaller") + @patch("services.plugin.plugin_service.dify_config") + def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): + mock_config.MARKETPLACE_ENABLED = True + mock_fs.get_system_features.return_value = make_features() + installer = mock_installer_cls.return_value + installer.fetch_plugin_manifest.return_value = MagicMock() + decode_resp = MagicMock() + decode_resp.verification = None + installer.decode_plugin_from_identifier.return_value = decode_resp + installer.install_from_identifiers.return_value = "task-id" + + PluginService.install_from_marketplace_pkg("t1", ["uid-1"]) + + installer.install_from_identifiers.assert_called_once() + call_args = installer.install_from_identifiers.call_args[0] + assert call_args[1] == ["uid-1"] # uses original uid + + +class TestUninstall: + @patch("services.plugin.plugin_service.PluginInstaller") + def test_direct_uninstall_when_plugin_not_found(self, mock_installer_cls): + installer = mock_installer_cls.return_value + installer.list_plugins.return_value = [] + installer.uninstall.return_value = True + + result = PluginService.uninstall("t1", "install-1") + + assert result is True + installer.uninstall.assert_called_once_with("t1", "install-1") + + @patch("services.plugin.plugin_service.db") + @patch("services.plugin.plugin_service.PluginInstaller") + def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + plugin = MagicMock() + plugin.installation_id = "install-1" + plugin.plugin_id = "org/myplugin" + installer = mock_installer_cls.return_value + installer.list_plugins.return_value = [plugin] + installer.uninstall.return_value = True + + # Mock Session context manager + mock_session = MagicMock() + mock_db.engine = MagicMock() + mock_session.scalars.return_value.all.return_value = [] # no credentials found + + with patch("services.plugin.plugin_service.Session") as mock_session_cls: + mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + + result = PluginService.uninstall("t1", "install-1") + + assert result is True + installer.uninstall.assert_called_once() diff --git a/api/tests/unit_tests/services/recommend_app/__init__.py b/api/tests/unit_tests/services/recommend_app/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py new file mode 100644 index 0000000000..770344aa39 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_buildin_retrieval.py @@ -0,0 +1,91 @@ +import json +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + +SAMPLE_BUILTIN_DATA = { + "recommended_apps": { + "en-US": {"categories": ["writing"], "apps": [{"id": "app-1"}]}, + "zh-Hans": {"categories": ["search"], "apps": [{"id": "app-2"}]}, + }, + "app_details": { + "app-1": {"id": "app-1", "name": "Writer", "mode": "chat"}, + "app-2": {"id": "app-2", "name": "Searcher", "mode": "workflow"}, + }, +} + + +@pytest.fixture(autouse=True) +def _reset_cache(): + BuildInRecommendAppRetrieval.builtin_data = None + yield + BuildInRecommendAppRetrieval.builtin_data = None + + +class TestBuildInRecommendAppRetrieval: + def test_get_type(self): + retrieval = BuildInRecommendAppRetrieval() + assert retrieval.get_type() == RecommendAppType.BUILDIN + + def test_get_recommended_apps_and_categories_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_apps_from_builtin", + return_value={"apps": []}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"apps": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + BuildInRecommendAppRetrieval, + "fetch_recommended_app_detail_from_builtin", + return_value={"id": "app-1"}, + ) as mock_fetch: + retrieval = BuildInRecommendAppRetrieval() + result = retrieval.get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + def test_get_builtin_data_reads_json_and_caches(self, tmp_path): + json_file = tmp_path / "constants" / "recommended_apps.json" + json_file.parent.mkdir(parents=True) + json_file.write_text(json.dumps(SAMPLE_BUILTIN_DATA)) + + mock_app = MagicMock() + mock_app.root_path = str(tmp_path) + + with patch( + "services.recommend_app.buildin.buildin_retrieval.current_app", + mock_app, + ): + first = BuildInRecommendAppRetrieval._get_builtin_data() + second = BuildInRecommendAppRetrieval._get_builtin_data() + + assert first == SAMPLE_BUILTIN_DATA + assert first is second + + def test_fetch_recommended_apps_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("en-US") + assert result == SAMPLE_BUILTIN_DATA["recommended_apps"]["en-US"] + + def test_fetch_recommended_apps_from_builtin_missing_language(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin("fr-FR") + assert result == {} + + def test_fetch_recommended_app_detail_from_builtin(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("app-1") + assert result == {"id": "app-1", "name": "Writer", "mode": "chat"} + + def test_fetch_recommended_app_detail_from_builtin_missing(self): + BuildInRecommendAppRetrieval.builtin_data = SAMPLE_BUILTIN_DATA + result = BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin("nonexistent") + assert result is None diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..5d21665f75 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,145 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): + site = ( + SimpleNamespace( + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + if has_site + else None + ) + app = ( + SimpleNamespace(is_public=is_public, site=site) + if is_public + else SimpleNamespace(is_public=False, site=site) + ) + return SimpleNamespace( + id=f"rec-{app_id}", + app=app, + app_id=app_id, + category=category, + position=1, + is_listed=True, + ) + + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_apps_and_sorted_categories(self, mock_db): + rec1 = self._make_recommended_app("a1", "writing") + rec2 = self._make_recommended_app("a2", "assistant") + mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert len(result["recommended_apps"]) == 2 + assert result["categories"] == ["assistant", "writing"] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_falls_back_to_default_language_when_empty(self, mock_db): + mock_db.session.scalars.return_value.all.side_effect = [ + [], + [self._make_recommended_app("a1", "chat")], + ] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + assert len(result["recommended_apps"]) == 1 + assert mock_db.session.scalars.call_count == 2 + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_non_public_apps(self, mock_db): + rec = self._make_recommended_app("a1", "chat", is_public=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + @patch("services.recommend_app.database.database_retrieval.db") + def test_skips_apps_without_site(self, mock_db): + rec = self._make_recommended_app("a1", "chat", has_site=False) + mock_db.session.scalars.return_value.all.return_value = [rec] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + assert result["recommended_apps"] == [] + + +class TestFetchRecommendedAppDetailFromDb: + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_not_listed(self, mock_db): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) + mock_db.session.query.side_effect = [rec_chain, app_chain] + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + @patch("services.recommend_app.database.database_retrieval.db") + def test_returns_detail_on_success(self, mock_db, mock_dsl): + app_model = SimpleNamespace( + id="app-1", + name="My App", + icon="icon.png", + icon_background="#fff", + mode="chat", + is_public=True, + ) + rec_chain = MagicMock() + rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") + app_chain = MagicMock() + app_chain.where.return_value.first.return_value = app_model + mock_db.session.query.side_effect = [rec_chain, app_chain] + mock_dsl.export_dsl.return_value = "exported_yaml" + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") + + assert result["id"] == "app-1" + assert result["name"] == "My App" + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py new file mode 100644 index 0000000000..036cba0cc0 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_factory.py @@ -0,0 +1,28 @@ +import pytest + +from services.recommend_app.buildin.buildin_retrieval import BuildInRecommendAppRetrieval +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRecommendAppRetrievalFactory: + @pytest.mark.parametrize( + ("mode", "expected_class"), + [ + ("remote", RemoteRecommendAppRetrieval), + ("builtin", BuildInRecommendAppRetrieval), + ("db", DatabaseRecommendAppRetrieval), + ], + ) + def test_factory_returns_correct_class(self, mode, expected_class): + result = RecommendAppRetrievalFactory.get_recommend_app_factory(mode) + assert result is expected_class + + def test_factory_raises_for_unknown_mode(self): + with pytest.raises(ValueError, match="invalid fetch recommended apps mode"): + RecommendAppRetrievalFactory.get_recommend_app_factory("invalid_mode") + + def test_get_buildin_recommend_app_retrieval(self): + result = RecommendAppRetrievalFactory.get_buildin_recommend_app_retrieval() + assert result is BuildInRecommendAppRetrieval diff --git a/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py new file mode 100644 index 0000000000..08f72a6f77 --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_recommend_app_type.py @@ -0,0 +1,18 @@ +from services.recommend_app.recommend_app_type import RecommendAppType + + +def test_enum_values(): + assert RecommendAppType.REMOTE == "remote" + assert RecommendAppType.BUILDIN == "builtin" + assert RecommendAppType.DATABASE == "db" + + +def test_enum_membership(): + assert "remote" in RecommendAppType.__members__.values() + assert "builtin" in RecommendAppType.__members__.values() + assert "db" in RecommendAppType.__members__.values() + + +def test_enum_is_str(): + for member in RecommendAppType: + assert isinstance(member, str) diff --git a/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py new file mode 100644 index 0000000000..e322fbed4c --- /dev/null +++ b/api/tests/unit_tests/services/recommend_app/test_remote_retrieval.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.recommend_app.recommend_app_type import RecommendAppType +from services.recommend_app.remote.remote_retrieval import RemoteRecommendAppRetrieval + + +class TestRemoteRecommendAppRetrieval: + def test_get_type(self): + assert RemoteRecommendAppRetrieval().get_type() == RecommendAppType.REMOTE + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + return_value={"id": "app-1"}, + ) + def test_get_recommend_app_detail_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "app-1"} + mock_fetch.assert_called_once_with("app-1") + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_app_detail_from_builtin", + return_value={"id": "fallback"}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_app_detail_from_dify_official", + side_effect=ConnectionError("timeout"), + ) + def test_get_recommend_app_detail_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommend_app_detail("app-1") + assert result == {"id": "fallback"} + mock_builtin.assert_called_once_with("app-1") + + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + return_value={"recommended_apps": [], "categories": []}, + ) + def test_get_recommended_apps_success(self, mock_fetch): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [], "categories": []} + + @patch( + "services.recommend_app.remote.remote_retrieval" + ".BuildInRecommendAppRetrieval.fetch_recommended_apps_from_builtin", + return_value={"recommended_apps": [{"id": "builtin"}]}, + ) + @patch.object( + RemoteRecommendAppRetrieval, + "fetch_recommended_apps_from_dify_official", + side_effect=ValueError("server error"), + ) + def test_get_recommended_apps_falls_back_on_error(self, mock_fetch, mock_builtin): + result = RemoteRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + assert result == {"recommended_apps": [{"id": "builtin"}]} + + +class TestFetchFromDifyOfficial: + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_json_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"id": "app-1", "name": "Test"} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result == {"id": "app-1", "name": "Test"} + mock_get.assert_called_once() + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_detail_returns_none_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=404) + + result = RemoteRecommendAppRetrieval.fetch_recommended_app_detail_from_dify_official("app-1") + + assert result is None + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_returns_sorted_categories_on_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = { + "recommended_apps": [], + "categories": ["writing", "agent", "chat"], + } + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert result["categories"] == ["agent", "chat", "writing"] + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_raises_on_non_200(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_get.return_value = MagicMock(status_code=500) + + with pytest.raises(ValueError, match="fetch recommended apps failed"): + RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + @patch("services.recommend_app.remote.remote_retrieval.dify_config") + @patch("services.recommend_app.remote.remote_retrieval.httpx.get") + def test_apps_without_categories_key(self, mock_get, mock_config): + mock_config.HOSTED_FETCH_APP_TEMPLATES_REMOTE_DOMAIN = "https://example.com" + mock_response = MagicMock(status_code=200) + mock_response.json.return_value = {"recommended_apps": []} + mock_get.return_value = mock_response + + result = RemoteRecommendAppRetrieval.fetch_recommended_apps_from_dify_official("en-US") + + assert "categories" not in result diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py new file mode 100644 index 0000000000..f9d901fca2 --- /dev/null +++ b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py @@ -0,0 +1,311 @@ +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.conversation.messages_clean_policy import ( + BillingDisabledPolicy, +) +from services.retention.conversation.messages_clean_service import MessagesCleanService + + +class TestMessagesCleanService: + @pytest.fixture(autouse=True) + def mock_db_engine(self): + with patch("services.retention.conversation.messages_clean_service.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db.engine + + @pytest.fixture + def mock_db_session(self, mock_db_engine): + with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: + mock_session = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + yield mock_session + + @pytest.fixture + def mock_policy(self): + policy = MagicMock(spec=BillingDisabledPolicy) + return policy + + def test_run_calls_clean_messages(self, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() + + def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): + # Arrange + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock( + rowcount=1 + ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete relations + MagicMock(rowcount=1), # delete messages + MagicMock(all=lambda: []), # next batch empty + ] + + # Reset side_effect to be more robust + # The service calls session.execute for: + # 1. Fetch messages + # 2. Fetch apps + # 3. Batch delete relations (8 calls if IDs exist) + # 4. Delete messages + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps + ] + # 8 deletes for relations + mock_returns.extend([MagicMock() for _ in range(8)]) + # 1 delete for messages + mock_returns.append(MagicMock(rowcount=1)) + # Final fetch messages (empty) + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + # Act + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + # Assert + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 1 + assert stats["batches"] == 2 + + def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): + start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + start_from=start_from, + end_before=end_before, + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: []), # No messages + ] + + stats = service.run() + assert stats["total_messages"] == 0 + + def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): + # Test pagination with cursor + end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) + service = MessagesCleanService( + policy=mock_policy, + end_before=end_before, + batch_size=1, + ) + + msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) + msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) + + mock_returns = [] + # Batch 1 + mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 2 + mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) + mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + + # Batch 3 + mock_returns.append(MagicMock(all=lambda: [])) + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified + + with patch("services.retention.conversation.messages_clean_service.time.sleep"): + stats = service.run() + + assert stats["batches"] == 3 + assert stats["total_messages"] == 2 + + def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + dry_run=True, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: + mock_sample.return_value = ["msg1"] + stats = service.run() + assert stats["filtered_messages"] == 1 + assert stats["total_deleted"] == 0 # Dry run + mock_sample.assert_called() + + def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # apps NOT found + MagicMock(all=lambda: []), # next batch empty + ] + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["total_deleted"] == 0 + + def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: []), # next batch empty + ] + + # We need to successfully execute line 228 and 229, then return empty at 251. + # line 228: raw_messages = list(session.execute(msg_stmt).all()) + # line 251: app_ids = list({msg.app_id for msg in messages}) + + calls = [] + + def list_side_effect(arg): + calls.append(arg) + if len(calls) == 2: # This is the second call to list() in the loop + return [] + return list(arg) + + with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): + stats = service.run() + assert stats["batches"] == 2 + assert stats["total_messages"] == 1 + + def test_from_time_range_validation(self, mock_policy): + now = datetime.datetime.now() + # Test start_from >= end_before + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(mock_policy, now, now) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self, mock_policy): + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + # Mock logger to avoid actual logging if needed, though it's fine + service = MessagesCleanService.from_time_range(mock_policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self, mock_policy): + # Test days < 0 + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(mock_policy, days=-1) + + # Test batch_size <= 0 + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) + + def test_from_days_success(self, mock_policy): + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(mock_policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_db_session.execute.side_effect = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + MagicMock(all=lambda: []), # next batch empty + ] + mock_policy.filter_message_ids.return_value = [] # Policy says NO + + stats = service.run() + assert stats["total_messages"] == 1 + assert stats["filtered_messages"] == 0 + assert stats["total_deleted"] == 0 + + def test_batch_delete_message_relations_empty(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, []) + mock_db_session.execute.assert_not_called() + + def test_batch_delete_message_relations_with_ids(self, mock_db_session): + MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) + assert mock_db_session.execute.call_count == 8 # 8 tables to clean up + + def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): + service = MessagesCleanService( + policy=mock_policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + + mock_returns = [ + MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages + MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps + ] + mock_returns.extend([MagicMock() for _ in range(8)]) # relations + mock_returns.append(MagicMock(rowcount=1)) # messages + mock_returns.append(MagicMock(all=lambda: [])) # next batch empty + + mock_db_session.execute.side_effect = mock_returns + mock_policy.filter_message_ids.return_value = ["msg1"] + + with patch( + "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", + 500, + ): + with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: + with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: + mock_uniform.return_value = 300.0 + service.run() + mock_uniform.assert_called_with(0, 500) + mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py new file mode 100644 index 0000000000..7d30645d38 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_clear_free_plan_expired_workflow_run_logs.py @@ -0,0 +1,524 @@ +""" +Unit tests for WorkflowRunCleanup service. +""" + +import datetime +from unittest.mock import MagicMock, patch + +import pytest + +from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup + + +def make_run(tenant_id: str = "t1", run_id: str = "r1", created_at: datetime.datetime | None = None): + run = MagicMock() + run.tenant_id = tenant_id + run.id = run_id + run.created_at = created_at or datetime.datetime(2024, 1, 1, tzinfo=datetime.UTC) + return run + + +@pytest.fixture +def mock_repo(): + return MagicMock() + + +@pytest.fixture +def cleanup(mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + yield WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + +# --------------------------------------------------------------------------- +# Constructor validation +# --------------------------------------------------------------------------- + + +class TestWorkflowRunCleanupInit: + def test_only_start_from_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_only_end_before_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="both set or both omitted"): + WorkflowRunCleanup( + days=30, + batch_size=10, + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_end_before_not_greater_than_start_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="end_before must be greater than start_from"): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 6, 1), + end_before=datetime.datetime(2024, 1, 1), + workflow_run_repo=mock_repo, + ) + + def test_equal_start_end_raises(self, mock_repo): + dt = datetime.datetime(2024, 1, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=dt, + end_before=dt, + workflow_run_repo=mock_repo, + ) + + def test_zero_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError, match="batch_size must be greater than 0"): + WorkflowRunCleanup(days=30, batch_size=0, workflow_run_repo=mock_repo) + + def test_negative_batch_size_raises(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + with pytest.raises(ValueError): + WorkflowRunCleanup(days=30, batch_size=-1, workflow_run_repo=mock_repo) + + def test_valid_window_init(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 7 + cfg.BILLING_ENABLED = False + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 6, 1) + c = WorkflowRunCleanup( + days=30, + batch_size=5, + start_from=start, + end_before=end, + workflow_run_repo=mock_repo, + ) + assert c.window_start == start + assert c.window_end == end + + def test_default_task_label_is_custom(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + assert c._metrics._base_attributes["task_label"] == "custom" + + +# --------------------------------------------------------------------------- +# _empty_related_counts / _format_related_counts +# --------------------------------------------------------------------------- + + +class TestStaticHelpers: + def test_empty_related_counts(self): + counts = WorkflowRunCleanup._empty_related_counts() + assert counts == { + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + + def test_format_related_counts(self): + counts = { + "node_executions": 1, + "offloads": 2, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + } + result = WorkflowRunCleanup._format_related_counts(counts) + assert "node_executions 1" in result + assert "offloads 2" in result + assert "trigger_logs 4" in result + + +# --------------------------------------------------------------------------- +# _expiration_datetime +# --------------------------------------------------------------------------- + + +class TestExpirationDatetime: + def test_negative_returns_none(self, cleanup): + assert cleanup._expiration_datetime("t1", -1) is None + + def test_valid_timestamp(self, cleanup): + ts = int(datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC).timestamp()) + result = cleanup._expiration_datetime("t1", ts) + assert result is not None + assert result.year == 2025 + + def test_overflow_returns_none(self, cleanup): + result = cleanup._expiration_datetime("t1", 2**62) + assert result is None + + +# --------------------------------------------------------------------------- +# _is_within_grace_period +# --------------------------------------------------------------------------- + + +class TestIsWithinGracePeriod: + def test_zero_grace_period_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 0 + assert cleanup._is_within_grace_period("t1", {"expiration_date": 9999999999}) is False + + def test_within_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + # expired just 1 day ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=1) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is True + + def test_outside_grace_period(self, cleanup): + cleanup.free_plan_grace_period_days = 5 + # expired 100 days ago + expired = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=100) + ts = int(expired.timestamp()) + assert cleanup._is_within_grace_period("t1", {"expiration_date": ts}) is False + + def test_missing_expiration_date_returns_false(self, cleanup): + cleanup.free_plan_grace_period_days = 30 + assert cleanup._is_within_grace_period("t1", {"expiration_date": -1}) is False + + +# --------------------------------------------------------------------------- +# _get_cleanup_whitelist +# --------------------------------------------------------------------------- + + +class TestGetCleanupWhitelist: + def test_billing_disabled_returns_empty(self, cleanup): + cleanup._cleanup_whitelist = None + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + result = cleanup._get_cleanup_whitelist() + assert result == set() + + def test_billing_enabled_fetches_whitelist(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.return_value = ["t1", "t2"] + result = c._get_cleanup_whitelist() + assert result == {"t1", "t2"} + + def test_cached_whitelist_returned(self, cleanup): + cleanup._cleanup_whitelist = {"cached"} + result = cleanup._get_cleanup_whitelist() + assert result == {"cached"} + + def test_billing_service_error_returns_empty(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_expired_subscription_cleanup_whitelist.side_effect = Exception("error") + result = c._get_cleanup_whitelist() + assert result == set() + + +# --------------------------------------------------------------------------- +# _filter_free_tenants +# --------------------------------------------------------------------------- + + +class TestFilterFreeTenants: + def test_billing_disabled_all_tenants_free(self, cleanup): + result = cleanup._filter_free_tenants(["t1", "t2"]) + assert result == {"t1", "t2"} + + def test_empty_tenants_returns_empty(self, cleanup): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = True + result = cleanup._filter_free_tenants([]) + assert result == set() + + def test_whitelisted_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = {"t1"} + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + from enums.cloud_plan import CloudPlan + + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + "t2": {"plan": CloudPlan.SANDBOX, "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1", "t2"]) + assert "t1" not in result + assert "t2" in result + + def test_paid_tenant_excluded(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = { + "t1": {"plan": "professional", "expiration_date": -1}, + } + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_missing_billing_info_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.return_value = {} + result = c._filter_free_tenants(["t1"]) + assert result == set() + + def test_billing_bulk_error_treats_as_non_free(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = True + c = WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + c._cleanup_whitelist = set() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.BillingService" + ) as bs: + bs.get_plan_bulk_with_cache.side_effect = Exception("fail") + result = c._filter_free_tenants(["t1"]) + assert result == set() + + +# --------------------------------------------------------------------------- +# run() — delete mode +# --------------------------------------------------------------------------- + + +class TestRunDeleteMode: + def _make_cleanup(self, mock_repo, billing_enabled=False): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = billing_enabled + return WorkflowRunCleanup(days=30, batch_size=10, workflow_run_repo=mock_repo) + + def test_no_rows_stops_immediately(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_all_paid_skips_delete(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_cleanup(mock_repo) + # billing disabled -> all free; but let's override _filter_free_tenants to return empty + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + + def test_runs_deleted_successfully(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.return_value = { + "runs": 1, + "node_executions": 0, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 0, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.time.sleep"): + c.run() + mock_repo.delete_runs_with_related.assert_called_once() + + def test_delete_exception_reraises(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.delete_runs_with_related.side_effect = RuntimeError("db error") + c = self._make_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + with pytest.raises(RuntimeError): + c.run() + + def test_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + ) + c.run() + + +# --------------------------------------------------------------------------- +# run() — dry run mode +# --------------------------------------------------------------------------- + + +class TestRunDryRunMode: + def _make_dry_cleanup(self, mock_repo): + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + return WorkflowRunCleanup( + days=30, + batch_size=10, + workflow_run_repo=mock_repo, + dry_run=True, + ) + + def test_dry_run_no_delete_called(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + mock_repo.count_runs_with_related.return_value = { + "node_executions": 2, + "offloads": 0, + "app_logs": 0, + "trigger_logs": 1, + "pauses": 0, + "pause_reasons": 0, + } + c = self._make_dry_cleanup(mock_repo) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.delete_runs_with_related.assert_not_called() + mock_repo.count_runs_with_related.assert_called_once() + + def test_dry_run_summary_with_window_start(self, mock_repo): + mock_repo.get_runs_batch_by_time_range.return_value = [] + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD = 0 + cfg.BILLING_ENABLED = False + c = WorkflowRunCleanup( + days=30, + batch_size=10, + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 6, 1), + workflow_run_repo=mock_repo, + dry_run=True, + ) + c.run() + + def test_dry_run_all_paid_skips_count(self, mock_repo): + run = make_run("t1") + mock_repo.get_runs_batch_by_time_range.side_effect = [[run], []] + c = self._make_dry_cleanup(mock_repo) + c._filter_free_tenants = MagicMock(return_value=set()) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.dify_config") as cfg: + cfg.BILLING_ENABLED = False + c.run() + mock_repo.count_runs_with_related.assert_not_called() + + +# --------------------------------------------------------------------------- +# _delete_trigger_logs / _count_trigger_logs +# --------------------------------------------------------------------------- + + +class TestTriggerLogMethods: + def test_delete_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.delete_by_run_ids.return_value = 5 + result = cleanup._delete_trigger_logs(session, ["r1", "r2"]) + assert result == 5 + + def test_count_trigger_logs(self, cleanup): + session = MagicMock() + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.SQLAlchemyWorkflowTriggerLogRepository" + ) as RepoClass: + instance = RepoClass.return_value + instance.count_by_run_ids.return_value = 3 + result = cleanup._count_trigger_logs(session, ["r1"]) + assert result == 3 + + +# --------------------------------------------------------------------------- +# _count_node_executions / _delete_node_executions +# --------------------------------------------------------------------------- + + +class TestNodeExecutionMethods: + def test_count_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.count_by_runs.return_value = (10, 2) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._count_node_executions(session, runs) + assert result == (10, 2) + + def test_delete_node_executions(self, cleanup): + session = MagicMock() + session.get_bind.return_value = MagicMock() + runs = [make_run("t1", "r1")] + with patch( + "services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.DifyAPIRepositoryFactory" + ) as factory: + repo = factory.create_api_workflow_node_execution_repository.return_value + repo.delete_by_runs.return_value = (5, 1) + with patch("services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs.sessionmaker"): + result = cleanup._delete_node_executions(session, runs) + assert result == (5, 1) diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py new file mode 100644 index 0000000000..9fe153c153 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_delete_archived_workflow_run.py @@ -0,0 +1,216 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowRun +from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion, DeleteResult + + +class TestArchivedWorkflowRunDeletion: + @pytest.fixture + def mock_db(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.db") as mock_db: + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture + def mock_sessionmaker(self): + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + mock_session = MagicMock(spec=Session) + mock_sm.return_value.return_value.__enter__.return_value = mock_session + yield mock_sm, mock_session + + @pytest.fixture + def mock_workflow_run_repo(self): + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.APIWorkflowRunRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + yield mock_repo + + def test_delete_by_run_id_success(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + tenant_id = "tenant-456" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_run.tenant_id = tenant_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [run_id] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + expected_result = DeleteResult(run_id=run_id, tenant_id=tenant_id, success=True) + mock_delete_run.return_value = expected_result + + result = deletion.delete_by_run_id(run_id) + + assert result == expected_result + mock_session.get.assert_called_once_with(WorkflowRun, run_id) + mock_repo.get_archived_run_ids.assert_called_once() + mock_delete_run.assert_called_once_with(mock_run) + + def test_delete_by_run_id_not_found(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + mock_session.get.return_value = None + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo"): + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "not found" in result.error + assert result.run_id == run_id + + def test_delete_by_run_id_not_archived(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + run_id = "run-123" + + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = run_id + mock_session.get.return_value = mock_run + + deletion = ArchivedWorkflowRunDeletion() + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_run_ids.return_value = [] + + result = deletion.delete_by_run_id(run_id) + + assert result.success is False + assert "is not archived" in result.error + + def test_delete_batch(self, mock_db, mock_sessionmaker): + mock_sm, mock_session = mock_sessionmaker + deletion = ArchivedWorkflowRunDeletion() + + mock_run1 = MagicMock(spec=WorkflowRun) + mock_run1.id = "run-1" + mock_run2 = MagicMock(spec=WorkflowRun) + mock_run2.id = "run-2" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.get_archived_runs_by_time_range.return_value = [mock_run1, mock_run2] + + with patch.object(deletion, "_delete_run") as mock_delete_run: + mock_delete_run.side_effect = [ + DeleteResult(run_id="run-1", tenant_id="t1", success=True), + DeleteResult(run_id="run-2", tenant_id="t1", success=True), + ] + + results = deletion.delete_batch(tenant_ids=["t1"], start_date=datetime.now(), end_date=datetime.now()) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + assert mock_delete_run.call_count == 2 + + def test_delete_run_dry_run(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=True) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.run_id == "run-123" + + def test_delete_run_success(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + mock_run.tenant_id = "tenant-456" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.return_value = {"workflow_runs": 1} + + result = deletion._delete_run(mock_run) + + assert result.success is True + assert result.deleted_counts == {"workflow_runs": 1} + + def test_delete_run_exception(self): + deletion = ArchivedWorkflowRunDeletion(dry_run=False) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-123" + + with patch.object(deletion, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = MagicMock() + mock_get_repo.return_value = mock_repo + mock_repo.delete_runs_with_related.side_effect = Exception("Database error") + + result = deletion._delete_run(mock_run) + + assert result.success is False + assert result.error == "Database error" + + def test_delete_trigger_logs(self): + mock_session = MagicMock(spec=Session) + run_ids = ["run-1", "run-2"] + + with patch( + "services.retention.workflow_run.delete_archived_workflow_run.SQLAlchemyWorkflowTriggerLogRepository" + ) as mock_repo_cls: + mock_repo = MagicMock() + mock_repo_cls.return_value = mock_repo + mock_repo.delete_by_run_ids.return_value = 5 + + count = ArchivedWorkflowRunDeletion._delete_trigger_logs(mock_session, run_ids) + + assert count == 5 + mock_repo_cls.assert_called_once_with(mock_session) + mock_repo.delete_by_run_ids.assert_called_once_with(run_ids) + + def test_delete_node_executions(self): + mock_session = MagicMock(spec=Session) + mock_run = MagicMock(spec=WorkflowRun) + mock_run.id = "run-1" + runs = [mock_run] + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.delete_by_runs.return_value = (1, 2) + + with patch("services.retention.workflow_run.delete_archived_workflow_run.sessionmaker") as mock_sm: + result = ArchivedWorkflowRunDeletion._delete_node_executions(mock_session, runs) + + assert result == (1, 2) + mock_create_repo.assert_called_once() + mock_repo.delete_by_runs.assert_called_once_with(mock_session, ["run-1"]) + + def test_get_workflow_run_repo(self, mock_db): + deletion = ArchivedWorkflowRunDeletion() + + with patch( + "repositories.factory.DifyAPIRepositoryFactory.create_api_workflow_run_repository" + ) as mock_create_repo: + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # First call + repo1 = deletion._get_workflow_run_repo() + assert repo1 == mock_repo + assert deletion.workflow_run_repo == mock_repo + + # Second call (should return cached) + repo2 = deletion._get_workflow_run_repo() + assert repo2 == mock_repo + mock_create_repo.assert_called_once() diff --git a/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py new file mode 100644 index 0000000000..4bfdba87a0 --- /dev/null +++ b/api/tests/unit_tests/services/retention/workflow_run/test_restore_archived_workflow_run.py @@ -0,0 +1,1117 @@ +""" +Comprehensive unit tests for WorkflowRunRestore service. + +This file provides complete test coverage for all WorkflowRunRestore methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. +""" + +import io +import json +import zipfile +from datetime import datetime +from unittest.mock import Mock, create_autospec, patch + +import pytest +from sqlalchemy import Column, Integer, MetaData, String, Table + +from libs.archive_storage import ArchiveStorageNotConfiguredError +from models.trigger import WorkflowTriggerLog +from models.workflow import ( + WorkflowAppLog, + WorkflowArchiveLog, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionOffload, + WorkflowPause, + WorkflowPauseReason, + WorkflowRun, +) +from services.retention.workflow_run.restore_archived_workflow_run import ( + SCHEMA_MAPPERS, + TABLE_MODELS, + RestoreResult, + WorkflowRunRestore, +) + + +class WorkflowRunRestoreTestDataFactory: + """ + Factory for creating test data and mock objects. + + Provides reusable methods to create consistent mock objects for testing + workflow run restore operations. + """ + + @staticmethod + def create_workflow_run_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowRun object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowRun object with specified attributes + """ + run = create_autospec(WorkflowRun, instance=True) + run.id = run_id + run.tenant_id = tenant_id + run.app_id = app_id + run.created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(run, key, value) + return run + + @staticmethod + def create_workflow_archive_log_mock( + run_id: str = "run-123", + tenant_id: str = "tenant-123", + app_id: str = "app-123", + created_at: datetime | None = None, + **kwargs, + ) -> Mock: + """ + Create a mock WorkflowArchiveLog object. + + Args: + run_id: Unique identifier for the workflow run + tenant_id: Tenant/workspace identifier + app_id: Application identifier + created_at: Creation timestamp + **kwargs: Additional attributes to set on the mock + + Returns: + Mock WorkflowArchiveLog object with specified attributes + """ + archive_log = create_autospec(WorkflowArchiveLog, instance=True) + archive_log.workflow_run_id = run_id + archive_log.tenant_id = tenant_id + archive_log.app_id = app_id + archive_log.run_created_at = created_at or datetime(2024, 1, 1, 12, 0, 0) + for key, value in kwargs.items(): + setattr(archive_log, key, value) + return archive_log + + @staticmethod + def create_archive_zip_mock( + manifest: dict | None = None, + tables_data: dict[str, list[dict]] | None = None, + ) -> bytes: + """ + Create a mock archive zip file in memory. + + Args: + manifest: Archive manifest data + tables_data: Dictionary mapping table names to list of records + + Returns: + Bytes representing the zip file + """ + if manifest is None: + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + "workflow_app_logs": {"row_count": 2}, + }, + } + + if tables_data is None: + tables_data = { + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + } + ], + "workflow_app_logs": [ + { + "id": "log-1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "log-2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "workflow_run_id": "run-123", + "created_from": "app", + "created_by_role": "account", + "created_by": "user-123", + }, + ], + } + + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest)) + for table_name, records in tables_data.items(): + jsonl_data = "\n".join(json.dumps(record) for record in records) + zip_file.writestr(f"{table_name}.jsonl", jsonl_data) + + zip_buffer.seek(0) + return zip_buffer.getvalue() + + +# --------------------------------------------------------------------------- +# Test WorkflowRunRestore Initialization +# --------------------------------------------------------------------------- + + +class TestWorkflowRunRestoreInit: + """Tests for WorkflowRunRestore.__init__ method.""" + + def test_default_initialization(self): + """Service should initialize with default values.""" + restore = WorkflowRunRestore() + assert restore.dry_run is False + assert restore.workers == 1 + assert restore.workflow_run_repo is None + + def test_dry_run_initialization(self): + """Service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + assert restore.dry_run is True + assert restore.workers == 1 + + def test_custom_workers_initialization(self): + """Service should accept custom workers count.""" + restore = WorkflowRunRestore(workers=5) + assert restore.workers == 5 + + def test_invalid_workers_raises_error(self): + """Service should raise ValueError for workers less than 1.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=0) + + def test_negative_workers_raises_error(self): + """Service should raise ValueError for negative workers.""" + with pytest.raises(ValueError, match="workers must be at least 1"): + WorkflowRunRestore(workers=-1) + + +# --------------------------------------------------------------------------- +# Test _get_workflow_run_repo Method +# --------------------------------------------------------------------------- + + +class TestGetWorkflowRunRepo: + """Tests for WorkflowRunRestore._get_workflow_run_repo method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.DifyAPIRepositoryFactory") + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + @patch("services.retention.workflow_run.restore_archived_workflow_run.db") + def test_first_call_creates_repo(self, mock_db, mock_sessionmaker, mock_factory): + """First call should create and cache repository.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + mock_repo = Mock() + mock_factory.create_api_workflow_run_repository.return_value = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + assert restore.workflow_run_repo is mock_repo + mock_sessionmaker.assert_called_once_with(bind=mock_db.engine, expire_on_commit=False) + mock_factory.create_api_workflow_run_repository.assert_called_once_with(mock_session) + + def test_cached_repo_returned(self): + """Subsequent calls should return cached repository.""" + restore = WorkflowRunRestore() + mock_repo = Mock() + restore.workflow_run_repo = mock_repo + + result = restore._get_workflow_run_repo() + + assert result is mock_repo + + +# --------------------------------------------------------------------------- +# Test _load_manifest_from_zip Method +# --------------------------------------------------------------------------- + + +class TestLoadManifestFromZip: + """Tests for WorkflowRunRestore._load_manifest_from_zip method.""" + + def test_load_valid_manifest(self): + """Should load manifest from valid zip.""" + manifest_data = {"schema_version": "1.0", "tables": {}} + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", json.dumps(manifest_data)) + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + result = WorkflowRunRestore._load_manifest_from_zip(archive) + + assert result == manifest_data + + def test_missing_manifest_raises_error(self): + """Should raise ValueError when manifest.json is missing.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("other.txt", "data") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(ValueError, match="manifest.json missing from archive bundle"): + WorkflowRunRestore._load_manifest_from_zip(archive) + + def test_invalid_json_raises_error(self): + """Should raise ValueError when manifest contains invalid JSON.""" + zip_buffer = io.BytesIO() + with zipfile.ZipFile(zip_buffer, "w") as zip_file: + zip_file.writestr("manifest.json", "invalid json") + zip_buffer.seek(0) + + with zipfile.ZipFile(zip_buffer, "r") as archive: + with pytest.raises(json.JSONDecodeError): + WorkflowRunRestore._load_manifest_from_zip(archive) + + +# --------------------------------------------------------------------------- +# Test _get_schema_version Method +# --------------------------------------------------------------------------- + + +class TestGetSchemaVersion: + """Tests for WorkflowRunRestore._get_schema_version method.""" + + def test_valid_schema_version(self): + """Should return valid schema version from manifest.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "1.0"} + result = restore._get_schema_version(manifest) + assert result == "1.0" + + def test_missing_schema_version_defaults_to_1_0(self): + """Should default to 1.0 when schema_version is missing.""" + restore = WorkflowRunRestore() + manifest = {"tables": {}} + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._get_schema_version(manifest) + + assert result == "1.0" + mock_logger.warning.assert_called_once_with("Manifest missing schema_version; defaulting to 1.0") + + def test_unsupported_schema_version_raises_error(self): + """Should raise ValueError for unsupported schema version.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": "2.0"} + + with pytest.raises(ValueError, match="Unsupported schema_version 2.0"): + restore._get_schema_version(manifest) + + def test_numeric_schema_version_converted_to_string(self): + """Should convert numeric schema version to string.""" + restore = WorkflowRunRestore() + manifest = {"schema_version": 1} + + # This should raise ValueError because "1" is not in SCHEMA_MAPPERS (only "1.0" is) + with pytest.raises(ValueError, match="Unsupported schema_version 1"): + restore._get_schema_version(manifest) + + +# --------------------------------------------------------------------------- +# Test _apply_schema_mapping Method +# --------------------------------------------------------------------------- + + +class TestApplySchemaMapping: + """Tests for WorkflowRunRestore._apply_schema_mapping method.""" + + def test_no_mapping_returns_original(self): + """Should return original record when no mapping exists.""" + restore = WorkflowRunRestore() + record = {"id": "test", "name": "test"} + result = restore._apply_schema_mapping("workflow_runs", "1.0", record) + assert result == record + + def test_mapping_applied(self): + """Should apply mapping when it exists.""" + restore = WorkflowRunRestore() + + def test_mapper(record): + return {**record, "mapped": True} + + # Add test mapper to SCHEMA_MAPPERS + original_mappers = SCHEMA_MAPPERS.copy() + SCHEMA_MAPPERS["1.0"]["test_table"] = test_mapper + + try: + record = {"id": "test"} + result = restore._apply_schema_mapping("test_table", "1.0", record) + assert result == {"id": "test", "mapped": True} + finally: + # Restore original mappers + SCHEMA_MAPPERS.clear() + SCHEMA_MAPPERS.update(original_mappers) + + +# --------------------------------------------------------------------------- +# Test _convert_datetime_fields Method +# --------------------------------------------------------------------------- + + +class TestConvertDatetimeFields: + """Tests for WorkflowRunRestore._convert_datetime_fields method.""" + + def test_iso_datetime_conversion(self): + """Should convert ISO datetime strings to datetime objects.""" + restore = WorkflowRunRestore() + + record = {"created_at": "2024-01-01T12:00:00", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["name"] == "test" + + def test_invalid_datetime_ignored(self): + """Should ignore invalid datetime strings.""" + restore = WorkflowRunRestore() + + record = {"created_at": "invalid-date", "name": "test"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["created_at"] == "invalid-date" + assert result["name"] == "test" + + def test_non_datetime_columns_unchanged(self): + """Should leave non-datetime columns unchanged.""" + restore = WorkflowRunRestore() + + record = {"id": "test", "tenant_id": "tenant-123"} + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert result["id"] == "test" + assert result["tenant_id"] == "tenant-123" + + +# --------------------------------------------------------------------------- +# Test _get_model_column_info Method +# --------------------------------------------------------------------------- + + +class TestGetModelColumnInfo: + """Tests for WorkflowRunRestore._get_model_column_info method.""" + + def test_column_info_extraction(self): + """Should extract column information correctly.""" + restore = WorkflowRunRestore() + + column_names, required_columns, non_nullable_with_default = restore._get_model_column_info(WorkflowRun) + + # Check that we get some expected columns + assert "id" in column_names + assert "tenant_id" in column_names + assert "app_id" in column_names + assert "created_at" in column_names + assert "created_by" in column_names + assert "status" in column_names + + # Columns without defaults should be required for restore inserts. + assert { + "tenant_id", + "app_id", + "workflow_id", + "type", + "triggered_from", + "version", + "status", + "created_by_role", + "created_by", + }.issubset(required_columns) + assert "id" not in required_columns + assert "created_at" not in required_columns + + # Check columns with defaults or server defaults + assert "id" in non_nullable_with_default + assert "created_at" in non_nullable_with_default + assert "elapsed_time" in non_nullable_with_default + assert "total_tokens" in non_nullable_with_default + assert "tenant_id" not in non_nullable_with_default + + def test_non_pk_auto_autoincrement_column_is_still_required(self): + """`autoincrement='auto'` should not mark non-PK columns as defaulted.""" + restore = WorkflowRunRestore() + + test_table = Table( + "test_autoincrement", + MetaData(), + Column("id", Integer, primary_key=True, autoincrement=True), + Column("required_field", String(255), nullable=False), + Column("defaulted_field", String(255), nullable=False, default="x"), + ) + + class MockModel: + __table__ = test_table + + _, required_columns, non_nullable_with_default = restore._get_model_column_info(MockModel) + + assert required_columns == {"required_field"} + assert "id" in non_nullable_with_default + assert "defaulted_field" in non_nullable_with_default + + +# --------------------------------------------------------------------------- +# Test _restore_table_records Method +# --------------------------------------------------------------------------- + + +class TestRestoreTableRecords: + """Tests for WorkflowRunRestore._restore_table_records method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.TABLE_MODELS") + def test_unknown_table_returns_zero(self, mock_table_models): + """Should return 0 for unknown table.""" + restore = WorkflowRunRestore() + mock_table_models.get.return_value = None + + mock_session = Mock() + records = [{"id": "test"}] + + with patch("services.retention.workflow_run.restore_archived_workflow_run.logger") as mock_logger: + result = restore._restore_table_records(mock_session, "unknown_table", records, schema_version="1.0") + + assert result == 0 + mock_logger.warning.assert_called_once_with("Unknown table: %s", "unknown_table") + + def test_empty_records_returns_zero(self): + """Should return 0 for empty records list.""" + restore = WorkflowRunRestore() + mock_session = Mock() + + result = restore._restore_table_records(mock_session, "workflow_runs", [], schema_version="1.0") + assert result == 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert): + """Should successfully restore records.""" + restore = WorkflowRunRestore() + + # Mock session and execution + mock_session = Mock() + mock_result = Mock() + mock_result.rowcount = 2 + mock_session.execute.return_value = mock_result + mock_cast.return_value = mock_result + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + records = [ + { + "id": "test1", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + { + "id": "test2", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + }, + ] + + result = restore._restore_table_records(mock_session, "workflow_runs", records, schema_version="1.0") + + assert result == 2 + mock_session.execute.assert_called_once() + + def test_missing_required_columns_raises_error(self): + """Should raise ValueError for missing required columns.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + # Use a dedicated mock model to isolate required-column validation behavior. + mock_model = Mock() + + # Mock a required column + required_column = Mock() + required_column.key = "required_field" + required_column.nullable = False + required_column.default = None + required_column.server_default = None + required_column.autoincrement = False + required_column.type = Mock() + + # Mock the __table__ attribute properly + mock_table = Mock() + mock_table.columns = [required_column] + mock_model.__table__ = mock_table + + records = [{"name": "test"}] # Missing required 'required_field' + + with patch.dict(TABLE_MODELS, {"test_table": mock_model}): + with pytest.raises(ValueError, match="Missing required columns for test_table"): + restore._restore_table_records(mock_session, "test_table", records, schema_version="1.0") + + +# --------------------------------------------------------------------------- +# Test _restore_from_run Method +# --------------------------------------------------------------------------- + + +class TestRestoreFromRun: + """Tests for WorkflowRunRestore._restore_from_run method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_storage_not_configured(self, mock_get_storage): + """Should handle ArchiveStorageNotConfiguredError.""" + restore = WorkflowRunRestore() + mock_get_storage.side_effect = ArchiveStorageNotConfiguredError("Storage not configured") + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Storage not configured" in result.error + assert result.elapsed_time > 0 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_archive_bundle_not_found(self, mock_get_storage): + """Should handle FileNotFoundError when archive bundle is missing.""" + restore = WorkflowRunRestore() + mock_storage = Mock() + mock_storage.get_object.side_effect = FileNotFoundError("Bundle not found") + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: Mock()) + + assert result.success is False + assert "Archive bundle not found" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_dry_run_mode(self, mock_get_storage): + """Should handle dry run mode correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create a proper mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] == 2 + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") + @patch("services.retention.workflow_run.restore_archived_workflow_run.cast") + def test_successful_restore(self, mock_cast, mock_pg_insert, mock_get_storage): + """Should successfully restore from archive.""" + restore = WorkflowRunRestore() + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session with context manager support + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + def session_maker(): + return mock_session + + # Mock database execution to return integer counts + mock_result_workflow_runs = Mock() + mock_result_workflow_runs.rowcount = 1 + mock_result_app_logs = Mock() + mock_result_app_logs.rowcount = 2 + + # Configure session.execute to return different results based on the table + def mock_execute(stmt): + if "workflow_runs" in str(stmt): + return mock_result_workflow_runs + else: + return mock_result_app_logs + + mock_session.execute.side_effect = mock_execute + mock_cast.return_value = mock_result_workflow_runs + + # Mock insert statement + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_pg_insert.return_value = mock_stmt + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Mock repository methods + with patch.object(restore, "_get_workflow_run_repo") as mock_get_repo: + mock_repo = Mock() + mock_get_repo.return_value = mock_repo + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=session_maker) + + assert result.success is True + assert result.restored_counts["workflow_runs"] == 1 + assert result.restored_counts["workflow_app_logs"] >= 1 # Just check it's restored + mock_session.commit.assert_called_once() + mock_repo.delete_archive_log_by_run_id.assert_called_once_with(mock_session, run.id) + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_invalid_archive_bundle(self, mock_get_storage): + """Should handle invalid archive bundle.""" + restore = WorkflowRunRestore() + + # Mock storage with invalid zip data + mock_storage = Mock() + mock_storage.get_object.return_value = b"invalid zip data" + mock_get_storage.return_value = mock_storage + + run = WorkflowRunRestoreTestDataFactory.create_workflow_run_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore._restore_from_run(run, session_maker=lambda: mock_session) + + assert result.success is False + # The error message comes from zipfile.BadZipFile which says "File is not a zip file" + assert "File is not a zip file" in result.error + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + def test_workflow_archive_log_input(self, mock_get_storage): + """Should handle WorkflowArchiveLog input correctly.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock storage and archive data + mock_storage = Mock() + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock() + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + + # Create proper mock session + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + result = restore._restore_from_run(archive_log, session_maker=lambda: mock_session) + + assert result.success is True + assert result.run_id == archive_log.workflow_run_id + assert result.tenant_id == archive_log.tenant_id + + +# --------------------------------------------------------------------------- +# Test restore_batch Method +# --------------------------------------------------------------------------- + + +class TestRestoreBatch: + """Tests for WorkflowRunRestore.restore_batch method.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_empty_tenant_ids_returns_empty(self, mock_sessionmaker): + """Should return empty list when tenant_ids is empty list.""" + restore = WorkflowRunRestore() + + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_batch( + tenant_ids=[], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert result == [] + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_successful_batch_restore(self, mock_executor): + """Should successfully restore batch of workflow runs.""" + restore = WorkflowRunRestore(workers=2) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + # Mock repository and archive logs + mock_repo = Mock() + archive_log1 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-1") + archive_log2 = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock("run-2") + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log1, archive_log2] + + # Mock restore results + result1 = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + result2 = RestoreResult(run_id="run-2", tenant_id="tenant-1", success=True, restored_counts={}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result1, result2]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", side_effect=[result1, result2]): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 2 + assert results[0].run_id == "run-1" + assert results[1].run_id == "run-2" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_dry_run_batch_restore(self, mock_executor): + """Should handle dry run mode for batch restore.""" + restore = WorkflowRunRestore(dry_run=True) + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_logs_by_time_range.return_value = [archive_log] + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + # Mock ThreadPoolExecutor with context manager support + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[result]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + results = restore.restore_batch( + tenant_ids=["tenant-1"], + start_date=datetime(2024, 1, 1), + end_date=datetime(2024, 1, 2), + ) + + assert len(results) == 1 + assert results[0].success is True + + +# --------------------------------------------------------------------------- +# Test restore_by_run_id Method +# --------------------------------------------------------------------------- + + +class TestRestoreByRunId: + """Tests for WorkflowRunRestore.restore_by_run_id method.""" + + def test_archive_log_not_found(self): + """Should handle case when archive log is not found.""" + restore = WorkflowRunRestore() + + mock_repo = Mock() + mock_repo.get_archived_log_by_run_id.return_value = None + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + result = restore.restore_by_run_id("nonexistent-run") + + assert result.success is False + assert "not found" in result.error + assert result.run_id == "nonexistent-run" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_successful_restore_by_id(self, mock_sessionmaker): + """Should successfully restore by run ID.""" + restore = WorkflowRunRestore() + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.sessionmaker") + def test_dry_run_restore_by_id(self, mock_sessionmaker): + """Should handle dry run mode for restore by ID.""" + restore = WorkflowRunRestore(dry_run=True) + + mock_session = Mock() + mock_sessionmaker.return_value = mock_session + + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + result = RestoreResult(run_id="run-1", tenant_id="tenant-1", success=True, restored_counts={"workflow_runs": 1}) + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch.object(restore, "_restore_from_run", return_value=result): + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock db.engine to avoid SQLAlchemy issues + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + actual_result = restore.restore_by_run_id("run-1") + + assert actual_result.success is True + assert actual_result.run_id == "run-1" + + +# --------------------------------------------------------------------------- +# Test RestoreResult Dataclass +# --------------------------------------------------------------------------- + + +class TestRestoreResult: + """Tests for RestoreResult dataclass.""" + + def test_restore_result_creation(self): + """Should create RestoreResult with all fields.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=True, + restored_counts={"workflow_runs": 1, "workflow_app_logs": 2}, + error=None, + elapsed_time=5.5, + ) + + assert result.run_id == "run-123" + assert result.tenant_id == "tenant-123" + assert result.success is True + assert result.restored_counts == {"workflow_runs": 1, "workflow_app_logs": 2} + assert result.error is None + assert result.elapsed_time == 5.5 + + def test_restore_result_with_error(self): + """Should create RestoreResult with error.""" + result = RestoreResult( + run_id="run-123", + tenant_id="tenant-123", + success=False, + restored_counts={}, + error="Something went wrong", + ) + + assert result.success is False + assert result.error == "Something went wrong" + assert result.restored_counts == {} + assert result.elapsed_time == 0.0 # Default value + + +# --------------------------------------------------------------------------- +# Test Constants and Mappings +# --------------------------------------------------------------------------- + + +class TestConstantsAndMappings: + """Tests for module constants and mappings.""" + + def test_table_models_mapping(self): + """TABLE_MODELS should contain expected table mappings.""" + expected_tables = { + "workflow_runs": WorkflowRun, + "workflow_app_logs": WorkflowAppLog, + "workflow_node_executions": WorkflowNodeExecutionModel, + "workflow_node_execution_offload": WorkflowNodeExecutionOffload, + "workflow_pauses": WorkflowPause, + "workflow_pause_reasons": WorkflowPauseReason, + "workflow_trigger_logs": WorkflowTriggerLog, + } + + assert expected_tables == TABLE_MODELS + + def test_schema_mappers_structure(self): + """SCHEMA_MAPPERS should have correct structure.""" + assert isinstance(SCHEMA_MAPPERS, dict) + assert "1.0" in SCHEMA_MAPPERS + assert isinstance(SCHEMA_MAPPERS["1.0"], dict) + + +# --------------------------------------------------------------------------- +# Integration Tests +# --------------------------------------------------------------------------- + + +class TestIntegration: + """Integration tests combining multiple components.""" + + @patch("services.retention.workflow_run.restore_archived_workflow_run.get_archive_storage") + @patch("services.retention.workflow_run.restore_archived_workflow_run.ThreadPoolExecutor") + def test_full_restore_flow(self, mock_executor, mock_get_storage): + """Test complete restore flow with all components.""" + restore = WorkflowRunRestore(workers=1) + + # Mock storage + mock_storage = Mock() + manifest = { + "schema_version": "1.0", + "tables": { + "workflow_runs": {"row_count": 1}, + }, + } + tables_data = { + "workflow_runs": [ + { + "id": "run-123", + "tenant_id": "tenant-123", + "app_id": "app-123", + "workflow_id": "workflow-123", + "type": "workflow", + "triggered_from": "app", + "version": "1", + "status": "succeeded", + "created_by_role": "account", + "created_by": "user-123", + "created_at": "2024-01-01T12:00:00", + } + ], + } + archive_data = WorkflowRunRestoreTestDataFactory.create_archive_zip_mock(manifest, tables_data) + mock_storage.get_object.return_value = archive_data + mock_get_storage.return_value = mock_storage + + # Mock session that supports context manager protocol + mock_session = Mock() + mock_session.__enter__ = Mock(return_value=mock_session) + mock_session.__exit__ = Mock(return_value=None) + + # Mock session factory that returns context manager sessions + mock_session_factory = Mock(return_value=mock_session) + + mock_result = Mock() + mock_result.rowcount = 1 + mock_session.execute.return_value = mock_result + + # Mock repository + mock_repo = Mock() + archive_log = WorkflowRunRestoreTestDataFactory.create_workflow_archive_log_mock() + mock_repo.get_archived_log_by_run_id.return_value = archive_log + + # Mock ThreadPoolExecutor (not actually used in restore_by_run_id but needed for patch) + mock_executor_instance = Mock() + mock_executor_instance.__enter__ = Mock(return_value=mock_executor_instance) + mock_executor_instance.__exit__ = Mock(return_value=None) + mock_executor_instance.map = Mock(return_value=[]) + mock_executor.return_value = mock_executor_instance + + with patch.object(restore, "_get_workflow_run_repo", return_value=mock_repo): + with patch("services.retention.workflow_run.restore_archived_workflow_run.pg_insert") as mock_insert: + mock_stmt = Mock() + mock_stmt.on_conflict_do_nothing.return_value = mock_stmt + mock_insert.return_value = mock_stmt + + with patch("services.retention.workflow_run.restore_archived_workflow_run.cast") as mock_cast: + mock_cast.return_value = mock_result + + with patch("services.retention.workflow_run.restore_archived_workflow_run.click") as mock_click: + # Mock sessionmaker and db.engine to avoid SQLAlchemy issues + with patch( + "services.retention.workflow_run.restore_archived_workflow_run.sessionmaker" + ) as mock_sessionmaker: + mock_sessionmaker.return_value = mock_session_factory + with patch("services.retention.workflow_run.restore_archived_workflow_run.db") as mock_db: + mock_db.engine = Mock() + result = restore.restore_by_run_id("run-123") + + assert result.success is True + assert result.restored_counts.get("workflow_runs") == 1 diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py index 635c86a14b..dcd6785464 100644 --- a/api/tests/unit_tests/services/test_account_service.py +++ b/api/tests/unit_tests/services/test_account_service.py @@ -1125,6 +1125,38 @@ class TestRegisterService: mock_create_workspace.assert_called_once_with(account=mock_account) mock_join_default_workspace.assert_not_called() + def test_create_account_and_tenant_still_calls_default_workspace_join_when_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should still be attempted when personal workspace creation fails.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_workspace.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(WorkSpaceNotAllowedCreateError): + AccountService.create_account_and_tenant( + email="test@example.com", + name="Test User", + interface_language="en-US", + password=None, + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): """Test successful account registration.""" # Setup mocks @@ -1235,6 +1267,84 @@ class TestRegisterService: mock_join_default_workspace.assert_not_called() + def test_register_still_calls_default_workspace_join_when_personal_workspace_creation_fails( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run even when personal workspace creation raises.""" + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError() + + with pytest.raises(AccountRegisterError, match="Workspace is not allowed to create."): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + + def test_register_still_calls_default_workspace_join_when_workspace_limit_exceeded( + self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch + ): + """Default workspace join should run before propagating workspace-limit registration failure.""" + from services.errors.workspace import WorkspacesLimitExceededError + + monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False) + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="11111111-1111-1111-1111-111111111111" + ) + + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace, + ): + mock_create_account.return_value = mock_account + mock_create_tenant.side_effect = WorkspacesLimitExceededError() + + with pytest.raises(AccountRegisterError, match="Registration failed:"): + RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + mock_join_default_workspace.assert_called_once_with(str(mock_account.id)) + mock_db_dependencies["db"].session.commit.assert_not_called() + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): """Test account registration with OAuth integration.""" # Setup mocks diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py new file mode 100644 index 0000000000..a6bc79e82b --- /dev/null +++ b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py @@ -0,0 +1,214 @@ +""" +Unit tests for services.advanced_prompt_template_service +""" + +import copy + +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) +from models.model import AppMode +from services.advanced_prompt_template_service import AdvancedPromptTemplateService + + +class TestAdvancedPromptTemplateService: + """Test suite for AdvancedPromptTemplateService.""" + + def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: + """Test baichuan model names use baichuan context prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "chat", + "model_name": "Baichuan2-13B", + "has_context": "true", + } + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + + def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: + """Test non-baichuan model names use common prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "completion", + "model_name": "gpt-4", + "has_context": "false", + } + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result == original_config + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: + """Test invalid app mode returns empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "chat" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} + + def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: + """Test context is prepended for completion prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: + """Test context is prepended for chat prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) + assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: + """Test chat prompt remains unchanged when has_context is false.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") + + # Assert + assert result == original_config + assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: + """Test completion app mode with completion model returns completion prompt.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") + + # Assert + assert result == original_config + assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: + """Test invalid model mode returns empty dict.""" + # Arrange + app_mode = AppMode.CHAT + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") + + # Assert + assert result == {} + + def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps completion prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] + + # Act + result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"] == original_text + + def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps chat prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] + + # Act + result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text + + def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: + """Test baichuan chat/completion returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: + """Test baichuan completion/chat returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: + """Test baichuan completion/completion prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: + """Test baichuan chat/chat prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: + """Test invalid baichuan mode combinations return empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py new file mode 100644 index 0000000000..7ce3d7ef7b --- /dev/null +++ b/api/tests/unit_tests/services/test_agent_service.py @@ -0,0 +1,346 @@ +""" +Unit tests for services.agent_service +""" + +from collections.abc import Callable +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +import pytz + +from core.plugin.impl.exc import PluginDaemonClientSideError +from models import Account +from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from services.agent_service import AgentService + + +def _make_current_user_account(timezone: str = "UTC") -> Account: + account = Account(name="Test User", email="test@example.com") + account.timezone = timezone + return account + + +def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: + app_model = MagicMock(spec=App) + app_model.id = "app-123" + app_model.tenant_id = "tenant-123" + app_model.app_model_config = app_model_config + return app_model + + +def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: + conversation = MagicMock(spec=Conversation) + conversation.id = "conv-123" + conversation.app_id = "app-123" + conversation.from_end_user_id = from_end_user_id + conversation.from_account_id = from_account_id + return conversation + + +def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: + message = MagicMock(spec=Message) + message.id = "msg-123" + message.conversation_id = "conv-123" + message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + message.provider_response_latency = 1.23 + message.answer_tokens = 4 + message.message_tokens = 6 + message.agent_thoughts = agent_thoughts + message.message_files = ["file-a.txt"] + return message + + +def _make_agent_thought() -> MagicMock: + agent_thought = MagicMock(spec=MessageAgentThought) + agent_thought.tokens = 3 + agent_thought.tool_input = "raw-input" + agent_thought.observation = "raw-output" + agent_thought.thought = "thinking" + agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + agent_thought.files = [] + agent_thought.tools = ["tool_a", "dataset_tool"] + agent_thought.tool_labels = {"tool_a": "Tool A"} + agent_thought.tool_meta = { + "tool_a": { + "tool_config": { + "tool_provider_type": "custom", + "tool_provider": "provider-1", + }, + "tool_parameters": {"param": "value"}, + "time_cost": 2.5, + }, + "dataset_tool": { + "tool_config": { + "tool_provider_type": "dataset-retrieval", + "tool_provider": "dataset-provider", + } + }, + } + agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} + agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} + return agent_thought + + +def _build_query_side_effect( + conversation: Conversation | None, + message: Message | None, + executor: EndUser | Account | None, +) -> Callable[..., MagicMock]: + def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if any(arg is Conversation for arg in args): + query.first.return_value = conversation + elif any(arg is Message for arg in args): + query.first.return_value = message + elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): + query.first.return_value = executor + return query + + return _query_side_effect + + +class TestAgentServiceGetAgentLogs: + """Test suite for AgentService.get_agent_logs.""" + + def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: + """Test missing conversation raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + with patch("services.agent_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") + + def test_get_agent_logs_should_raise_when_message_missing(self) -> None: + """Test missing message raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + with patch("services.agent_service.db") as mock_db: + conversation_query = MagicMock() + conversation_query.where.return_value = conversation_query + conversation_query.first.return_value = conversation + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [conversation_query, message_query] + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") + + def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: + """Test missing app model config raises ValueError.""" + # Arrange + app_model = _make_app_model(None) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: + """Test missing agent config raises ValueError.""" + # Arrange + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: + """Test agent logs returned for end-user executor with tool icons.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + executor = MagicMock(spec=EndUser) + executor.name = "End User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_tool = MagicMock() + agent_tool.tool_name = "tool_a" + agent_tool.provider_type = "custom" + agent_tool.provider_id = "provider-2" + agent_config = MagicMock() + agent_config.tools = [agent_tool] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, + patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + mock_get_icon.side_effect = [None, "icon-a"] + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["status"] == "success" + assert result["meta"]["executor"] == "End User" + assert result["meta"]["total_tokens"] == 10 + assert result["meta"]["agent_mode"] == "react" + assert result["meta"]["iterations"] == 1 + assert result["files"] == ["file-a.txt"] + assert len(result["iterations"]) == 1 + tool_calls = result["iterations"][0]["tool_calls"] + assert tool_calls[0]["tool_name"] == "tool_a" + assert tool_calls[0]["tool_icon"] == "icon-a" + assert tool_calls[1]["tool_name"] == "dataset_tool" + assert tool_calls[1]["tool_icon"] == "" + mock_convert.assert_called_once() + + def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: + """Test agent logs fall back to account executor when end user is missing.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") + executor = MagicMock(spec=Account) + executor.name = "Account User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Account User" + + def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: + """Test unknown executor and missing tool details fall back to defaults.""" + # Arrange + agent_thought = _make_agent_thought() + agent_thought.tool_labels = {} + agent_thought.tool_inputs_dict = {} + agent_thought.tool_outputs_dict = None + agent_thought.tool_meta = {"tool_a": {"error": "failed"}} + agent_thought.tools = ["tool_a"] + + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Unknown" + assert result["meta"]["agent_mode"] == "react" + tool_call = result["iterations"][0]["tool_calls"][0] + assert tool_call["status"] == "error" + assert tool_call["error"] == "failed" + assert tool_call["tool_label"] == "tool_a" + assert tool_call["tool_input"] == {} + assert tool_call["tool_output"] == {} + assert tool_call["time_cost"] == 0 + assert tool_call["tool_parameters"] == {} + assert tool_call["tool_icon"] is None + + +class TestAgentServiceProviders: + """Test suite for AgentService provider methods.""" + + def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: + """Test list_agent_providers delegates to PluginAgentClient.""" + # Arrange + tenant_id = "tenant-1" + expected = [{"name": "provider"}] + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_providers.return_value = expected + + # Act + result = AgentService.list_agent_providers("user-1", tenant_id) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) + + def test_get_agent_provider_should_return_provider_when_successful(self) -> None: + """Test get_agent_provider returns provider when successful.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + expected = {"name": provider_name} + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.return_value = expected + + # Act + result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) + + def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: + """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( + "plugin error" + ) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0aacfc7f13 --- /dev/null +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -0,0 +1,1685 @@ +""" +Unit tests for services.annotation_service +""" + +from io import BytesIO +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from services.annotation_service import AppAnnotationService + + +def _make_app(app_id: str = "app-1", tenant_id: str = "tenant-1") -> MagicMock: + app = MagicMock(spec=App) + app.id = app_id + app.tenant_id = tenant_id + app.status = "normal" + return app + + +def _make_user(user_id: str = "user-1") -> MagicMock: + user = MagicMock() + user.id = user_id + return user + + +def _make_message(message_id: str = "msg-1", app_id: str = "app-1") -> MagicMock: + message = MagicMock(spec=Message) + message.id = message_id + message.app_id = app_id + message.conversation_id = "conv-1" + message.query = "default-question" + message.annotation = None + return message + + +def _make_annotation(annotation_id: str = "ann-1") -> MagicMock: + annotation = MagicMock(spec=MessageAnnotation) + annotation.id = annotation_id + annotation.content = "" + annotation.question = "" + annotation.question_text = "" + return annotation + + +def _make_setting(setting_id: str = "setting-1", with_detail: bool = True) -> MagicMock: + setting = MagicMock(spec=AppAnnotationSetting) + setting.id = setting_id + setting.score_threshold = 0.5 + setting.collection_binding_id = "collection-1" + if with_detail: + setting.collection_binding_detail = SimpleNamespace(provider_name="provider-a", model_name="model-a") + else: + setting.collection_binding_detail = None + return setting + + +def _make_file(content: bytes) -> FileStorage: + return FileStorage(stream=BytesIO(content)) + + +class TestAppAnnotationServiceUpInsert: + """Test suite for up_insert_app_annotation_from_message.""" + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, "app-1") + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_answer_missing(self) -> None: + """Test missing answer and content raises ValueError.""" + # Arrange + args = {"message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_message_missing(self) -> None: + """Test missing message raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_update_existing_annotation_when_found(self) -> None: + """Test existing annotation is updated and indexed.""" + # Arrange + args = {"answer": "updated", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = annotation + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation + assert annotation.content == "updated" + assert annotation.question == message.query + mock_db.session.add.assert_called_once_with(annotation) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + message.query, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_has_no_annotation( + self, + ) -> None: + """Test new annotation is created when message has no annotation.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = None + annotation_instance = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question without message_id raises ValueError.""" + # Arrange + args = {"answer": "hello"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_missing(self) -> None: + """Test annotation is created when message_id is not provided.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + +class TestAppAnnotationServiceEnableDisable: + """Test suite for enable/disable app annotation.""" + + def test_enable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test cache hit returns processing status.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-1" + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "job-1", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_enable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test cache miss enqueues enable task.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-1"), + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "uuid-1", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("enable_app_annotation_job_uuid-1", "waiting") + mock_task.delay.assert_called_once_with( + "uuid-1", + "app-1", + current_user.id, + tenant_id, + 0.5, + "p", + "m", + ) + + def test_disable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test disable cache hit returns processing status.""" + # Arrange + tenant_id = "tenant-1" + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-2" + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "job-2", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_disable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test disable cache miss enqueues disable task.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-2"), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "uuid-2", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("disable_app_annotation_job_uuid-2", "waiting") + mock_task.delay.assert_called_once_with("uuid-2", "app-1", tenant_id) + + +class TestAppAnnotationServiceListAndExport: + """Test suite for list and export methods.""" + + def test_get_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_list_by_app_id("app-1", 1, 10, "") + + def test_get_annotation_list_by_app_id_should_return_items_with_keyword(self) -> None: + """Test keyword search returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1"], total=1) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("libs.helper.escape_like_pattern", return_value="safe"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "keyword") + + # Assert + assert items == ["a1"] + assert total == 1 + + def test_get_annotation_list_by_app_id_should_return_items_without_keyword(self) -> None: + """Test list query without keyword returns paginated items.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1", "a2"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "") + + # Assert + assert items == ["a1", "a2"] + assert total == 2 + + def test_export_annotation_list_by_app_id_should_sanitize_fields(self) -> None: + """Test export sanitizes question and content fields.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation1.question = "=cmd" + annotation1.content = "+1" + annotation2 = _make_annotation("ann-2") + annotation2.question = "@bad" + annotation2.content = "-2" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.CSVSanitizer.sanitize_value", side_effect=lambda v: f"safe:{v}"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.order_by.return_value = annotation_query + annotation_query.all.return_value = [annotation1, annotation2] + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act + result = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Assert + assert result == [annotation1, annotation2] + assert annotation1.question == "safe:=cmd" + assert annotation1.content == "safe:+1" + assert annotation2.question == "safe:@bad" + assert annotation2.content == "safe:-2" + + def test_export_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test export raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.export_annotation_list_by_app_id("app-1") + + +class TestAppAnnotationServiceDirectManipulation: + """Test suite for direct insert/update/delete methods.""" + + def test_insert_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test insert raises NotFound when app is missing.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.insert_app_annotation_directly(args, "app-1") + + def test_insert_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.insert_app_annotation_directly(args, app.id) + + def test_insert_app_annotation_directly_should_create_annotation_and_index(self) -> None: + """Test insert creates annotation and triggers index task.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.insert_app_annotation_directly(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_update_app_annotation_directly_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1") + + def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound in update path.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1") + + def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None: + """Test update changes fields and triggers index update.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + annotation.question_text = "q1" + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.update_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + # Assert + assert result == annotation + assert annotation.content == "hello" + assert annotation.question == "q1" + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + annotation.question_text, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_delete_annotation_and_histories(self) -> None: + """Test delete removes annotation and hit histories.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + history1 = MagicMock(spec=AppAnnotationHitHistory) + history2 = MagicMock(spec=AppAnnotationHitHistory) + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + scalars_result = MagicMock() + scalars_result.all.return_value = [history1, history2] + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + mock_db.session.scalars.return_value = scalars_result + + # Act + AppAnnotationService.delete_app_annotation(app.id, annotation.id) + + # Assert + mock_db.session.delete.assert_any_call(annotation) + mock_db.session.delete.assert_any_call(history1) + mock_db.session.delete.assert_any_call(history2) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + app.id, + tenant_id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_raise_not_found_when_app_missing(self) -> None: + """Test delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation("app-1", "ann-1") + + def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None: + """Test delete raises NotFound when annotation is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation(app.id, "ann-1") + + def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None: + """Test batch delete returns zero when no annotations found.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [] + + mock_db.session.query.side_effect = [app_query, annotations_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1"]) + + # Assert + assert result == {"deleted_count": 0} + + def test_delete_app_annotations_in_batch_should_raise_not_found_when_app_missing(self) -> None: + """Test batch delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotations_in_batch("app-1", ["ann-1"]) + + def test_delete_app_annotations_in_batch_should_delete_annotations_and_histories(self) -> None: + """Test batch delete removes annotations and triggers index deletion.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [(annotation1, setting), (annotation2, None)] + + hit_history_query = MagicMock() + hit_history_query.where.return_value = hit_history_query + hit_history_query.delete.return_value = None + + delete_query = MagicMock() + delete_query.where.return_value = delete_query + delete_query.delete.return_value = 2 + + mock_db.session.query.side_effect = [app_query, annotations_query, hit_history_query, delete_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1", "ann-2"]) + + # Assert + assert result == {"deleted_count": 2} + mock_task.delay.assert_called_once_with(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + +class TestAppAnnotationServiceBatchImport: + """Test suite for batch import.""" + + def test_batch_import_app_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.batch_import_app_annotations("app-1", file) + + def test_batch_import_app_annotations_should_return_error_when_columns_invalid(self) -> None: + """Test invalid column count returns error message.""" + # Arrange + file = _make_file(b"question\nq\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["only"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Invalid CSV format" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_file_empty(self) -> None: + """Test empty file returns validation error before CSV parsing.""" + # Arrange + file = _make_file(b"") + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "empty or invalid" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_min_records_not_met(self) -> None: + """Test min records validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=2), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_row_limit_exceeded(self) -> None: + """Test row count over max limit returns explicit error.""" + # Arrange + file = _make_file(b"question,answer\nq1,a1\nq2,a2\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1", "q2"], "a": ["a1", "a2"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=1, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "too many records" in error_msg + + def test_batch_import_app_annotations_should_skip_malformed_rows_and_fail_min_records(self) -> None: + """Test malformed row extraction is skipped and can fail min record validation.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + malformed_row = MagicMock() + malformed_row.iloc.__getitem__.side_effect = IndexError() + df = MagicMock() + df.columns = ["q", "a"] + df.iterrows.return_value = [(0, malformed_row)] + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_skip_nan_rows_and_fail_min_records(self) -> None: + """Test NaN rows are skipped by validation and reported via min record check.""" + # Arrange + file = _make_file(b"question,answer\nnan,nan\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["nan"], "a": ["nan"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_question_too_long(self) -> None: + """Test oversized question is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q" * 2001], "a": ["a"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Question at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_answer_too_long(self) -> None: + """Test oversized answer is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q"], "a": ["a" * 10001]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Answer at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_quota_exceeded(self) -> None: + """Test quota validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True), + annotation_quota_limit=SimpleNamespace(limit=1, size=1), + ) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "exceeds the limit" in error_msg + + def test_batch_import_app_annotations_should_enqueue_job_when_valid(self) -> None: + """Test successful batch import enqueues job and returns status.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.batch_import_annotations_task") as mock_task, + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-3"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result == {"job_id": "uuid-3", "job_status": "waiting", "record_count": 1} + mock_redis.zadd.assert_called_once() + mock_redis.expire.assert_called_once() + mock_redis.setnx.assert_called_once_with("app_annotation_batch_import_uuid-3", "waiting") + mock_task.delay.assert_called_once() + + def test_batch_import_app_annotations_should_cleanup_active_job_on_unexpected_exception(self) -> None: + """Test unexpected runtime errors trigger cleanup and return wrapped error.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-4"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch("services.annotation_service.logger") as mock_logger, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_redis.zadd.side_effect = RuntimeError("boom") + mock_redis.zrem.side_effect = RuntimeError("cleanup-failed") + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result["error_msg"] == "An error occurred while processing the file: boom" + mock_redis.zrem.assert_called_once_with(f"annotation_import_active:{tenant_id}", "uuid-4") + mock_logger.debug.assert_called_once() + + +class TestAppAnnotationServiceHitHistoryAndSettings: + """Test suite for hit history and settings methods.""" + + def test_get_annotation_hit_histories_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories("app-1", "ann-1", 1, 10) + + def test_get_annotation_hit_histories_should_return_items_and_total(self) -> None: + """Test hit histories pagination returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + pagination = SimpleNamespace(items=["h1"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_hit_histories(app.id, annotation.id, 1, 10) + + # Assert + assert items == ["h1"] + assert total == 2 + + def test_get_annotation_hit_histories_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories(app.id, "ann-1", 1, 10) + + def test_get_annotation_by_id_should_return_none_when_missing(self) -> None: + """Test get_annotation_by_id returns None when not found.""" + # Arrange + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result is None + + def test_get_annotation_by_id_should_return_annotation_when_exists(self) -> None: + """Test get_annotation_by_id returns annotation when found.""" + # Arrange + annotation = _make_annotation("ann-1") + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = annotation + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result == annotation + + def test_add_annotation_history_should_update_hit_count_and_store_history(self) -> None: + """Test add_annotation_history updates hit count and creates history.""" + # Arrange + with ( + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.AppAnnotationHitHistory") as mock_history_cls, + ): + query = MagicMock() + query.where.return_value = query + mock_db.session.query.return_value = query + + # Act + AppAnnotationService.add_annotation_history( + annotation_id="ann-1", + app_id="app-1", + annotation_question="q", + annotation_content="a", + query="q", + user_id="user-1", + message_id="msg-1", + from_source="chat", + score=0.8, + ) + + # Assert + query.update.assert_called_once() + mock_history_cls.assert_called_once() + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_get_app_annotation_setting_by_app_id_should_return_embedding_model_when_detail_exists(self) -> None: + """Test setting detail returns embedding model info.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=True) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + assert embedding_model["embedding_model_name"] == "model-a" + + def test_get_app_annotation_setting_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_app_annotation_setting_by_app_id("app-1") + + def test_get_app_annotation_setting_by_app_id_should_return_empty_embedding_model_when_no_detail(self) -> None: + """Test setting without detail returns empty embedding model.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=False) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + assert result["embedding_model"] == {} + + def test_get_app_annotation_setting_by_app_id_should_return_disabled_when_setting_missing(self) -> None: + """Test missing setting returns disabled payload.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result == {"enabled": False} + + def test_update_app_annotation_setting_should_update_and_return_detail(self) -> None: + """Test update_app_annotation_setting updates fields and returns detail.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=True) + args = {"score_threshold": 0.8} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.8 + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + mock_db.session.add.assert_called_once_with(setting) + mock_db.session.commit.assert_called_once() + + def test_update_app_annotation_setting_should_return_empty_embedding_model_when_detail_missing(self) -> None: + """Test update returns empty embedding_model when collection detail is absent.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=False) + args = {"score_threshold": 0.7} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.7 + assert result["embedding_model"] == {} + + def test_update_app_annotation_setting_should_raise_not_found_when_app_missing(self) -> None: + """Test update raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting("app-1", "setting-1", {"score_threshold": 0.5}) + + def test_update_app_annotation_setting_should_raise_not_found_when_setting_missing(self) -> None: + """Test update raises NotFound when setting is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting(app.id, "setting-1", {"score_threshold": 0.5}) + + +class TestAppAnnotationServiceClearAll: + """Test suite for clear_all_annotations.""" + + def test_clear_all_annotations_should_delete_annotations_and_histories(self) -> None: + """Test clear_all_annotations deletes all data and triggers index removal.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + history = MagicMock(spec=AppAnnotationHitHistory) + + def query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if App in args: + query.first.return_value = app + elif AppAnnotationSetting in args: + query.first.return_value = setting + elif MessageAnnotation in args: + query.yield_per.return_value = [annotation1, annotation2] + elif AppAnnotationHitHistory in args: + query.yield_per.return_value = [history] + return query + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + mock_db.session.query.side_effect = query_side_effect + + # Act + result = AppAnnotationService.clear_all_annotations(app.id) + + # Assert + assert result == {"result": "success"} + mock_db.session.delete.assert_any_call(annotation1) + mock_db.session.delete.assert_any_call(annotation2) + mock_db.session.delete.assert_any_call(history) + mock_task.delay.assert_any_call(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_task.delay.assert_any_call(annotation2.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + def test_clear_all_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.clear_all_annotations("app-1") diff --git a/api/tests/unit_tests/services/test_api_based_extension_service.py b/api/tests/unit_tests/services/test_api_based_extension_service.py new file mode 100644 index 0000000000..7f4b5fdaa3 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_based_extension_service.py @@ -0,0 +1,421 @@ +""" +Comprehensive unit tests for services/api_based_extension_service.py + +Covers: +- APIBasedExtensionService.get_all_by_tenant_id +- APIBasedExtensionService.save +- APIBasedExtensionService.delete +- APIBasedExtensionService.get_with_tenant_id +- APIBasedExtensionService._validation (new record & existing record branches) +- APIBasedExtensionService._ping_connection (pong success, wrong response, exception) +""" + +from unittest.mock import MagicMock, patch + +import pytest + +from services.api_based_extension_service import APIBasedExtensionService + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_extension( + *, + id_: str | None = None, + tenant_id: str = "tenant-001", + name: str = "my-ext", + api_endpoint: str = "https://example.com/hook", + api_key: str = "secret-key-123", +) -> MagicMock: + """Return a lightweight mock that mimics APIBasedExtension.""" + ext = MagicMock() + ext.id = id_ + ext.tenant_id = tenant_id + ext.name = name + ext.api_endpoint = api_endpoint + ext.api_key = api_key + return ext + + +# --------------------------------------------------------------------------- +# Tests: get_all_by_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetAllByTenantId: + """Tests for APIBasedExtensionService.get_all_by_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extensions_with_decrypted_keys(self, mock_db, mock_decrypt): + """Each api_key is decrypted and the list is returned.""" + ext1 = _make_extension(id_="id-1", api_key="enc-key-1") + ext2 = _make_extension(id_="id-2", api_key="enc-key-2") + + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [ + ext1, + ext2, + ] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [ext1, ext2] + assert ext1.api_key == "decrypted-key" + assert ext2.api_key == "decrypted-key" + assert mock_decrypt.call_count == 2 + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_empty_list_when_no_extensions(self, mock_db, mock_decrypt): + """Returns an empty list gracefully when no records exist.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = APIBasedExtensionService.get_all_by_tenant_id("tenant-001") + + assert result == [] + mock_decrypt.assert_not_called() + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_calls_query_with_correct_tenant_id(self, mock_db, mock_decrypt): + """Verifies the DB is queried with the supplied tenant_id.""" + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + APIBasedExtensionService.get_all_by_tenant_id("tenant-xyz") + + mock_db.session.query.return_value.filter_by.assert_called_once_with(tenant_id="tenant-xyz") + + +# --------------------------------------------------------------------------- +# Tests: save +# --------------------------------------------------------------------------- + + +class TestSave: + """Tests for APIBasedExtensionService.save.""" + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation") + def test_save_new_record_encrypts_key_and_commits(self, mock_validation, mock_db, mock_encrypt): + """Happy path: validation passes, key is encrypted, record is added and committed.""" + ext = _make_extension(id_=None, api_key="plain-key-123") + + result = APIBasedExtensionService.save(ext) + + mock_validation.assert_called_once_with(ext) + mock_encrypt.assert_called_once_with(ext.tenant_id, "plain-key-123") + assert ext.api_key == "encrypted-key" + mock_db.session.add.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + assert result is ext + + @patch("services.api_based_extension_service.encrypt_token", return_value="encrypted-key") + @patch("services.api_based_extension_service.db") + @patch.object(APIBasedExtensionService, "_validation", side_effect=ValueError("name must not be empty")) + def test_save_raises_when_validation_fails(self, mock_validation, mock_db, mock_encrypt): + """If _validation raises, save should propagate the error without touching the DB.""" + ext = _make_extension(name="") + + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService.save(ext) + + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +# --------------------------------------------------------------------------- +# Tests: delete +# --------------------------------------------------------------------------- + + +class TestDelete: + """Tests for APIBasedExtensionService.delete.""" + + @patch("services.api_based_extension_service.db") + def test_delete_removes_record_and_commits(self, mock_db): + """delete() must call session.delete with the extension and then commit.""" + ext = _make_extension(id_="delete-me") + + APIBasedExtensionService.delete(ext) + + mock_db.session.delete.assert_called_once_with(ext) + mock_db.session.commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# Tests: get_with_tenant_id +# --------------------------------------------------------------------------- + + +class TestGetWithTenantId: + """Tests for APIBasedExtensionService.get_with_tenant_id.""" + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_returns_extension_with_decrypted_key(self, mock_db, mock_decrypt): + """Found extension has its api_key decrypted before being returned.""" + ext = _make_extension(id_="ext-123", api_key="enc-key") + + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = ext + + result = APIBasedExtensionService.get_with_tenant_id("tenant-001", "ext-123") + + assert result is ext + assert ext.api_key == "decrypted-key" + mock_decrypt.assert_called_once_with(ext.tenant_id, "enc-key") + + @patch("services.api_based_extension_service.db") + def test_raises_value_error_when_not_found(self, mock_db): + """Raises ValueError when no matching extension exists.""" + (mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value) = None + + with pytest.raises(ValueError, match="API based extension is not found"): + APIBasedExtensionService.get_with_tenant_id("tenant-001", "non-existent") + + @patch("services.api_based_extension_service.decrypt_token", return_value="decrypted-key") + @patch("services.api_based_extension_service.db") + def test_queries_with_correct_tenant_and_extension_id(self, mock_db, mock_decrypt): + """Verifies both tenant_id and extension id are used in the query.""" + ext = _make_extension(id_="ext-abc") + chain = mock_db.session.query.return_value + chain.filter_by.return_value.filter_by.return_value.first.return_value = ext + + APIBasedExtensionService.get_with_tenant_id("tenant-002", "ext-abc") + + # First filter_by call uses tenant_id + chain.filter_by.assert_called_once_with(tenant_id="tenant-002") + # Second filter_by call uses id + chain.filter_by.return_value.filter_by.assert_called_once_with(id="ext-abc") + + +# --------------------------------------------------------------------------- +# Tests: _validation (new record — id is falsy) +# --------------------------------------------------------------------------- + + +class TestValidationNewRecord: + """Tests for _validation() with a brand-new record (no id).""" + + def _build_mock_db(self, name_exists: bool = False): + mock_db = MagicMock() + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() if name_exists else None + ) + return mock_db + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_new_extension_passes(self, mock_db, mock_ping): + """A new record with all valid fields should pass without exceptions.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, name="valid-ext", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_empty(self, mock_db): + """Empty name raises ValueError.""" + ext = _make_extension(id_=None, name="") + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_is_none(self, mock_db): + """None name raises ValueError.""" + ext = _make_extension(id_=None, name=None) + with pytest.raises(ValueError, match="name must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_name_already_exists_for_new_record(self, mock_db): + """A new record whose name already exists raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = ( + MagicMock() + ) + ext = _make_extension(id_=None, name="duplicate-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_empty(self, mock_db): + """Empty api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint="") + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_endpoint_is_none(self, mock_db): + """None api_endpoint raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_endpoint=None) + + with pytest.raises(ValueError, match="api_endpoint must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_empty(self, mock_db): + """Empty api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="") + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_is_none(self, mock_db): + """None api_key raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key=None) + + with pytest.raises(ValueError, match="api_key must not be empty"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_too_short(self, mock_db): + """api_key shorter than 5 characters raises ValueError.""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="abc") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_api_key_exactly_four_chars(self, mock_db): + """api_key with exactly 4 characters raises ValueError (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="1234") + + with pytest.raises(ValueError, match="api_key must be at least 5 characters"): + APIBasedExtensionService._validation(ext) + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_api_key_exactly_five_chars_is_accepted(self, mock_db, mock_ping): + """api_key with exactly 5 characters should pass (boundary condition).""" + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.first.return_value = None + ext = _make_extension(id_=None, api_key="12345") + + # Should not raise + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _validation (existing record — id is truthy) +# --------------------------------------------------------------------------- + + +class TestValidationExistingRecord: + """Tests for _validation() with an existing record (id is set).""" + + @patch.object(APIBasedExtensionService, "_ping_connection") + @patch("services.api_based_extension_service.db") + def test_valid_existing_extension_passes(self, mock_db, mock_ping): + """An existing record whose name is unique (excluding self) should pass.""" + # .where(...).first() → None means no *other* record has that name + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = None + ext = _make_extension(id_="existing-id", name="unique-name", api_key="longenoughkey") + + # Should not raise + APIBasedExtensionService._validation(ext) + mock_ping.assert_called_once_with(ext) + + @patch("services.api_based_extension_service.db") + def test_raises_if_existing_record_name_conflicts_with_another(self, mock_db): + """Existing record cannot use a name already owned by a different record.""" + ( + mock_db.session.query.return_value.filter_by.return_value.filter_by.return_value.where.return_value.first.return_value + ) = MagicMock() + ext = _make_extension(id_="existing-id", name="taken-name") + + with pytest.raises(ValueError, match="name must be unique, it is already existed"): + APIBasedExtensionService._validation(ext) + + +# --------------------------------------------------------------------------- +# Tests: _ping_connection +# --------------------------------------------------------------------------- + + +class TestPingConnection: + """Tests for APIBasedExtensionService._ping_connection.""" + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_successful_ping_returns_pong(self, mock_requestor_class): + """When the endpoint returns {"result": "pong"}, no exception is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension(api_endpoint="https://ok.example.com", api_key="secret-key") + # Should not raise + APIBasedExtensionService._ping_connection(ext) + + mock_requestor_class.assert_called_once_with(ext.api_endpoint, ext.api_key) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_wrong_ping_response_raises_value_error(self, mock_requestor_class): + """When the response is not {"result": "pong"}, a ValueError is raised.""" + mock_client = MagicMock() + mock_client.request.return_value = {"result": "error"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_network_exception_wraps_in_value_error(self, mock_requestor_class): + """Any exception raised during request is wrapped in a ValueError.""" + mock_client = MagicMock() + mock_client.request.side_effect = ConnectionError("network failure") + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: network failure"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_requestor_constructor_exception_wraps_in_value_error(self, mock_requestor_class): + """Exception raised by the requestor constructor itself is wrapped.""" + mock_requestor_class.side_effect = RuntimeError("bad config") + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error: bad config"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_missing_result_key_raises_value_error(self, mock_requestor_class): + """A response dict without a 'result' key does not equal 'pong' → raises.""" + mock_client = MagicMock() + mock_client.request.return_value = {} # no 'result' key + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + with pytest.raises(ValueError, match="connection error"): + APIBasedExtensionService._ping_connection(ext) + + @patch("services.api_based_extension_service.APIBasedExtensionRequestor") + def test_uses_ping_extension_point(self, mock_requestor_class): + """The PING extension point is passed to the client.request call.""" + from models.api_based_extension import APIBasedExtensionPoint + + mock_client = MagicMock() + mock_client.request.return_value = {"result": "pong"} + mock_requestor_class.return_value = mock_client + + ext = _make_extension() + APIBasedExtensionService._ping_connection(ext) + + call_kwargs = mock_client.request.call_args + assert call_kwargs.kwargs["point"] == APIBasedExtensionPoint.PING + assert call_kwargs.kwargs["params"] == {} diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/unit_tests/services/test_api_token_service.py new file mode 100644 index 0000000000..ad4de93b25 --- /dev/null +++ b/api/tests/unit_tests/services/test_api_token_service.py @@ -0,0 +1,466 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Unauthorized + +import services.api_token_service as api_token_service_module +from services.api_token_service import ApiTokenCache, CachedApiToken + + +@pytest.fixture +def mock_db_session(): + """Fixture providing common DB session mocking for query_token_from_db tests.""" + fake_engine = MagicMock() + + session = MagicMock() + session_context = MagicMock() + session_context.__enter__.return_value = session + session_context.__exit__.return_value = None + + with ( + patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + yield { + "session": session, + "mock_session_class": mock_session_class, + "mock_cache_set": mock_cache_set, + "mock_record_usage": mock_record_usage, + "fake_engine": fake_engine, + } + + +class TestQueryTokenFromDb: + def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): + """Test DB lookup success path caches token and records usage.""" + # Arrange + auth_token = "token-123" + scope = "app" + api_token = MagicMock() + + mock_db_session["session"].scalar.return_value = api_token + + # Act + result = api_token_service_module.query_token_from_db(auth_token, scope) + + # Assert + assert result == api_token + mock_db_session["mock_session_class"].assert_called_once_with( + mock_db_session["fake_engine"], expire_on_commit=False + ) + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) + mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + + def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): + """Test DB lookup miss path caches null marker and raises Unauthorized.""" + # Arrange + auth_token = "missing-token" + scope = "app" + + mock_db_session["session"].scalar.return_value = None + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(auth_token, scope) + + mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) + mock_db_session["mock_record_usage"].assert_not_called() + + +class TestRecordTokenUsage: + def test_should_write_active_key_with_iso_timestamp_and_ttl(self): + """Test record_token_usage writes usage timestamp with one-hour TTL.""" + # Arrange + auth_token = "token-123" + scope = "dataset" + fixed_time = datetime(2026, 2, 24, 12, 0, 0) + expected_key = ApiTokenCache.make_active_key(auth_token, scope) + + with ( + patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), + patch.object(api_token_service_module, "redis_client") as mock_redis, + ): + # Act + api_token_service_module.record_token_usage(auth_token, scope) + + # Assert + mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) + + def test_should_not_raise_when_redis_write_fails(self): + """Test record_token_usage swallows Redis errors.""" + # Arrange + with patch.object(api_token_service_module, "redis_client") as mock_redis: + mock_redis.set.side_effect = Exception("redis unavailable") + + # Act / Assert + api_token_service_module.record_token_usage("token-123", "app") + + +class TestFetchTokenWithSingleFlight: + def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): + """Test single-flight returns cache when another request already populated it.""" + # Arrange + auth_token = "token-123" + scope = "app" + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=auth_token, + last_used_at=None, + created_at=None, + ) + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == cached_token + mock_redis.lock.assert_called_once_with( + f"api_token_query_lock:{scope}:{auth_token}", + timeout=10, + blocking_timeout=5, + ) + lock.acquire.assert_called_once_with(blocking=True) + lock.release.assert_called_once() + mock_cache_get.assert_called_once_with(auth_token, scope) + mock_query_db.assert_not_called() + + def test_should_query_db_when_lock_acquired_and_cache_missed(self): + """Test single-flight queries DB when cache remains empty after lock acquisition.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_called_once() + + def test_should_query_db_directly_when_lock_not_acquired(self): + """Test lock timeout branch falls back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.return_value = False + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get") as mock_cache_get, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_cache_get.assert_not_called() + mock_query_db.assert_called_once_with(auth_token, scope) + lock.release.assert_not_called() + + def test_should_reraise_unauthorized_from_db_query(self): + """Test Unauthorized from DB query is propagated unchanged.""" + # Arrange + auth_token = "token-123" + scope = "app" + lock = MagicMock() + lock.acquire.return_value = True + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None), + patch.object( + api_token_service_module, + "query_token_from_db", + side_effect=Unauthorized("Access token is invalid"), + ), + ): + mock_redis.lock.return_value = lock + + # Act / Assert + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + lock.release.assert_called_once() + + def test_should_fallback_to_db_query_when_lock_raises_exception(self): + """Test Redis lock errors fall back to direct DB query.""" + # Arrange + auth_token = "token-123" + scope = "app" + db_token = MagicMock() + + lock = MagicMock() + lock.acquire.side_effect = RuntimeError("redis lock error") + + with ( + patch.object(api_token_service_module, "redis_client") as mock_redis, + patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, + ): + mock_redis.lock.return_value = lock + + # Act + result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) + + # Assert + assert result == db_token + mock_query_db.assert_called_once_with(auth_token, scope) + + +class TestApiTokenCacheTenantBranches: + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): + """Test scoped delete removes cache key and tenant index membership.""" + # Arrange + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + cached_token = CachedApiToken( + id="id-1", + app_id="app-1", + tenant_id="tenant-1", + type="app", + token=token, + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") + + with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: + # Act + result = ApiTokenCache.delete(token, scope) + + # Assert + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + mock_remove_index.assert_called_once_with("tenant-1", cache_key) + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): + """Test tenant invalidation deletes indexed cache entries and index key.""" + # Arrange + tenant_id = "tenant-1" + index_key = ApiTokenCache._make_tenant_index_key(tenant_id) + mock_redis.smembers.return_value = { + b"api_token:app:token-1", + b"api_token:any:token-2", + } + + # Act + result = ApiTokenCache.invalidate_by_tenant(tenant_id) + + # Assert + assert result is True + mock_redis.smembers.assert_called_once_with(index_key) + mock_redis.delete.assert_any_call("api_token:app:token-1") + mock_redis.delete.assert_any_call("api_token:any:token-2") + mock_redis.delete.assert_any_call(index_key) + + +class TestApiTokenCacheCoreBranches: + def test_cached_api_token_repr_should_include_id_and_type(self): + """Test CachedApiToken __repr__ includes key identity fields.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + assert repr(token) == "" + + def test_serialize_token_should_handle_cached_api_token_instances(self): + """Test serialization path when input is already a CachedApiToken.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + + serialized = ApiTokenCache._serialize_token(token) + + assert isinstance(serialized, bytes) + assert b'"id":"id-123"' in serialized + assert b'"token":"token-123"' in serialized + + def test_deserialize_token_should_return_none_for_null_markers(self): + """Test null cache marker deserializes to None.""" + assert ApiTokenCache._deserialize_token("null") is None + assert ApiTokenCache._deserialize_token(b"null") is None + + def test_deserialize_token_should_return_none_for_invalid_payload(self): + """Test invalid serialized payload returns None.""" + assert ApiTokenCache._deserialize_token("not-json") is None + + @patch("services.api_token_service.redis_client") + def test_get_should_return_none_on_cache_miss(self, mock_redis): + """Test cache miss branch in ApiTokenCache.get.""" + mock_redis.get.return_value = None + + result = ApiTokenCache.get("token-123", "app") + + assert result is None + mock_redis.get.assert_called_once_with("api_token:app:token-123") + + @patch("services.api_token_service.redis_client") + def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): + """Test cache hit branch in ApiTokenCache.get.""" + token = CachedApiToken( + id="id-123", + app_id="app-123", + tenant_id="tenant-123", + type="app", + token="token-123", + last_used_at=None, + created_at=None, + ) + mock_redis.get.return_value = token.model_dump_json().encode("utf-8") + + result = ApiTokenCache.get("token-123", "app") + + assert isinstance(result, CachedApiToken) + assert result.id == "id-123" + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index update exits early for missing tenant id.""" + ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") + + mock_redis.sadd.assert_not_called() + mock_redis.expire.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): + """Test tenant index update handles Redis write errors gracefully.""" + mock_redis.sadd.side_effect = Exception("redis down") + + ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.sadd.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): + """Test tenant index removal exits early for missing tenant id.""" + ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") + + mock_redis.srem.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): + """Test tenant index removal handles Redis errors gracefully.""" + mock_redis.srem.side_effect = Exception("redis down") + + ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") + + mock_redis.srem.assert_called_once() + + @patch("services.api_token_service.redis_client") + def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): + """Test set returns False when Redis setex fails.""" + mock_redis.setex.side_effect = Exception("redis write failed") + api_token = MagicMock() + api_token.id = "id-123" + api_token.app_id = "app-123" + api_token.tenant_id = "tenant-123" + api_token.type = "app" + api_token.token = "token-123" + api_token.last_used_at = None + api_token.created_at = None + + result = ApiTokenCache.set("token-123", "app", api_token) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): + """Test delete(scope=None) returns False when scan_iter raises.""" + mock_redis.scan_iter.side_effect = Exception("scan failed") + + result = ApiTokenCache.delete("token-123", None) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): + """Test scoped delete still succeeds when tenant lookup from cache fails.""" + token = "token-123" + scope = "app" + cache_key = ApiTokenCache._make_cache_key(token, scope) + mock_redis.get.side_effect = Exception("get failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is True + mock_redis.delete.assert_called_once_with(cache_key) + + @patch("services.api_token_service.redis_client") + def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): + """Test scoped delete returns False when delete operation fails.""" + token = "token-123" + scope = "app" + mock_redis.get.return_value = None + mock_redis.delete.side_effect = Exception("delete failed") + + result = ApiTokenCache.delete(token, scope) + + assert result is False + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): + """Test tenant invalidation returns True when tenant index is empty.""" + mock_redis.smembers.return_value = set() + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is True + mock_redis.delete.assert_not_called() + + @patch("services.api_token_service.redis_client") + def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): + """Test tenant invalidation returns False when Redis operation fails.""" + mock_redis.smembers.side_effect = Exception("redis failed") + + result = ApiTokenCache.invalidate_by_tenant("tenant-123") + + assert result is False diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py new file mode 100644 index 0000000000..4f7d184046 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -0,0 +1,958 @@ +import base64 +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +import yaml + +from core.trigger.constants import ( + TRIGGER_PLUGIN_NODE_TYPE, + TRIGGER_SCHEDULE_NODE_TYPE, + TRIGGER_WEBHOOK_NODE_TYPE, +) +from dify_graph.enums import BuiltinNodeTypes +from models import Account, AppMode +from models.model import IconType +from services import app_dsl_service +from services.app_dsl_service import ( + AppDslService, + CheckDependenciesPendingData, + ImportMode, + ImportStatus, + PendingData, + _check_version_compatibility, +) + + +class _FakeHttpResponse: + def __init__(self, content: bytes, *, raises: Exception | None = None): + self.content = content + self._raises = raises + + def raise_for_status(self) -> None: + if self._raises is not None: + raise self._raises + + +def _account_mock(*, tenant_id: str = "tenant-1", account_id: str = "account-1") -> MagicMock: + account = MagicMock(spec=Account) + account.current_tenant_id = tenant_id + account.id = account_id + return account + + +def _yaml_dump(data: dict) -> str: + return yaml.safe_dump(data, allow_unicode=True) + + +def _workflow_yaml(*, version: str = app_dsl_service.CURRENT_DSL_VERSION) -> str: + return _yaml_dump( + { + "version": version, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + } + ) + + +def test_check_version_compatibility_invalid_version_returns_failed(): + assert _check_version_compatibility("not-a-version") == ImportStatus.FAILED + + +def test_check_version_compatibility_newer_version_returns_pending(): + assert _check_version_compatibility("99.0.0") == ImportStatus.PENDING + + +def test_check_version_compatibility_major_older_returns_pending(monkeypatch): + monkeypatch.setattr(app_dsl_service, "CURRENT_DSL_VERSION", "1.0.0") + assert _check_version_compatibility("0.9.9") == ImportStatus.PENDING + + +def test_check_version_compatibility_minor_older_returns_completed_with_warnings(): + assert _check_version_compatibility("0.5.0") == ImportStatus.COMPLETED_WITH_WARNINGS + + +def test_check_version_compatibility_equal_returns_completed(): + assert _check_version_compatibility(app_dsl_service.CURRENT_DSL_VERSION) == ImportStatus.COMPLETED + + +def test_import_app_invalid_import_mode_raises_value_error(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid import_mode"): + service.import_app(account=_account_mock(), import_mode="invalid-mode", yaml_content="version: '0.1.0'") + + +def test_import_app_yaml_url_requires_url(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url=None) + assert result.status == ImportStatus.FAILED + assert "yaml_url is required" in result.error + + +def test_import_app_yaml_content_requires_content(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=None) + assert result.status == ImportStatus.FAILED + assert "yaml_content is required" in result.error + + +def test_import_app_yaml_url_fetch_error_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Error fetching YAML from URL: boom" in result.error + + +def test_import_app_yaml_url_empty_content_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"") + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "Empty content" in result.error + + +def test_import_app_yaml_url_file_too_large_returns_failed(monkeypatch): + def fake_get(_url: str, **_kwargs): + return _FakeHttpResponse(b"x" * (app_dsl_service.DSL_MAX_SIZE + 1)) + + monkeypatch.setattr(app_dsl_service.ssrf_proxy, "get", fake_get) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_URL, yaml_url="https://example.com/a.yml" + ) + assert result.status == ImportStatus.FAILED + assert "File size exceeds" in result.error + + +def test_import_app_yaml_not_mapping_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="[]") + assert result.status == ImportStatus.FAILED + assert "content must be a mapping" in result.error + + +def test_import_app_version_not_str_returns_failed(): + service = AppDslService(MagicMock()) + yaml_content = _yaml_dump({"version": 1, "kind": "app", "app": {"name": "x", "mode": "workflow"}}) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=yaml_content) + assert result.status == ImportStatus.FAILED + assert "Invalid version type" in result.error + + +def test_import_app_missing_app_data_returns_failed(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump({"version": "0.6.0", "kind": "app"}), + ) + assert result.status == ImportStatus.FAILED + assert "Missing app data" in result.error + + +def test_import_app_app_id_not_found_returns_failed(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="missing-app", + ) + assert result.status == ImportStatus.FAILED + assert result.error == "App not found" + + +def test_import_app_overwrite_only_allows_workflow_and_advanced_chat(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + existing_app = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.CHAT.value) + + session = MagicMock() + session.scalar.return_value = existing_app + service = AppDslService(session) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + app_id="app-1", + ) + assert result.status == ImportStatus.FAILED + assert "Only workflow or advanced chat apps" in result.error + + +def test_import_app_pending_stores_import_info_in_redis(): + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(version="99.0.0"), + name="n", + description="d", + icon_type="emoji", + icon="i", + icon_background="#000000", + ) + assert result.status == ImportStatus.PENDING + assert result.imported_dsl_version == "99.0.0" + + app_dsl_service.redis_client.setex.assert_called_once() + call = app_dsl_service.redis_client.setex.call_args + redis_key = call.args[0] + assert redis_key.startswith(app_dsl_service.IMPORT_INFO_REDIS_KEY_PREFIX) + + +def test_import_app_completed_uses_declared_dependencies(monkeypatch): + dependencies_payload = [{"id": "langgenius/google", "version": "1.0.0"}] + + plugin_deps = [SimpleNamespace(model_dump=lambda: dependencies_payload[0])] + monkeypatch.setattr( + app_dsl_service.PluginDependency, + "model_validate", + lambda d: plugin_deps[0], + ) + + created_app = SimpleNamespace(id="app-new", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_yaml_dump( + { + "version": app_dsl_service.CURRENT_DSL_VERSION, + "kind": "app", + "app": {"name": "My App", "mode": AppMode.WORKFLOW.value}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + "dependencies": dependencies_payload, + } + ), + ) + + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "app-new" + draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-new") + + +@pytest.mark.parametrize("has_workflow", [True, False]) +def test_import_app_legacy_versions_extract_dependencies(monkeypatch, has_workflow: bool): + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_workflow_graph", + lambda *_args, **_kwargs: ["from-workflow"], + ) + monkeypatch.setattr( + AppDslService, + "_extract_dependencies_from_model_config", + lambda *_args, **_kwargs: ["from-model-config"], + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_latest_dependencies", + lambda deps: [SimpleNamespace(model_dump=lambda: {"dep": deps[0]})], + ) + + created_app = SimpleNamespace(id="app-legacy", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + draft_var_service = MagicMock() + monkeypatch.setattr(app_dsl_service, "WorkflowDraftVariableService", lambda *args, **kwargs: draft_var_service) + + data: dict = { + "version": "0.1.5", + "kind": "app", + "app": {"name": "Legacy", "mode": AppMode.WORKFLOW.value}, + } + if has_workflow: + data["workflow"] = {"graph": {"nodes": []}, "features": {}} + else: + data["model_config"] = {"model": {"provider": "openai"}} + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_yaml_dump(data) + ) + assert result.status == ImportStatus.COMPLETED_WITH_WARNINGS + draft_var_service.delete_app_workflow_variables.assert_called_once_with(app_id="app-legacy") + + +def test_import_app_yaml_error_returns_failed(monkeypatch): + def bad_safe_load(_content: str): + raise yaml.YAMLError("bad") + + monkeypatch.setattr(app_dsl_service.yaml, "safe_load", bad_safe_load) + + service = AppDslService(MagicMock()) + result = service.import_app(account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content="x: y") + assert result.status == ImportStatus.FAILED + assert result.error.startswith("Invalid YAML format:") + + +def test_import_app_unexpected_error_returns_failed(monkeypatch): + monkeypatch.setattr( + AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("oops")) + ) + + service = AppDslService(MagicMock()) + result = service.import_app( + account=_account_mock(), import_mode=ImportMode.YAML_CONTENT, yaml_content=_workflow_yaml() + ) + assert result.status == ImportStatus.FAILED + assert result.error == "oops" + + +def test_confirm_import_expired_returns_failed(): + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "expired" in result.error + + +def test_confirm_import_invalid_pending_data_type_returns_failed(): + app_dsl_service.redis_client.get.return_value = 123 + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert "Invalid import information" in result.error + + +def test_confirm_import_success_deletes_redis_key(monkeypatch): + def fake_select(_model): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(app_dsl_service, "select", fake_select) + + session = MagicMock() + session.scalar.return_value = None + service = AppDslService(session) + + pending = PendingData( + import_mode=ImportMode.YAML_CONTENT, + yaml_content=_workflow_yaml(), + name="name", + description="desc", + icon_type="emoji", + icon="🤖", + icon_background="#fff", + app_id=None, + ) + app_dsl_service.redis_client.get.return_value = pending.model_dump_json() + + created_app = SimpleNamespace(id="confirmed-app", mode=AppMode.WORKFLOW.value, tenant_id="tenant-1") + monkeypatch.setattr(AppDslService, "_create_or_update_app", lambda *_args, **_kwargs: created_app) + + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.COMPLETED + assert result.app_id == "confirmed-app" + app_dsl_service.redis_client.delete.assert_called_once() + + +def test_confirm_import_exception_returns_failed(monkeypatch): + app_dsl_service.redis_client.get.return_value = "not-json" + monkeypatch.setattr( + PendingData, "model_validate_json", lambda *_args, **_kwargs: (_ for _ in ()).throw(ValueError("bad")) + ) + + service = AppDslService(MagicMock()) + result = service.confirm_import(import_id="import-1", account=_account_mock()) + assert result.status == ImportStatus.FAILED + assert result.error == "bad" + + +def test_check_dependencies_returns_empty_when_no_redis_data(): + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert result.leaked_dependencies == [] + + +def test_check_dependencies_calls_analysis_service(monkeypatch): + pending = CheckDependenciesPendingData(dependencies=[], app_id="app-1").model_dump_json() + app_dsl_service.redis_client.get.return_value = pending + dep = app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [dep], + ) + + service = AppDslService(MagicMock()) + result = service.check_dependencies(app_model=SimpleNamespace(id="app-1", tenant_id="tenant-1")) + assert len(result.leaked_dependencies) == 1 + + +def test_create_or_update_app_missing_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="loss app mode"): + service._create_or_update_app(app=None, data={"app": {}}, account=_account_mock()) + + +def test_create_or_update_app_existing_app_updates_fields(monkeypatch): + fixed_now = object() + monkeypatch.setattr(app_dsl_service, "naive_utc_now", lambda: fixed_now) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.WORKFLOW.value, + name="old", + description="old-desc", + icon_type=IconType.EMOJI, + icon="old-icon", + icon_background="#111111", + updated_by=None, + updated_at=None, + app_model_config=None, + ) + service = AppDslService(MagicMock()) + updated = service._create_or_update_app( + app=app, + data={ + "app": {"mode": AppMode.WORKFLOW.value, "name": "yaml-name", "icon_type": IconType.IMAGE, "icon": "X"}, + "workflow": {"graph": {"nodes": []}, "features": {}}, + }, + account=_account_mock(), + name="override-name", + description=None, + icon_background="#222222", + ) + assert updated is app + assert app.name == "override-name" + assert app.icon_type == IconType.IMAGE + assert app.icon == "X" + assert app.icon_background == "#222222" + assert app.updated_at is fixed_now + + +def test_create_or_update_app_new_app_requires_tenant(): + account = _account_mock() + account.current_tenant_id = None + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Current tenant is not set"): + service._create_or_update_app( + app=None, + data={"app": {"mode": AppMode.WORKFLOW.value, "name": "n"}}, + account=account, + ) + + +def test_create_or_update_app_creates_workflow_app_and_saves_dependencies(monkeypatch): + class DummyApp(SimpleNamespace): + pass + + monkeypatch.setattr(app_dsl_service, "App", DummyApp) + + sent: list[tuple[str, object]] = [] + monkeypatch.setattr(app_dsl_service.app_was_created, "send", lambda app, account: sent.append((app.id, account.id))) + + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = SimpleNamespace(unique_hash="uh") + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_environment_variable_from_mapping", + lambda _m: SimpleNamespace(kind="env"), + ) + monkeypatch.setattr( + app_dsl_service.variable_factory, + "build_conversation_variable_from_mapping", + lambda _m: SimpleNamespace(kind="conv"), + ) + + monkeypatch.setattr( + AppDslService, "decrypt_dataset_id", lambda *_args, **_kwargs: "00000000-0000-0000-0000-000000000000" + ) + + session = MagicMock() + service = AppDslService(session) + deps = [ + app_dsl_service.PluginDependency.model_validate( + {"type": "package", "value": {"plugin_unique_identifier": "acme/foo", "version": "1.0.0"}} + ) + ] + data = { + "app": {"mode": AppMode.WORKFLOW.value, "name": "n"}, + "workflow": { + "environment_variables": [{"x": 1}], + "conversation_variables": [{"y": 2}], + "graph": { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["enc-1", "enc-2"]}}, + ] + }, + "features": {}, + }, + } + + app = service._create_or_update_app(app=None, data=data, account=_account_mock(), dependencies=deps) + + assert app.tenant_id == "tenant-1" + assert sent == [(app.id, "account-1")] + app_dsl_service.redis_client.setex.assert_called() + workflow_service.sync_draft_workflow.assert_called_once() + + passed_graph = workflow_service.sync_draft_workflow.call_args.kwargs["graph"] + dataset_ids = passed_graph["nodes"][0]["data"]["dataset_ids"] + assert dataset_ids == ["00000000-0000-0000-0000-000000000000", "00000000-0000-0000-0000-000000000000"] + + +def test_create_or_update_app_workflow_missing_workflow_data_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing workflow data"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.WORKFLOW.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.WORKFLOW.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_requires_model_config(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Missing model_config"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.CHAT.value}}, + account=_account_mock(), + ) + + +def test_create_or_update_app_chat_creates_model_config_and_sends_event(monkeypatch): + class DummyModelConfig(SimpleNamespace): + def from_model_config_dict(self, _cfg: dict): + return self + + monkeypatch.setattr(app_dsl_service, "AppModelConfig", DummyModelConfig) + + sent: list[str] = [] + monkeypatch.setattr( + app_dsl_service.app_model_config_was_updated, "send", lambda app, app_model_config: sent.append(app.id) + ) + + session = MagicMock() + service = AppDslService(session) + + app = SimpleNamespace( + id="app-1", + tenant_id="tenant-1", + mode=AppMode.CHAT.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ) + service._create_or_update_app( + app=app, + data={"app": {"mode": AppMode.CHAT.value}, "model_config": {"model": {"provider": "openai"}}}, + account=_account_mock(), + ) + + assert app.app_model_config_id is not None + assert sent == ["app-1"] + session.add.assert_called() + + +def test_create_or_update_app_invalid_mode_raises(): + service = AppDslService(MagicMock()) + with pytest.raises(ValueError, match="Invalid app mode"): + service._create_or_update_app( + app=SimpleNamespace( + id="a", + tenant_id="t", + mode=AppMode.RAG_PIPELINE.value, + name="n", + description="d", + icon_background="#fff", + app_model_config=None, + ), + data={"app": {"mode": AppMode.RAG_PIPELINE.value}}, + account=_account_mock(), + ) + + +def test_export_dsl_delegates_by_mode(monkeypatch): + workflow_calls: list[bool] = [] + model_calls: list[bool] = [] + monkeypatch.setattr(AppDslService, "_append_workflow_export_data", lambda **_kwargs: workflow_calls.append(True)) + monkeypatch.setattr( + AppDslService, "_append_model_config_export_data", lambda *_args, **_kwargs: model_calls.append(True) + ) + + workflow_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + AppDslService.export_dsl(workflow_app) + assert workflow_calls == [True] + + chat_app = SimpleNamespace( + mode=AppMode.CHAT.value, + tenant_id="tenant-1", + name="n", + icon="i", + icon_type="emoji", + icon_background="#fff", + description="d", + use_icon_as_answer_icon=False, + app_model_config=SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": []}}), + ) + AppDslService.export_dsl(chat_app) + assert model_calls == [True] + + +def test_export_dsl_preserves_icon_and_icon_type(monkeypatch): + monkeypatch.setattr(AppDslService, "_append_workflow_export_data", lambda **_kwargs: None) + + emoji_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id="tenant-1", + name="Emoji App", + icon="🎨", + icon_type=IconType.EMOJI, + icon_background="#FF5733", + description="App with emoji icon", + use_icon_as_answer_icon=True, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(emoji_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "🎨" + assert data["app"]["icon_type"] == "emoji" + assert data["app"]["icon_background"] == "#FF5733" + + image_app = SimpleNamespace( + mode=AppMode.WORKFLOW.value, + tenant_id="tenant-1", + name="Image App", + icon="https://example.com/icon.png", + icon_type=IconType.IMAGE, + icon_background="#FFEAD5", + description="App with image icon", + use_icon_as_answer_icon=False, + app_model_config=None, + ) + yaml_output = AppDslService.export_dsl(image_app) + data = yaml.safe_load(yaml_output) + assert data["app"]["icon"] == "https://example.com/icon.png" + assert data["app"]["icon_type"] == "image" + assert data["app"]["icon_background"] == "#FFEAD5" + + +def test_append_workflow_export_data_filters_and_overrides(monkeypatch): + workflow_dict = { + "graph": { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "dataset_ids": ["d1", "d2"]}}, + {"data": {"type": BuiltinNodeTypes.TOOL, "credential_id": "secret"}}, + { + "data": { + "type": BuiltinNodeTypes.AGENT, + "agent_parameters": {"tools": {"value": [{"credential_id": "secret"}]}}, + } + }, + {"data": {"type": TRIGGER_SCHEDULE_NODE_TYPE, "config": {"x": 1}}}, + {"data": {"type": TRIGGER_WEBHOOK_NODE_TYPE, "webhook_url": "x", "webhook_debug_url": "y"}}, + {"data": {"type": TRIGGER_PLUGIN_NODE_TYPE, "subscription_id": "s"}}, + ] + } + } + + workflow = SimpleNamespace(to_dict=lambda *, include_secret: workflow_dict) + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = workflow + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + monkeypatch.setattr( + AppDslService, "encrypt_dataset_id", lambda *, dataset_id, tenant_id: f"enc:{tenant_id}:{dataset_id}" + ) + monkeypatch.setattr( + TriggerScheduleNode := app_dsl_service.TriggerScheduleNode, + "get_default_config", + lambda: {"config": {"default": True}}, + ) + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_workflow", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + export_data: dict = {} + AppDslService._append_workflow_export_data( + export_data=export_data, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + nodes = export_data["workflow"]["graph"]["nodes"] + assert nodes[0]["data"]["dataset_ids"] == ["enc:tenant-1:d1", "enc:tenant-1:d2"] + assert "credential_id" not in nodes[1]["data"] + assert "credential_id" not in nodes[2]["data"]["agent_parameters"]["tools"]["value"][0] + assert nodes[3]["data"]["config"] == {"default": True} + assert nodes[4]["data"]["webhook_url"] == "" + assert nodes[4]["data"]["webhook_debug_url"] == "" + assert nodes[5]["data"]["subscription_id"] == "" + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_workflow_export_data_missing_workflow_raises(monkeypatch): + workflow_service = MagicMock() + workflow_service.get_draft_workflow.return_value = None + monkeypatch.setattr(app_dsl_service, "WorkflowService", lambda: workflow_service) + + with pytest.raises(ValueError, match="Missing draft workflow configuration"): + AppDslService._append_workflow_export_data( + export_data={}, + app_model=SimpleNamespace(tenant_id="tenant-1"), + include_secret=False, + workflow_id=None, + ) + + +def test_append_model_config_export_data_filters_credential_id(monkeypatch): + monkeypatch.setattr(AppDslService, "_extract_dependencies_from_model_config", lambda *_args, **_kwargs: ["dep-1"]) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "generate_dependencies", + lambda *, tenant_id, dependencies: [ + SimpleNamespace(model_dump=lambda: {"tenant": tenant_id, "dep": dependencies[0]}) + ], + ) + monkeypatch.setattr(app_dsl_service, "jsonable_encoder", lambda x: x) + + app_model_config = SimpleNamespace(to_dict=lambda: {"agent_mode": {"tools": [{"credential_id": "secret"}]}}) + app_model = SimpleNamespace(tenant_id="tenant-1", app_model_config=app_model_config) + export_data: dict = {} + + AppDslService._append_model_config_export_data(export_data, app_model) + assert export_data["model_config"]["agent_mode"]["tools"] == [{}] + assert export_data["dependencies"] == [{"tenant": "tenant-1", "dep": "dep-1"}] + + +def test_append_model_config_export_data_requires_app_config(): + with pytest.raises(ValueError, match="Missing app configuration"): + AppDslService._append_model_config_export_data({}, SimpleNamespace(app_model_config=None)) + + +def test_extract_dependencies_from_workflow_graph_covers_all_node_types(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + + monkeypatch.setattr(app_dsl_service.ToolNodeData, "model_validate", lambda _d: SimpleNamespace(provider_id="p1")) + monkeypatch.setattr( + app_dsl_service.LLMNodeData, "model_validate", lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m1")) + ) + monkeypatch.setattr( + app_dsl_service.QuestionClassifierNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m2")), + ) + monkeypatch.setattr( + app_dsl_service.ParameterExtractorNodeData, + "model_validate", + lambda _d: SimpleNamespace(model=SimpleNamespace(provider="m3")), + ) + + def kr_validate(_d): + return SimpleNamespace( + retrieval_mode="multiple", + multiple_retrieval_config=SimpleNamespace( + reranking_mode="weighted_score", + weights=SimpleNamespace(vector_setting=SimpleNamespace(embedding_provider_name="m4")), + reranking_model=None, + ), + single_retrieval_config=None, + ) + + monkeypatch.setattr(app_dsl_service.KnowledgeRetrievalNodeData, "model_validate", kr_validate) + + graph = { + "nodes": [ + {"data": {"type": BuiltinNodeTypes.TOOL}}, + {"data": {"type": BuiltinNodeTypes.LLM}}, + {"data": {"type": BuiltinNodeTypes.QUESTION_CLASSIFIER}}, + {"data": {"type": BuiltinNodeTypes.PARAMETER_EXTRACTOR}}, + {"data": {"type": BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL}}, + {"data": {"type": "unknown"}}, + ] + } + + deps = AppDslService._extract_dependencies_from_workflow_graph(graph) + assert deps == ["tool:p1", "model:m1", "model:m2", "model:m3", "model:m4"] + + +def test_extract_dependencies_from_workflow_graph_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.ToolNodeData, "model_validate", lambda _d: (_ for _ in ()).throw(ValueError("bad")) + ) + deps = AppDslService._extract_dependencies_from_workflow_graph( + {"nodes": [{"data": {"type": BuiltinNodeTypes.TOOL}}]} + ) + assert deps == [] + + +def test_extract_dependencies_from_model_config_parses_providers(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda provider: f"model:{provider}", + ) + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_tool_dependency", + lambda provider_id: f"tool:{provider_id}", + ) + + deps = AppDslService._extract_dependencies_from_model_config( + { + "model": {"provider": "p1"}, + "dataset_configs": { + "datasets": {"datasets": [{"reranking_model": {"reranking_provider_name": {"provider": "p2"}}}]} + }, + "agent_mode": {"tools": [{"provider_id": "t1"}]}, + } + ) + assert deps == ["model:p1", "model:p2", "tool:t1"] + + +def test_extract_dependencies_from_model_config_handles_exceptions(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "analyze_model_provider_dependency", + lambda _p: (_ for _ in ()).throw(ValueError("bad")), + ) + deps = AppDslService._extract_dependencies_from_model_config({"model": {"provider": "p1"}}) + assert deps == [] + + +def test_get_leaked_dependencies_empty_returns_empty(): + assert AppDslService.get_leaked_dependencies("tenant-1", []) == [] + + +def test_get_leaked_dependencies_delegates(monkeypatch): + monkeypatch.setattr( + app_dsl_service.DependenciesAnalysisService, + "get_leaked_dependencies", + lambda *, tenant_id, dependencies: [SimpleNamespace(tenant_id=tenant_id, deps=dependencies)], + ) + res = AppDslService.get_leaked_dependencies("tenant-1", [SimpleNamespace(id="x")]) + assert len(res) == 1 + + +def test_encrypt_decrypt_dataset_id_respects_config(monkeypatch): + tenant_id = "tenant-1" + dataset_uuid = "00000000-0000-0000-0000-000000000000" + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", False) + assert AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) == dataset_uuid + + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id=dataset_uuid, tenant_id=tenant_id) + assert encrypted != dataset_uuid + assert base64.b64decode(encrypted.encode()) + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id=tenant_id) == dataset_uuid + + +def test_decrypt_dataset_id_returns_plain_uuid_unchanged(): + value = "00000000-0000-0000-0000-000000000000" + assert AppDslService.decrypt_dataset_id(encrypted_data=value, tenant_id="tenant-1") == value + + +def test_decrypt_dataset_id_returns_none_on_invalid_data(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + assert AppDslService.decrypt_dataset_id(encrypted_data="not-base64", tenant_id="tenant-1") is None + + +def test_decrypt_dataset_id_returns_none_when_decrypted_is_not_uuid(monkeypatch): + monkeypatch.setattr(app_dsl_service.dify_config, "DSL_EXPORT_ENCRYPT_DATASET_ID", True) + encrypted = AppDslService.encrypt_dataset_id(dataset_id="not-a-uuid", tenant_id="tenant-1") + assert AppDslService.decrypt_dataset_id(encrypted_data=encrypted, tenant_id="tenant-1") is None + + +def test_is_valid_uuid_handles_bad_inputs(): + assert AppDslService._is_valid_uuid("00000000-0000-0000-0000-000000000000") is True + assert AppDslService._is_valid_uuid("nope") is False diff --git a/api/tests/unit_tests/services/test_app_generate_service.py b/api/tests/unit_tests/services/test_app_generate_service.py index 47b759bc7d..c2b430c551 100644 --- a/api/tests/unit_tests/services/test_app_generate_service.py +++ b/api/tests/unit_tests/services/test_app_generate_service.py @@ -1,14 +1,50 @@ +""" +Comprehensive unit tests for services.app_generate_service.AppGenerateService. + +Covers: + - _build_streaming_task_on_subscribe (streams / pubsub / exception / idempotency) + - generate (COMPLETION / AGENT_CHAT / CHAT / ADVANCED_CHAT / WORKFLOW / invalid mode, + streaming & blocking, billing, quota-refund-on-error, rate_limit.exit) + - _get_max_active_requests (all limit combos) + - generate_single_iteration (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_single_loop (ADVANCED_CHAT / WORKFLOW / invalid mode) + - generate_more_like_this + - _get_workflow (debugger / non-debugger / specific id / invalid format / not found) + - get_response_generator (ended / non-ended workflow run) +""" + +import threading +import time +import uuid +from contextlib import contextmanager from unittest.mock import MagicMock -import services.app_generate_service as app_generate_service_module +import pytest + +import services.app_generate_service as ags_module +from core.app.entities.app_invoke_entities import InvokeFrom from models.model import AppMode from services.app_generate_service import AppGenerateService +from services.errors.app import WorkflowIdFormatError, WorkflowNotFoundError +# --------------------------------------------------------------------------- +# Helpers / Fakes +# --------------------------------------------------------------------------- class _DummyRateLimit: + """Minimal stand-in for RateLimit that never touches Redis.""" + + _instance_dict: dict[str, "_DummyRateLimit"] = {} + + def __new__(cls, client_id: str, max_active_requests: int): + # avoid singleton caching across tests + instance = object.__new__(cls) + return instance + def __init__(self, client_id: str, max_active_requests: int) -> None: self.client_id = client_id self.max_active_requests = max_active_requests + self._exited: list[str] = [] @staticmethod def gen_request_key() -> str: @@ -18,101 +54,720 @@ class _DummyRateLimit: return request_id or "dummy-request-id" def exit(self, request_id: str) -> None: - return None + self._exited.append(request_id) def generate(self, generator, request_id: str): return generator -def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch): - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) +def _make_app(mode: AppMode | str, *, max_active_requests: int = 0, is_agent: bool = False) -> MagicMock: + app = MagicMock() + app.mode = mode + app.id = "app-id" + app.tenant_id = "tenant-id" + app.max_active_requests = max_active_requests + app.is_agent = is_agent + return app - workflow = MagicMock() - workflow.id = "workflow-id" - workflow.created_by = "owner-id" - - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) - - generator_spy = mocker.patch( - "services.app_generate_service.WorkflowAppGenerator.generate", - return_value={"result": "ok"}, - ) - - app_model = MagicMock() - app_model.mode = AppMode.WORKFLOW - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False +def _make_user() -> MagicMock: user = MagicMock() user.id = "user-id" - - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args={"inputs": {"k": "v"}}, - invoke_from=MagicMock(), - streaming=False, - ) - - assert result == {"result": "ok"} - - call_kwargs = generator_spy.call_args.kwargs - pause_state_config = call_kwargs.get("pause_state_config") - assert pause_state_config is not None - assert pause_state_config.state_owner_user_id == "owner-id" + return user -def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch): - """ - Regression test: ADVANCED_CHAT in blocking mode should return a plain dict - (non-streaming), and must not go through the async retrieve_events path. - Keeps behavior consistent with WORKFLOW blocking branch. - """ - # Disable billing and stub RateLimit to a no-op that just passes values through - monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False) - mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) - - # Arrange a fake workflow and wire AppGenerateService._get_workflow to return it +def _make_workflow(*, workflow_id: str = "workflow-id", created_by: str = "owner-id") -> MagicMock: workflow = MagicMock() - workflow.id = "workflow-id" - mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + workflow.id = workflow_id + workflow.created_by = created_by + return workflow - # Spy on the streaming retrieval path to ensure it's NOT called - retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") - # Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False - generate_spy = mocker.patch( - "services.app_generate_service.AdvancedChatAppGenerator.generate", - return_value={"result": "ok"}, - ) +@contextmanager +def _noop_rate_limit_context(rate_limit, request_id): + """Drop-in replacement for rate_limit_context that doesn't touch Redis.""" + yield - # Minimal app model for ADVANCED_CHAT - app_model = MagicMock() - app_model.mode = AppMode.ADVANCED_CHAT - app_model.id = "app-id" - app_model.tenant_id = "tenant-id" - app_model.max_active_requests = 0 - app_model.is_agent = False - user = MagicMock() - user.id = "user-id" +# --------------------------------------------------------------------------- +# _build_streaming_task_on_subscribe +# --------------------------------------------------------------------------- +class TestBuildStreamingTaskOnSubscribe: + """Tests for AppGenerateService._build_streaming_task_on_subscribe.""" - # Must include query and inputs for AdvancedChatAppGenerator - args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}} + def test_streams_mode_starts_immediately(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + # task started immediately during build + assert called == [1] + # calling the returned callback is idempotent + cb() + assert called == [1] # not called again - # Act: call service with streaming=False (blocking mode) - result = AppGenerateService.generate( - app_model=app_model, - user=user, - args=args, - invoke_from=MagicMock(), - streaming=False, - ) + def test_pubsub_mode_starts_on_subscribe(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) # large to prevent timer + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + # second call is idempotent + cb() + assert called == [1] - # Assert: returns the dict from generate(), and did not call retrieve_events() - assert result == {"result": "ok"} - assert generate_spy.call_args.kwargs.get("streaming") is False - retrieve_spy.assert_not_called() + def test_sharded_mode_starts_on_subscribe(self, monkeypatch): + """sharded is treated like pubsub (i.e. not 'streams').""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "sharded") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + called = [] + cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + assert called == [] + cb() + assert called == [1] + + def test_pubsub_fallback_timer_fires(self, monkeypatch): + """When nobody subscribes fast enough the fallback timer fires.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 50) # 50 ms + called = [] + _cb = AppGenerateService._build_streaming_task_on_subscribe(lambda: called.append(1)) + time.sleep(0.2) # give the timer time to fire + assert called == [1] + + def test_exception_in_start_task_returns_false(self, monkeypatch): + """When start_task raises, _try_start returns False and next call retries.""" + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + call_count = 0 + + def _bad(): + nonlocal call_count + call_count += 1 + if call_count == 1: + raise RuntimeError("boom") + + cb = AppGenerateService._build_streaming_task_on_subscribe(_bad) + # first call inside build raised, but is caught; second call via cb succeeds + assert call_count == 1 + cb() + assert call_count == 2 + + def test_concurrent_subscribe_only_starts_once(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub") + monkeypatch.setattr(ags_module, "SSE_TASK_START_FALLBACK_MS", 60_000) + call_count = 0 + + def _inc(): + nonlocal call_count + call_count += 1 + + cb = AppGenerateService._build_streaming_task_on_subscribe(_inc) + threads = [threading.Thread(target=cb) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + assert call_count == 1 + + +# --------------------------------------------------------------------------- +# _get_max_active_requests +# --------------------------------------------------------------------------- +class TestGetMaxActiveRequests: + def test_both_zero_returns_zero(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 0 + + def test_app_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_config_limit_only(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 10) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 10 + + def test_both_non_zero_returns_min(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 20) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 0) + app = _make_app(AppMode.CHAT, max_active_requests=5) + assert AppGenerateService._get_max_active_requests(app) == 5 + + def test_default_active_requests_used_when_app_has_none(self, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "APP_MAX_ACTIVE_REQUESTS", 0) + monkeypatch.setattr(ags_module.dify_config, "APP_DEFAULT_ACTIVE_REQUESTS", 15) + app = _make_app(AppMode.CHAT, max_active_requests=0) + assert AppGenerateService._get_max_active_requests(app) == 15 + + +# --------------------------------------------------------------------------- +# generate – every AppMode branch +# --------------------------------------------------------------------------- +class TestGenerate: + """Tests for AppGenerateService.generate covering each mode.""" + + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + # Prevent AppExecutionParams.new from touching real models via isinstance + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + # -- COMPLETION --------------------------------------------------------- + def test_completion_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"result": "ok"}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "ok"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via mode ------------------------------------------------ + def test_agent_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + result = AppGenerateService.generate( + app_model=_make_app(AppMode.AGENT_CHAT), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent"} + gen_spy.assert_called_once() + + # -- AGENT_CHAT via is_agent flag (non-AGENT_CHAT mode) ----------------- + def test_agent_via_is_agent_flag(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.generate", + return_value={"result": "agent-via-flag"}, + ) + mocker.patch( + "services.app_generate_service.AgentChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=True) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "agent-via-flag"} + gen_spy.assert_called_once() + + # -- CHAT --------------------------------------------------------------- + def test_chat_mode(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.ChatAppGenerator.generate", + return_value={"result": "chat"}, + ) + mocker.patch( + "services.app_generate_service.ChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + app = _make_app(AppMode.CHAT, is_agent=False) + result = AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "chat"} + gen_spy.assert_called_once() + + # -- ADVANCED_CHAT blocking --------------------------------------------- + def test_advanced_chat_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + + retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events") + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.generate", + return_value={"result": "advanced-blocking"}, + ) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "advanced-blocking"} + assert gen_spy.call_args.kwargs.get("streaming") is False + retrieve_spy.assert_not_called() + + # -- ADVANCED_CHAT streaming -------------------------------------------- + def test_advanced_chat_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-1", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe call the real on_subscribe + # so the inner closure (line 165) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.ADVANCED_CHAT), + user=_make_user(), + args={"workflow_id": None, "query": "hi", "inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + # In streaming mode it should go through retrieve_events, not generate + gen_instance.retrieve_events.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- WORKFLOW blocking -------------------------------------------------- + def test_workflow_blocking(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.generate", + return_value={"result": "workflow-blocking"}, + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + assert result == {"result": "workflow-blocking"} + call_kwargs = gen_spy.call_args.kwargs + assert call_kwargs.get("pause_state_config") is not None + assert call_kwargs["pause_state_config"].state_owner_user_id == "owner-id" + + # -- WORKFLOW streaming ------------------------------------------------- + def test_workflow_streaming(self, mocker, monkeypatch): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AppExecutionParams.new", + return_value=MagicMock(workflow_run_id="wfr-2", model_dump_json=MagicMock(return_value="{}")), + ) + delay_spy = mocker.patch("services.app_generate_service.workflow_based_app_execution_task.delay") + # Let _build_streaming_task_on_subscribe invoke the real on_subscribe + # so the inner closure (line 216) actually executes. + monkeypatch.setattr(ags_module.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams") + retrieve_spy = mocker.patch( + "services.app_generate_service.MessageBasedAppGenerator.retrieve_events", + return_value=iter([]), + ) + mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + result = AppGenerateService.generate( + app_model=_make_app(AppMode.WORKFLOW), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + retrieve_spy.assert_called_once() + # The inner on_subscribe closure was invoked by _build_streaming_task_on_subscribe + delay_spy.assert_called_once() + + # -- Invalid mode ------------------------------------------------------- + def test_invalid_mode_raises(self, mocker): + app = _make_app("invalid-mode", is_agent=False) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate( + app_model=app, + user=_make_user(), + args={}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + +# --------------------------------------------------------------------------- +# generate – billing / quota +# --------------------------------------------------------------------------- +class TestGenerateBilling: + @pytest.fixture(autouse=True) + def _common(self, mocker, monkeypatch): + mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit) + mocker.patch( + "services.app_generate_service.rate_limit_context", + _noop_rate_limit_context, + ) + + def test_billing_enabled_consumes_quota(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + consume_mock = mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + consume_mock.assert_called_once_with("tenant-id") + + def test_billing_quota_exceeded_raises_rate_limit_error(self, mocker, monkeypatch): + from services.errors.app import QuotaExceededError + from services.errors.llm import InvokeRateLimitError + + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + side_effect=QuotaExceededError(feature="workflow", tenant_id="t", required=1), + ) + + with pytest.raises(InvokeRateLimitError): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + + def test_exception_refunds_quota_and_exits_rate_limit(self, mocker, monkeypatch): + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", True) + quota_charge = MagicMock() + mocker.patch( + "services.app_generate_service.QuotaType.WORKFLOW.consume", + return_value=quota_charge, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + side_effect=RuntimeError("boom"), + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + with pytest.raises(RuntimeError, match="boom"): + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + quota_charge.refund.assert_called_once() + + def test_rate_limit_exit_called_in_finally_for_blocking(self, mocker, monkeypatch): + """For non-streaming (blocking) calls, rate_limit.exit should be called in finally.""" + monkeypatch.setattr(ags_module.dify_config, "BILLING_ENABLED", False) + + exit_calls: list[str] = [] + + class _TrackingRateLimit(_DummyRateLimit): + def exit(self, request_id: str) -> None: + exit_calls.append(request_id) + + mocker.patch("services.app_generate_service.RateLimit", _TrackingRateLimit) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate", + return_value={"ok": True}, + ) + mocker.patch( + "services.app_generate_service.CompletionAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + + AppGenerateService.generate( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + args={"inputs": {}}, + invoke_from=InvokeFrom.SERVICE_API, + streaming=False, + ) + # exit is called in finally block for non-streaming + assert len(exit_calls) >= 1 + + +# --------------------------------------------------------------------------- +# _get_workflow +# --------------------------------------------------------------------------- +class TestGetWorkflow: + def test_debugger_fetches_draft(self, mocker): + draft_wf = _make_workflow() + ws = MagicMock() + ws.get_draft_workflow.return_value = draft_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + assert result is draft_wf + ws.get_draft_workflow.assert_called_once() + + def test_debugger_raises_when_no_draft(self, mocker): + ws = MagicMock() + ws.get_draft_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not initialized"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.DEBUGGER) + + def test_non_debugger_fetches_published(self, mocker): + pub_wf = _make_workflow() + ws = MagicMock() + ws.get_published_workflow.return_value = pub_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + assert result is pub_wf + ws.get_published_workflow.assert_called_once() + + def test_non_debugger_raises_when_no_published(self, mocker): + ws = MagicMock() + ws.get_published_workflow.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(ValueError, match="Workflow not published"): + AppGenerateService._get_workflow(_make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API) + + def test_specific_workflow_id_valid_uuid(self, mocker): + valid_uuid = str(uuid.uuid4()) + specific_wf = _make_workflow(workflow_id=valid_uuid) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = specific_wf + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + result = AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + assert result is specific_wf + ws.get_published_workflow_by_id.assert_called_once() + + def test_specific_workflow_id_invalid_uuid(self, mocker): + ws = MagicMock() + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowIdFormatError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id="not-a-uuid" + ) + + def test_specific_workflow_id_not_found(self, mocker): + valid_uuid = str(uuid.uuid4()) + ws = MagicMock() + ws.get_published_workflow_by_id.return_value = None + mocker.patch("services.app_generate_service.WorkflowService", return_value=ws) + + with pytest.raises(WorkflowNotFoundError): + AppGenerateService._get_workflow( + _make_app(AppMode.WORKFLOW), InvokeFrom.SERVICE_API, workflow_id=valid_uuid + ) + + +# --------------------------------------------------------------------------- +# generate_single_iteration +# --------------------------------------------------------------------------- +class TestGenerateSingleIteration: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + gen_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_iteration_generate", + return_value={"event": "iteration"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "iteration"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + iter_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_iteration_generate", + return_value={"event": "wf-iteration"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_iteration( + app_model=app, user=_make_user(), node_id="n1", args={"k": "v"} + ) + iter_spy.assert_called_once() + assert result == {"event": "wf-iteration"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.CHAT) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_iteration(app_model=app, user=_make_user(), node_id="n1", args={}) + + +# --------------------------------------------------------------------------- +# generate_single_loop +# --------------------------------------------------------------------------- +class TestGenerateSingleLoop: + def test_advanced_chat_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.single_loop_generate", + return_value={"event": "loop"}, + ) + app = _make_app(AppMode.ADVANCED_CHAT) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "loop"} + + def test_workflow_mode(self, mocker): + workflow = _make_workflow() + mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow) + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator.convert_to_event_stream", + side_effect=lambda x: x, + ) + loop_spy = mocker.patch( + "services.app_generate_service.WorkflowAppGenerator.single_loop_generate", + return_value={"event": "wf-loop"}, + ) + app = _make_app(AppMode.WORKFLOW) + result = AppGenerateService.generate_single_loop( + app_model=app, user=_make_user(), node_id="n1", args=MagicMock() + ) + loop_spy.assert_called_once() + assert result == {"event": "wf-loop"} + + def test_invalid_mode_raises(self, mocker): + app = _make_app(AppMode.COMPLETION) + with pytest.raises(ValueError, match="Invalid app mode"): + AppGenerateService.generate_single_loop(app_model=app, user=_make_user(), node_id="n1", args=MagicMock()) + + +# --------------------------------------------------------------------------- +# generate_more_like_this +# --------------------------------------------------------------------------- +class TestGenerateMoreLikeThis: + def test_delegates_to_completion_generator(self, mocker): + gen_spy = mocker.patch( + "services.app_generate_service.CompletionAppGenerator.generate_more_like_this", + return_value={"result": "similar"}, + ) + result = AppGenerateService.generate_more_like_this( + app_model=_make_app(AppMode.COMPLETION), + user=_make_user(), + message_id="msg-1", + invoke_from=InvokeFrom.SERVICE_API, + streaming=True, + ) + assert result == {"result": "similar"} + gen_spy.assert_called_once() + assert gen_spy.call_args.kwargs["stream"] is True + + +# --------------------------------------------------------------------------- +# get_response_generator +# --------------------------------------------------------------------------- +class TestGetResponseGenerator: + def test_non_ended_workflow_run(self, mocker): + app = _make_app(AppMode.ADVANCED_CHAT) + workflow_run = MagicMock() + workflow_run.id = "run-1" + workflow_run.status.is_ended.return_value = False + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([{"event": "started"}]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + gen_instance.retrieve_events.assert_called_once() + + def test_ended_workflow_run_still_returns_generator(self, mocker): + """Even when the run is ended, the current code still returns a generator (TODO branch).""" + app = _make_app(AppMode.WORKFLOW) + workflow_run = MagicMock() + workflow_run.id = "run-2" + workflow_run.status.is_ended.return_value = True + + gen_instance = MagicMock() + gen_instance.retrieve_events.return_value = iter([]) + gen_instance.convert_to_event_stream.side_effect = lambda x: x + mocker.patch( + "services.app_generate_service.AdvancedChatAppGenerator", + return_value=gen_instance, + ) + + result = AppGenerateService.get_response_generator(app_model=app, workflow_run=workflow_run) + # current impl falls through the TODO and still creates a generator + gen_instance.retrieve_events.assert_called_once() diff --git a/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py new file mode 100644 index 0000000000..e66d52f66b --- /dev/null +++ b/api/tests/unit_tests/services/test_app_generate_service_streaming_integration.py @@ -0,0 +1,197 @@ +import json +import uuid +from collections import defaultdict, deque + +import pytest + +from core.app.apps.message_generator import MessageGenerator +from models.model import AppMode +from services.app_generate_service import AppGenerateService + + +# ----------------------------- +# Fakes for Redis Pub/Sub flow +# ----------------------------- +class _FakePubSub: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + self._subs: set[str] = set() + self._closed = False + + def subscribe(self, topic: str) -> None: + self._subs.add(topic) + + def unsubscribe(self, topic: str) -> None: + self._subs.discard(topic) + + def close(self) -> None: + self._closed = True + + def get_message(self, ignore_subscribe_messages: bool = True, timeout: int | float | None = 1): + # simulate a non-blocking poll; return first available + if self._closed: + return None + for t in list(self._subs): + q = self._store.get(t) + if q and len(q) > 0: + payload = q.popleft() + return {"type": "message", "channel": t, "data": payload} + # no message + return None + + +class _FakeRedisClient: + def __init__(self, store: dict[str, deque[bytes]]): + self._store = store + + def pubsub(self): + return _FakePubSub(self._store) + + def publish(self, topic: str, payload: bytes) -> None: + self._store.setdefault(topic, deque()).append(payload) + + +# ------------------------------------ +# Fakes for Redis Streams (XADD/XREAD) +# ------------------------------------ +class _FakeStreams: + def __init__(self) -> None: + # key -> list[(id, {field: value})] + self._data: dict[str, list[tuple[str, dict]]] = defaultdict(list) + self._seq: dict[str, int] = defaultdict(int) + + def xadd(self, key: str, fields: dict, *, maxlen: int | None = None) -> str: + # maxlen is accepted for API compatibility with redis-py; ignored in this test double + self._seq[key] += 1 + eid = f"{self._seq[key]}-0" + self._data[key].append((eid, fields)) + return eid + + def expire(self, key: str, seconds: int) -> None: + # no-op for tests + return None + + def xread(self, streams: dict, block: int | None = None, count: int | None = None): + assert len(streams) == 1 + key, last_id = next(iter(streams.items())) + entries = self._data.get(key, []) + start = 0 + if last_id != "0-0": + for i, (eid, _f) in enumerate(entries): + if eid == last_id: + start = i + 1 + break + if start >= len(entries): + return [] + end = len(entries) if count is None else min(len(entries), start + count) + return [(key, entries[start:end])] + + +@pytest.fixture +def _patch_get_channel_streams(monkeypatch): + from libs.broadcast_channel.redis.streams_channel import StreamsBroadcastChannel + + fake = _FakeStreams() + chan = StreamsBroadcastChannel(fake, retention_seconds=60) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees streams mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "streams", raising=False) + + +@pytest.fixture +def _patch_get_channel_pubsub(monkeypatch): + from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel + + store: dict[str, deque[bytes]] = defaultdict(deque) + client = _FakeRedisClient(store) + chan = RedisBroadcastChannel(client) + + def _get_channel(): + return chan + + # Patch both the source and the imported alias used by MessageGenerator + monkeypatch.setattr("extensions.ext_redis.get_pubsub_broadcast_channel", lambda: chan) + monkeypatch.setattr("core.app.apps.message_generator.get_pubsub_broadcast_channel", lambda: chan) + # Ensure AppGenerateService sees pubsub mode + import services.app_generate_service as ags + + monkeypatch.setattr(ags.dify_config, "PUBSUB_REDIS_CHANNEL_TYPE", "pubsub", raising=False) + + +def _publish_events(app_mode: AppMode, run_id: str, events: list[dict]): + # Publish events to the same topic used by MessageGenerator + topic = MessageGenerator.get_response_topic(app_mode, run_id) + for ev in events: + topic.publish(json.dumps(ev).encode()) + + +@pytest.mark.usefixtures("_patch_get_channel_streams") +def test_streams_full_flow_prepublish_and_replay(): + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + # Build start_task that publishes two events immediately + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + + def start_task(): + _publish_events(app_mode, run_id, events) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Start retrieving BEFORE subscription is established; in streams mode, we also started immediately + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + # skip ping events + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] + + +@pytest.mark.usefixtures("_patch_get_channel_pubsub") +def test_pubsub_full_flow_start_on_subscribe_gated(monkeypatch): + # Speed up any potential timer if it accidentally triggers + monkeypatch.setattr("services.app_generate_service.SSE_TASK_START_FALLBACK_MS", 50) + + app_mode = AppMode.WORKFLOW + run_id = str(uuid.uuid4()) + + published_order: list[str] = [] + + def start_task(): + # When called (on subscribe), publish both events + events = [{"event": "workflow_started"}, {"event": "workflow_finished"}] + _publish_events(app_mode, run_id, events) + published_order.extend([e["event"] for e in events]) + + on_subscribe = AppGenerateService._build_streaming_task_on_subscribe(start_task) + + # Producer not started yet; only when subscribe happens + assert published_order == [] + + gen = MessageGenerator.retrieve_events(app_mode, run_id, idle_timeout=2.0, on_subscribe=on_subscribe) + + received = [] + for msg in gen: + if isinstance(msg, str): + continue + received.append(msg) + if msg.get("event") == "workflow_finished": + break + + # Verify publish happened and consumer received in order + assert published_order == ["workflow_started", "workflow_finished"] + assert [m.get("event") for m in received] == ["workflow_started", "workflow_finished"] diff --git a/api/tests/unit_tests/services/test_app_model_config_service.py b/api/tests/unit_tests/services/test_app_model_config_service.py new file mode 100644 index 0000000000..d4b4bf14a3 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_model_config_service.py @@ -0,0 +1,88 @@ +from unittest.mock import patch + +import pytest + +from models.model import AppMode +from services.app_model_config_service import AppModelConfigService + + +@pytest.fixture +def mock_config_managers(): + """Fixture that patches all app config manager validate methods. + + Returns a dictionary containing the mocked config_validate methods for each manager. + """ + with ( + patch("services.app_model_config_service.ChatAppConfigManager.config_validate") as mock_chat_validate, + patch("services.app_model_config_service.AgentChatAppConfigManager.config_validate") as mock_agent_validate, + patch( + "services.app_model_config_service.CompletionAppConfigManager.config_validate" + ) as mock_completion_validate, + ): + mock_chat_validate.return_value = {"manager": "chat"} + mock_agent_validate.return_value = {"manager": "agent"} + mock_completion_validate.return_value = {"manager": "completion"} + + yield { + "chat": mock_chat_validate, + "agent": mock_agent_validate, + "completion": mock_completion_validate, + } + + +class TestAppModelConfigService: + @pytest.mark.parametrize( + ("app_mode", "selected_manager"), + [ + (AppMode.CHAT, "chat"), + (AppMode.AGENT_CHAT, "agent"), + (AppMode.COMPLETION, "completion"), + ], + ) + def test_should_route_validation_to_correct_manager_based_on_app_mode( + self, app_mode, selected_manager, mock_config_managers + ): + """Test configuration validation is delegated to the expected manager for each supported app mode.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + result = AppModelConfigService.validate_configuration(tenant_id=tenant_id, config=config, app_mode=app_mode) + + assert result == {"manager": selected_manager} + + if selected_manager == "chat": + mock_chat_validate.assert_called_once_with(tenant_id, config) + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() + elif selected_manager == "agent": + mock_agent_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_completion_validate.assert_not_called() + else: + mock_completion_validate.assert_called_once_with(tenant_id, config) + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + + def test_should_raise_value_error_when_app_mode_is_not_supported(self, mock_config_managers): + """Test unsupported app modes raise ValueError with the invalid mode in the message.""" + tenant_id = "tenant-123" + config = {"temperature": 0.5} + + mock_chat_validate = mock_config_managers["chat"] + mock_agent_validate = mock_config_managers["agent"] + mock_completion_validate = mock_config_managers["completion"] + + with pytest.raises(ValueError, match=f"Invalid app mode: {AppMode.WORKFLOW}"): + AppModelConfigService.validate_configuration( + tenant_id=tenant_id, + config=config, + app_mode=AppMode.WORKFLOW, + ) + + mock_chat_validate.assert_not_called() + mock_agent_validate.assert_not_called() + mock_completion_validate.assert_not_called() diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py new file mode 100644 index 0000000000..bff8dc92c6 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_service.py @@ -0,0 +1,609 @@ +"""Unit tests for services.app_service.""" + +import json +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.errors.error import ProviderTokenNotInitError +from models import Account, Tenant +from models.model import App, AppMode +from services.app_service import AppService + + +@pytest.fixture +def service() -> AppService: + """Provide AppService instance.""" + return AppService() + + +@pytest.fixture +def account() -> Account: + """Create account object for create_app tests.""" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + result = Account(name="Account User", email="account@example.com") + result.id = "acc-1" + result._current_tenant = tenant + return result + + +@pytest.fixture +def default_args() -> dict: + """Create default create_app args.""" + return { + "name": "Test App", + "mode": AppMode.CHAT.value, + "icon": "🤖", + "icon_background": "#FFFFFF", + } + + +@pytest.fixture +def app_template() -> dict: + """Create basic app template for create_app tests.""" + return { + AppMode.CHAT: { + "app": {}, + "model_config": { + "model": { + "provider": "provider-a", + "name": "model-a", + "mode": "chat", + "completion_params": {}, + } + }, + } + } + + +def _make_current_user() -> Account: + user = Account(name="Tester", email="tester@example.com") + user.id = "user-1" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + user._current_tenant = tenant + return user + + +class TestAppServicePagination: + """Test suite for get_paginate_apps.""" + + def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: + """Test pagination returns None when tag filter has no targets.""" + # Arrange + args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} + + with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is None + + def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: + """Test pagination delegates to db.paginate when filters are valid.""" + # Arrange + args = { + "mode": "workflow", + "is_created_by_me": True, + "name": "My_App%", + "tag_ids": ["tag-1"], + "page": 2, + "limit": 10, + } + expected_pagination = MagicMock() + + with ( + patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), + patch("libs.helper.escape_like_pattern", return_value="escaped"), + patch("services.app_service.db") as mock_db, + ): + mock_db.paginate.return_value = expected_pagination + + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is expected_pagination + mock_db.paginate.assert_called_once() + + +class TestAppServiceCreate: + """Test suite for create_app.""" + + def test_create_app_should_create_with_matching_default_model( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app uses matching default model and persists app config.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + model_instance = SimpleNamespace( + model_name="model-a", + provider="provider-a", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_config.BILLING_ENABLED = True + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + assert app_instance.app_model_config_id == "cfg-1" + mock_db.session.add.assert_any_call(app_instance) + mock_db.session.add.assert_any_call(app_model_config) + assert mock_db.session.flush.call_count == 2 + mock_db.session.commit.assert_called_once() + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + + def test_create_app_should_raise_when_model_schema_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app raises ValueError when non-matching model has no schema.""" + # Arrange + app_instance = SimpleNamespace(id="app-1") + model_instance = SimpleNamespace( + model_name="model-b", + provider="provider-b", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + model_instance.model_type_instance.get_model_schema.return_value = None + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + + # Act & Assert + with pytest.raises(ValueError, match="model schema not found"): + service.create_app("tenant-1", default_args, account) + mock_db.session.commit.assert_not_called() + + def test_create_app_should_fallback_to_default_provider_when_model_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app falls back to provider/model name when no default model instance is available.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_config.BILLING_ENABLED = False + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") + + def test_create_app_should_log_and_fallback_on_unexpected_model_error( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test unexpected model manager errors are logged and fallback provider is used.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db"), + patch("services.app_service.app_was_created"), + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), + ), + patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), + patch("services.app_service.logger") as mock_logger, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = RuntimeError("boom") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_logger.exception.assert_called_once() + + +class TestAppServiceGetAndUpdate: + """Test suite for app retrieval and update methods.""" + + def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: + """Test get_app returns original app for non-agent modes.""" + # Arrange + app = MagicMock() + app.mode = AppMode.CHAT + app.is_agent = False + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: + """Test get_app returns app when agent mode has no model config.""" + # Arrange + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = None + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: + """Test get_app decrypts and masks secret tool parameters.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + manager = MagicMock() + manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} + manager.mask_tool_parameters.return_value = {"secret": "***"} + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), + patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + assert tool["tool_parameters"] == {"secret": "***"} + assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} + + def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: + """Test get_app logs and continues when masking fails.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), + patch("services.app_service.logger") as mock_logger, + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + mock_logger.exception.assert_called_once() + + def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: + """Test update methods set fields and commit changes.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type="emoji", + icon="a", + icon_background="#111", + enable_site=True, + enable_api=True, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "image", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + renamed = service.update_app_name(app, "rename") + iconed = service.update_app_icon(app, "icon-2", "#333") + site_same = service.update_app_site_status(app, app.enable_site) + api_same = service.update_app_api_status(app, app.enable_api) + site_changed = service.update_app_site_status(app, False) + api_changed = service.update_app_api_status(app, False) + + # Assert + assert updated is app + assert renamed is app + assert iconed is app + assert site_same is app + assert api_same is app + assert site_changed is app + assert api_changed is app + assert mock_db.session.commit.call_count >= 5 + + +class TestAppServiceDeleteAndMeta: + """Test suite for delete and metadata methods.""" + + def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: + """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" + # Arrange + app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + with ( + patch("services.app_service.db") as mock_db, + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), + ), + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch( + "services.app_service.dify_config", + new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), + ), + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.remove_app_and_related_data_task") as mock_task, + ): + # Act + service.delete_app(app) + + # Assert + mock_db.session.delete.assert_called_once_with(app) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") + + def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: + """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" + # Arrange + workflow = SimpleNamespace( + graph_dict={ + "nodes": [ + { + "data": { + "type": "tool", + "provider_type": "builtin", + "provider_id": "builtin-provider", + "tool_name": "tool_builtin", + } + }, + { + "data": { + "type": "tool", + "provider_type": "api", + "provider_id": "api-provider-id", + "tool_name": "tool_api", + } + }, + ] + } + ) + app = cast( + App, + SimpleNamespace( + mode=AppMode.WORKFLOW.value, + workflow=workflow, + app_model_config=None, + tenant_id="tenant-1", + icon_type="emoji", + icon_background="#fff", + ), + ) + + provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = provider + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") + assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} + + def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: + """Test get_app_meta falls back to default icon when API provider lookup fails.""" + # Arrange + app_model_config = SimpleNamespace( + agent_mode_dict={ + "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] + } + ) + app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} + + def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: + """Test get_app_meta returns empty metadata when workflow/model config is absent.""" + # Arrange + workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) + chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) + + # Act + workflow_meta = service.get_app_meta(workflow_app) + chat_meta = service.get_app_meta(chat_app) + + # Assert + assert workflow_meta == {"tool_icons": {}} + assert chat_meta == {"tool_icons": {}} + + +class TestAppServiceCodeLookup: + """Test suite for app code lookup methods.""" + + def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: + """Test get_app_code_by_id raises when site is missing.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id("app-1") + + def test_get_app_code_by_id_should_return_code(self) -> None: + """Test get_app_code_by_id returns site code.""" + # Arrange + site = SimpleNamespace(code="code-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_code_by_id("app-1") + + # Assert + assert result == "code-1" + + def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: + """Test get_app_id_by_code raises when code does not exist.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("missing") + + def test_get_app_id_by_code_should_return_app_id(self) -> None: + """Test get_app_id_by_code returns linked app id.""" + # Arrange + site = SimpleNamespace(app_id="app-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_id_by_code("code-1") + + # Assert + assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_async_workflow_service.py b/api/tests/unit_tests/services/test_async_workflow_service.py new file mode 100644 index 0000000000..639e091041 --- /dev/null +++ b/api/tests/unit_tests/services/test_async_workflow_service.py @@ -0,0 +1,507 @@ +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +import services.async_workflow_service as async_workflow_service_module +from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus +from services.async_workflow_service import AsyncWorkflowService +from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError +from services.workflow.entities import AsyncTriggerResponse, TriggerData +from services.workflow.queue_dispatcher import QueuePriority + + +class AsyncWorkflowServiceTestDataFactory: + """Factory helpers for async workflow service unit tests.""" + + @staticmethod + def create_trigger_data( + app_id: str = "app-123", + tenant_id: str = "tenant-123", + workflow_id: str | None = "workflow-123", + root_node_id: str = "root-node-123", + ) -> TriggerData: + """Create valid trigger data for async workflow execution tests.""" + return TriggerData( + app_id=app_id, + tenant_id=tenant_id, + workflow_id=workflow_id, + root_node_id=root_node_id, + inputs={"name": "dify"}, + files=[], + trigger_type=AppTriggerType.UNKNOWN, + trigger_from=WorkflowRunTriggeredFrom.APP_RUN, + trigger_metadata=None, + ) + + @staticmethod + def create_trigger_log_with_data(trigger_data: TriggerData, retry_count: int = 0) -> MagicMock: + """Create a mock trigger log with serialized trigger data.""" + trigger_log = MagicMock() + trigger_log.id = "trigger-log-123" + trigger_log.trigger_data = trigger_data.model_dump_json() + trigger_log.retry_count = retry_count + trigger_log.error = "previous-error" + trigger_log.status = WorkflowTriggerStatus.FAILED + trigger_log.to_dict.return_value = {"id": trigger_log.id} + return trigger_log + + +class TestAsyncWorkflowService: + @pytest.fixture + def async_workflow_trigger_mocks(self): + """Shared fixture for async workflow trigger tests. + + Yields mocks for: + - repo: SQLAlchemyWorkflowTriggerLogRepository + - dispatcher_manager_class: QueueDispatcherManager class + - dispatcher: dispatcher instance + - quota_workflow: QuotaType.WORKFLOW + - get_workflow: AsyncWorkflowService._get_workflow method + - professional_task: execute_workflow_professional + - team_task: execute_workflow_team + - sandbox_task: execute_workflow_sandbox + """ + mock_repo = MagicMock() + + def _create_side_effect(new_log): + new_log.id = "trigger-log-123" + return new_log + + mock_repo.create.side_effect = _create_side_effect + + mock_dispatcher = MagicMock() + quota_workflow = MagicMock() + mock_get_workflow = MagicMock() + + mock_professional_task = MagicMock() + mock_team_task = MagicMock() + mock_sandbox_task = MagicMock() + + with ( + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + patch.object(async_workflow_service_module, "QueueDispatcherManager") as mock_dispatcher_manager_class, + patch.object(async_workflow_service_module, "WorkflowService"), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "_get_workflow", + ) as mock_get_workflow, + patch.object( + async_workflow_service_module, + "QuotaType", + new=SimpleNamespace(WORKFLOW=quota_workflow), + ), + patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task, + patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task, + patch.object(async_workflow_service_module, "execute_workflow_sandbox") as mock_sandbox_task, + ): + # Configure dispatcher_manager to return our mock_dispatcher + mock_dispatcher_manager_class.return_value.get_dispatcher.return_value = mock_dispatcher + + yield { + "repo": mock_repo, + "dispatcher_manager_class": mock_dispatcher_manager_class, + "dispatcher": mock_dispatcher, + "quota_workflow": quota_workflow, + "get_workflow": mock_get_workflow, + "professional_task": mock_professional_task, + "team_task": mock_team_task, + "sandbox_task": mock_sandbox_task, + } + + @pytest.mark.parametrize( + ("queue_name", "selected_task_attr"), + [ + (QueuePriority.PROFESSIONAL, "execute_workflow_professional"), + (QueuePriority.TEAM, "execute_workflow_team"), + (QueuePriority.SANDBOX, "execute_workflow_sandbox"), + ], + ) + def test_should_dispatch_to_matching_celery_task_when_triggering_workflow( + self, queue_name, selected_task_attr, async_workflow_trigger_mocks + ): + """Test queue-based task routing and successful async trigger response.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = queue_name + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock() + task_result.id = "task-123" + mocks["professional_task"].delay.return_value = task_result + mocks["team_task"].delay.return_value = task_result + mocks["sandbox_task"].delay.return_value = task_result + + class DummyAccount: + def __init__(self, user_id: str): + self.id = user_id + + with patch.object(async_workflow_service_module, "Account", DummyAccount): + user = DummyAccount("account-123") + + # Act + result = AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + assert isinstance(result, AsyncTriggerResponse) + assert result.workflow_trigger_log_id == "trigger-log-123" + assert result.task_id == "task-123" + assert result.status == "queued" + assert result.queue == queue_name + + mocks["quota_workflow"].consume.assert_called_once_with("tenant-123") + assert session.commit.call_count == 2 + + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.status == WorkflowTriggerStatus.QUEUED + assert created_log.queue_name == queue_name + assert created_log.created_by_role == CreatorUserRole.ACCOUNT + assert created_log.created_by == "account-123" + assert created_log.trigger_data == trigger_data.model_dump_json() + assert created_log.inputs == json.dumps(dict(trigger_data.inputs)) + assert created_log.celery_task_id == "task-123" + + task_mocks = { + "execute_workflow_professional": mocks["professional_task"], + "execute_workflow_team": mocks["team_task"], + "execute_workflow_sandbox": mocks["sandbox_task"], + } + for task_attr, task_mock in task_mocks.items(): + if task_attr == selected_task_attr: + task_mock.delay.assert_called_once_with({"workflow_trigger_log_id": "trigger-log-123"}) + else: + task_mock.delay.assert_not_called() + + def test_should_set_end_user_role_when_triggered_by_end_user(self, async_workflow_trigger_mocks): + """Test that non-account users are tracked as END_USER in trigger logs.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.SANDBOX + mocks["get_workflow"].return_value = workflow + + task_result = MagicMock(id="task-123") + mocks["sandbox_task"].delay.return_value = task_result + + user = SimpleNamespace(id="end-user-123") + + # Act + AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data) + + # Assert + created_log = mocks["repo"].create.call_args[0][0] + assert created_log.created_by_role == CreatorUserRole.END_USER + assert created_log.created_by == "end-user-123" + + def test_should_raise_workflow_not_found_when_app_does_not_exist(self): + """Test trigger failure when app lookup returns no result.""" + # Arrange + session = MagicMock() + session.scalar.return_value = None + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data(app_id="missing-app") + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository"), + patch.object(async_workflow_service_module, "QueueDispatcherManager"), + patch.object(async_workflow_service_module, "WorkflowService"), + ): + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="App not found: missing-app"): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + def test_should_mark_log_rate_limited_and_raise_when_quota_exceeded(self, async_workflow_trigger_mocks): + """Test quota-exceeded path updates trigger log and raises WorkflowQuotaLimitError.""" + # Arrange + session = MagicMock() + session.commit = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + session.scalar.return_value = app_model + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + workflow = MagicMock() + workflow.id = "workflow-123" + + mocks = async_workflow_trigger_mocks + mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM + mocks["get_workflow"].return_value = workflow + mocks["quota_workflow"].consume.side_effect = QuotaExceededError( + feature="workflow", + tenant_id="tenant-123", + required=1, + ) + + # Act / Assert + with pytest.raises( + WorkflowQuotaLimitError, + match="Workflow execution quota limit reached for tenant tenant-123", + ): + AsyncWorkflowService.trigger_workflow_async( + session=session, + user=SimpleNamespace(id="user-123"), + trigger_data=trigger_data, + ) + + assert session.commit.call_count == 2 + updated_log = mocks["repo"].update.call_args[0][0] + assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED + assert "Quota limit reached" in updated_log.error + mocks["professional_task"].delay.assert_not_called() + mocks["team_task"].delay.assert_not_called() + mocks["sandbox_task"].delay.assert_not_called() + + def test_should_raise_when_reinvoke_target_log_does_not_exist(self): + """Test reinvoke_trigger error path when original trigger log is missing.""" + # Arrange + session = MagicMock() + repo = MagicMock() + repo.get_by_id.return_value = None + + with patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo): + # Act / Assert + with pytest.raises(ValueError, match="Trigger log not found: missing-log"): + AsyncWorkflowService.reinvoke_trigger( + session=session, + user=SimpleNamespace(id="user-123"), + workflow_trigger_log_id="missing-log", + ) + + def test_should_update_original_log_and_requeue_when_reinvoking(self): + """Test reinvoke flow updates original log state and triggers a new async run.""" + # Arrange + session = MagicMock() + trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data() + trigger_log = AsyncWorkflowServiceTestDataFactory.create_trigger_log_with_data(trigger_data, retry_count=1) + repo = MagicMock() + repo.get_by_id.return_value = trigger_log + + expected_response = AsyncTriggerResponse( + workflow_trigger_log_id="new-trigger-log-456", + task_id="task-456", + status="queued", + queue=QueuePriority.TEAM, + ) + + with ( + patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo), + patch.object( + async_workflow_service_module.AsyncWorkflowService, + "trigger_workflow_async", + return_value=expected_response, + ) as mock_trigger_workflow_async, + ): + user = SimpleNamespace(id="user-123") + + # Act + response = AsyncWorkflowService.reinvoke_trigger( + session=session, + user=user, + workflow_trigger_log_id="trigger-log-123", + ) + + # Assert + assert response == expected_response + assert trigger_log.status == WorkflowTriggerStatus.RETRYING + assert trigger_log.retry_count == 2 + assert trigger_log.error is None + assert trigger_log.triggered_at is not None + repo.update.assert_called_once_with(trigger_log) + session.commit.assert_called_once() + called_trigger_data = mock_trigger_workflow_async.call_args[0][2] + assert isinstance(called_trigger_data, TriggerData) + assert called_trigger_data.app_id == "app-123" + + @pytest.mark.parametrize( + ("repo_result", "expected"), + [ + (None, None), + (MagicMock(), {"id": "trigger-log-123"}), + ], + ) + def test_should_return_trigger_log_dict_or_none(self, repo_result, expected): + """Test get_trigger_log returns serialized log data or None.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + fake_engine = MagicMock() + mock_repo.get_by_id.return_value = repo_result + if repo_result: + repo_result.to_dict.return_value = expected + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)), + patch.object( + async_workflow_service_module, "Session", return_value=mock_session_context + ) as mock_session_class, + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_trigger_log("trigger-log-123", tenant_id="tenant-123") + + # Assert + assert result == expected + mock_session_class.assert_called_once_with(fake_engine) + mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123") + + def test_should_return_recent_logs_as_dict_list(self): + """Test get_recent_logs converts repository models into dictionaries.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log1 = MagicMock() + log1.to_dict.return_value = {"id": "log-1"} + log2 = MagicMock() + log2.to_dict.return_value = {"id": "log-2"} + mock_repo.get_recent_logs.return_value = [log1, log2] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_recent_logs( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + # Assert + assert result == [{"id": "log-1"}, {"id": "log-2"}] + mock_repo.get_recent_logs.assert_called_once_with( + tenant_id="tenant-123", + app_id="app-123", + hours=12, + limit=50, + offset=10, + ) + + def test_should_return_failed_logs_for_retry_as_dict_list(self): + """Test get_failed_logs_for_retry serializes repository logs into dicts.""" + # Arrange + mock_session = MagicMock() + mock_repo = MagicMock() + log = MagicMock() + log.to_dict.return_value = {"id": "failed-log-1"} + mock_repo.get_failed_for_retry.return_value = [log] + + mock_session_context = MagicMock() + mock_session_context.__enter__.return_value = mock_session + mock_session_context.__exit__.return_value = None + + with ( + patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())), + patch.object(async_workflow_service_module, "Session", return_value=mock_session_context), + patch.object( + async_workflow_service_module, + "SQLAlchemyWorkflowTriggerLogRepository", + return_value=mock_repo, + ), + ): + # Act + result = AsyncWorkflowService.get_failed_logs_for_retry(tenant_id="tenant-123", max_retry_count=4, limit=20) + + # Assert + assert result == [{"id": "failed-log-1"}] + mock_repo.get_failed_for_retry.assert_called_once_with(tenant_id="tenant-123", max_retry_count=4, limit=20) + + +class TestAsyncWorkflowServiceGetWorkflow: + def test_should_return_specific_workflow_when_workflow_id_exists(self): + """Test _get_workflow returns published workflow by id when provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-123") + + # Assert + assert result == workflow + workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123") + workflow_service.get_published_workflow.assert_not_called() + + def test_should_raise_when_specific_workflow_id_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError for unknown workflow id.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + workflow_service.get_published_workflow_by_id.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="Published workflow not found: workflow-404"): + AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-404") + + def test_should_return_default_published_workflow_when_workflow_id_not_provided(self): + """Test _get_workflow returns default published workflow when no id is provided.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow = MagicMock() + workflow_service.get_published_workflow.return_value = workflow + + # Act + result = AsyncWorkflowService._get_workflow(workflow_service, app_model) + + # Assert + assert result == workflow + workflow_service.get_published_workflow.assert_called_once_with(app_model) + workflow_service.get_published_workflow_by_id.assert_not_called() + + def test_should_raise_when_default_published_workflow_not_found(self): + """Test _get_workflow raises WorkflowNotFoundError when app has no published workflow.""" + # Arrange + workflow_service = MagicMock() + app_model = MagicMock() + app_model.id = "app-123" + workflow_service.get_published_workflow.return_value = None + + # Act / Assert + with pytest.raises(WorkflowNotFoundError, match="No published workflow found for app: app-123"): + AsyncWorkflowService._get_workflow(workflow_service, app_model) diff --git a/api/tests/unit_tests/services/test_batch_indexing_base.py b/api/tests/unit_tests/services/test_batch_indexing_base.py new file mode 100644 index 0000000000..bd68b67d89 --- /dev/null +++ b/api/tests/unit_tests/services/test_batch_indexing_base.py @@ -0,0 +1,387 @@ +from dataclasses import asdict +from typing import Any, ClassVar, cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.document_task import DocumentTask +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy + +# --------------------------------------------------------------------------- +# Concrete subclass for testing (the base class is abstract) +# --------------------------------------------------------------------------- + + +class ConcreteBatchProxy(BatchDocumentIndexingProxy): + """Minimal concrete implementation that provides the required class-level vars.""" + + QUEUE_NAME: ClassVar[str] = "test_queue" + NORMAL_TASK_FUNC: ClassVar[Any] = MagicMock(name="NORMAL_TASK_FUNC") + PRIORITY_TASK_FUNC: ClassVar[Any] = MagicMock(name="PRIORITY_TASK_FUNC") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +DATASET_ID = "dataset-xyz" +DOC_IDS: list[str] = ["doc-1", "doc-2", "doc-3"] + + +def make_proxy(**kwargs: Any) -> ConcreteBatchProxy: + """Factory: returns a ConcreteBatchProxy with TenantIsolatedTaskQueue mocked out.""" + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + proxy = ConcreteBatchProxy( + tenant_id=kwargs.get("tenant_id", TENANT_ID), + dataset_id=kwargs.get("dataset_id", DATASET_ID), + document_ids=kwargs.get("document_ids", DOC_IDS), + ) + # Expose the mock queue on the proxy so tests can assert on it + proxy._tenant_isolated_task_queue = MockQueue.return_value + return proxy + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- + + +class TestBatchDocumentIndexingProxyInit: + """Tests for __init__ of BatchDocumentIndexingProxy.""" + + def test_should_store_document_ids_when_initialized(self) -> None: + """Verify that document_ids are stored on the proxy instance.""" + # Arrange + doc_ids: list[str] = ["doc-a", "doc-b"] + + # Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert proxy._document_ids == doc_ids + + def test_should_propagate_tenant_and_dataset_to_base_when_initialized(self) -> None: + """Verify that tenant_id and dataset_id are forwarded to the parent class.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + assert proxy._tenant_id == TENANT_ID + assert proxy._dataset_id == DATASET_ID + + def test_should_create_tenant_isolated_queue_with_correct_args_when_initialized(self) -> None: + """Verify that TenantIsolatedTaskQueue is constructed with (tenant_id, QUEUE_NAME).""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + MockQueue.assert_called_once_with(TENANT_ID, ConcreteBatchProxy.QUEUE_NAME) + + @pytest.mark.parametrize("doc_ids", [[], ["single-doc"], ["d1", "d2", "d3", "d4"]]) + def test_should_accept_any_length_document_ids_when_initialized(self, doc_ids: list[str]) -> None: + """Verify that empty, single, and multiple document IDs are all accepted.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert list(proxy._document_ids) == doc_ids + + +class TestSendToDirectQueue: + """Tests for _send_to_direct_queue.""" + + def test_should_call_task_func_delay_with_correct_args_when_sent_to_direct_queue( + self, + ) -> None: + """Verify that task_func.delay is called with the right kwargs.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + task_func.delay.assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_interact_with_tenant_queue_when_sent_to_direct_queue(self) -> None: + """Direct queue path must never touch the tenant-isolated queue.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_forward_any_callable_when_sent_to_direct_queue(self) -> None: + """Verify that different task functions are each called correctly.""" + # Arrange + proxy = make_proxy() + task_a, task_b = MagicMock(), MagicMock() + + # Act + proxy._send_to_direct_queue(task_a) + proxy._send_to_direct_queue(task_b) + + # Assert + task_a.delay.assert_called_once() + task_b.delay.assert_called_once() + + +class TestSendToTenantQueue: + """Tests for _send_to_tenant_queue — both branches.""" + + # ------------------------------------------------------------------ + # Branch 1: get_task_key() is truthy → push to waiting queue + # ------------------------------------------------------------------ + + def test_should_push_task_to_queue_when_task_key_exists(self) -> None: + """When get_task_key() is truthy, tasks must be pushed via push_tasks().""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + expected_payload = [asdict(DocumentTask(tenant_id=TENANT_ID, dataset_id=DATASET_ID, document_ids=DOC_IDS))] + mock_queue.push_tasks.assert_called_once_with(expected_payload) + + def test_should_not_call_task_func_delay_when_task_key_exists(self) -> None: + """When a key already exists, task_func.delay must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + cast(MagicMock, task_func.delay).assert_not_called() + + def test_should_not_set_waiting_time_when_task_key_exists(self) -> None: + """When a key already exists, set_task_waiting_time must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_serialize_document_task_correctly_when_pushing_to_queue(self) -> None: + """Verify the serialised payload matches asdict(DocumentTask(...)).""" + # Arrange + proxy = make_proxy(document_ids=["doc-x"]) + proxy._tenant_isolated_task_queue.get_task_key.return_value = "k" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert — inspect the payload passed to push_tasks + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + call_args = mock_queue.push_tasks.call_args + pushed_list = call_args[0][0] # first positional arg + assert len(pushed_list) == 1 + assert pushed_list[0]["tenant_id"] == TENANT_ID + assert pushed_list[0]["dataset_id"] == DATASET_ID + assert pushed_list[0]["document_ids"] == ["doc-x"] + + # ------------------------------------------------------------------ + # Branch 2: get_task_key() is falsy → set flag + dispatch via delay + # ------------------------------------------------------------------ + + def test_should_set_waiting_time_and_call_delay_when_no_task_key(self) -> None: + """When get_task_key() is falsy, set_task_waiting_time and task_func.delay are invoked.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_push_tasks_when_no_task_key(self) -> None: + """When get_task_key() is falsy, push_tasks must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + + @pytest.mark.parametrize("falsy_key", [None, "", 0, False]) + def test_should_init_task_when_key_is_any_falsy_value(self, falsy_key: Any) -> None: + """Verify that any falsy return from get_task_key() triggers the init branch.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = falsy_key + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once() + + +class TestDispatchRouting: + """Tests for the _dispatch / delay routing logic inherited from the base class.""" + + def _mock_features(self, enabled: bool, plan: CloudPlan) -> MagicMock: + features = MagicMock() + features.billing.enabled = enabled + features.billing.subscription.plan = plan + return features + + def test_should_send_to_normal_tenant_queue_when_billing_enabled_and_sandbox_plan(self) -> None: + """Sandbox plan routes to normal priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_default_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_tenant_queue_when_billing_enabled_and_paid_plan(self) -> None: + """Non-sandbox paid plan routes to priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.PROFESSIONAL) + + # Act + with patch.object(proxy, "_send_to_priority_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_direct_queue_when_billing_not_enabled(self) -> None: + """Self-hosted / no billing → priority direct queue (no tenant isolation).""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_priority_direct_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_call_dispatch_when_delay_is_invoked(self) -> None: + """Calling delay() must invoke _dispatch() exactly once.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_dispatch") as mock_dispatch: + proxy.delay() + + # Assert + mock_dispatch.assert_called_once() + + def test_should_use_feature_service_for_billing_info(self) -> None: + """Verify that FeatureService.get_features is consulted during dispatch.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + with patch.object(proxy, "_send_to_priority_direct_queue"): + # Act + proxy._dispatch() + + # Assert + mock_features.assert_called_once_with(TENANT_ID) + + +class TestBaseRouterHelpers: + """Tests for the three routing helper methods from the base class.""" + + def test_should_call_send_to_tenant_queue_with_normal_func_when_default_tenant_queue(self) -> None: + """_send_to_default_tenant_queue must forward NORMAL_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_default_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.NORMAL_TASK_FUNC) + + def test_should_call_send_to_tenant_queue_with_priority_func_when_priority_tenant_queue(self) -> None: + """_send_to_priority_tenant_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_priority_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) + + def test_should_call_send_to_direct_queue_with_priority_func_when_priority_direct_queue(self) -> None: + """_send_to_priority_direct_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_direct_queue") as mock_method: + proxy._send_to_priority_direct_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) diff --git a/api/tests/unit_tests/services/test_billing_service.py b/api/tests/unit_tests/services/test_billing_service.py index eecb3c7672..316381f0ca 100644 --- a/api/tests/unit_tests/services/test_billing_service.py +++ b/api/tests/unit_tests/services/test_billing_service.py @@ -1303,6 +1303,24 @@ class TestBillingServiceSubscriptionOperations: # Assert assert result == {} + def test_get_plan_bulk_converts_string_expiration_date_to_int(self, mock_send_request): + """Test bulk plan retrieval converts string expiration_date to int.""" + # Arrange + tenant_ids = ["tenant-1"] + mock_send_request.return_value = { + "data": { + "tenant-1": {"plan": "sandbox", "expiration_date": "1735689600"}, + } + } + + # Act + result = BillingService.get_plan_bulk(tenant_ids) + + # Assert + assert "tenant-1" in result + assert isinstance(result["tenant-1"]["expiration_date"], int) + assert result["tenant-1"]["expiration_date"] == 1735689600 + def test_get_plan_bulk_with_invalid_tenant_plan_skipped(self, mock_send_request): """Test bulk plan retrieval when one tenant has invalid plan data (should skip that tenant).""" # Arrange diff --git a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py index 50826d6798..6bf78d3411 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_expired_workflow_run_logs.py @@ -265,6 +265,61 @@ def test_run_exits_on_empty_batch(monkeypatch: pytest.MonkeyPatch) -> None: cleanup.run() +def test_run_records_metrics_on_success(monkeypatch: pytest.MonkeyPatch) -> None: + cutoff = datetime.datetime.now() + repo = FakeRepo( + batches=[[FakeRun("run-free", "t_free", cutoff)]], + delete_result={ + "runs": 0, + "node_executions": 2, + "offloads": 1, + "app_logs": 3, + "trigger_logs": 4, + "pauses": 5, + "pause_reasons": 6, + }, + ) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + batch_calls: list[dict[str, object]] = [] + completion_calls: list[dict[str, object]] = [] + monkeypatch.setattr(cleanup._metrics, "record_batch", lambda **kwargs: batch_calls.append(kwargs)) + monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs)) + + cleanup.run() + + assert len(batch_calls) == 1 + assert batch_calls[0]["batch_rows"] == 1 + assert batch_calls[0]["targeted_runs"] == 1 + assert batch_calls[0]["deleted_runs"] == 1 + assert batch_calls[0]["related_action"] == "deleted" + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "success" + + +def test_run_records_failed_metrics(monkeypatch: pytest.MonkeyPatch) -> None: + class FailingRepo(FakeRepo): + def delete_runs_with_related( + self, runs: list[FakeRun], delete_node_executions=None, delete_trigger_logs=None + ) -> dict[str, int]: + raise RuntimeError("delete failed") + + cutoff = datetime.datetime.now() + repo = FailingRepo(batches=[[FakeRun("run-free", "t_free", cutoff)]]) + cleanup = create_cleanup(monkeypatch, repo=repo, days=30, batch_size=10) + monkeypatch.setattr(cleanup_module.dify_config, "BILLING_ENABLED", False) + + completion_calls: list[dict[str, object]] = [] + monkeypatch.setattr(cleanup._metrics, "record_completion", lambda **kwargs: completion_calls.append(kwargs)) + + with pytest.raises(RuntimeError, match="delete failed"): + cleanup.run() + + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "failed" + + def test_run_dry_run_skips_deletions(monkeypatch: pytest.MonkeyPatch, capsys: pytest.CaptureFixture[str]) -> None: cutoff = datetime.datetime.now() repo = FakeRepo( diff --git a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py index 5099362e00..1926cb133a 100644 --- a/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py +++ b/api/tests/unit_tests/services/test_clear_free_plan_tenant_expired_logs.py @@ -1,9 +1,12 @@ import datetime -from unittest.mock import Mock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, patch import pytest from sqlalchemy.orm import Session +from enums.cloud_plan import CloudPlan +from services import clear_free_plan_tenant_expired_logs as service_module from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs @@ -156,13 +159,453 @@ class TestClearFreePlanTenantExpiredLogs: # Should call delete for each table that has records assert mock_session.query.return_value.where.return_value.delete.called - def test_clear_message_related_tables_logging_output( - self, mock_session, sample_message_ids, sample_records, capsys + def test_clear_message_related_tables_all_serialization_fails_skips_backup_but_deletes( + self, mock_session, sample_message_ids ): - """Test that logging output is generated.""" + record = Mock() + record.id = "record-1" + record.to_dict.side_effect = Exception("Serialization error") + with patch("services.clear_free_plan_tenant_expired_logs.storage") as mock_storage: - mock_session.query.return_value.where.return_value.all.return_value = sample_records + mock_session.query.return_value.where.return_value.all.return_value = [record] ClearFreePlanTenantExpiredLogs._clear_message_related_tables(mock_session, "tenant-123", sample_message_ids) - pass + mock_storage.save.assert_not_called() + assert mock_session.query.return_value.where.return_value.delete.called + + +class _ImmediateFuture: + def __init__(self, fn, args, kwargs): + self._fn = fn + self._args = args + self._kwargs = kwargs + + def result(self): + return self._fn(*self._args, **self._kwargs) + + +class _ImmediateExecutor: + def __init__(self, *args, **kwargs) -> None: + self.submitted: list[tuple[object, tuple[object, ...], dict[str, object]]] = [] + + def submit(self, fn, *args, **kwargs): + self.submitted.append((fn, args, kwargs)) + return _ImmediateFuture(fn, args, kwargs) + + +def _session_wrapper_for_no_autoflush(session: Mock) -> Mock: + """ + ClearFreePlanTenantExpiredLogs.process_tenant uses: + with Session(db.engine).no_autoflush as session: + so Session(db.engine) must return an object with a no_autoflush context manager. + """ + cm = MagicMock() + cm.__enter__.return_value = session + cm.__exit__.return_value = None + + wrapper = MagicMock() + wrapper.no_autoflush = cm + return wrapper + + +def _session_wrapper_for_direct(session: Mock) -> Mock: + """ClearFreePlanTenantExpiredLogs.process uses: with Session(db.engine) as session:""" + wrapper = MagicMock() + wrapper.__enter__.return_value = session + wrapper.__exit__.return_value = None + return wrapper + + +def test_process_tenant_processes_all_batches(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace( + scalars=lambda _stmt: SimpleNamespace( + all=lambda: [SimpleNamespace(id="app-1"), SimpleNamespace(id="app-2")] + ) + ), + ), + ) + + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + + clear_related = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", clear_related) + + # Session sequence for messages, conversations, workflow_app_logs loops: + # - messages: one batch then empty + # - conversations: one batch then empty + # - workflow app logs: one batch then empty + msg1 = SimpleNamespace(id="m1", to_dict=lambda: {"id": "m1"}) + conv1 = SimpleNamespace(id="c1", to_dict=lambda: {"id": "c1"}) + log1 = SimpleNamespace(id="l1", to_dict=lambda: {"id": "l1"}) + + def make_query_with_batches(batches: list[list[object]]): + q = MagicMock() + q.where.return_value = q + q.limit.return_value = q + q.all.side_effect = batches + q.delete.return_value = 1 + return q + + msg_session_1 = MagicMock() + msg_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[msg1], []]) if model == service_module.Message else MagicMock() + ) + msg_session_1.commit.return_value = None + + msg_session_2 = MagicMock() + msg_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.Message else MagicMock() + ) + msg_session_2.commit.return_value = None + + conv_session_1 = MagicMock() + conv_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[conv1], []]) if model == service_module.Conversation else MagicMock() + ) + conv_session_1.commit.return_value = None + + conv_session_2 = MagicMock() + conv_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.Conversation else MagicMock() + ) + conv_session_2.commit.return_value = None + + wal_session_1 = MagicMock() + wal_session_1.query.side_effect = lambda model: ( + make_query_with_batches([[log1], []]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_1.commit.return_value = None + + wal_session_2 = MagicMock() + wal_session_2.query.side_effect = lambda model: ( + make_query_with_batches([[]]) if model == service_module.WorkflowAppLog else MagicMock() + ) + wal_session_2.commit.return_value = None + + session_wrappers = [ + _session_wrapper_for_no_autoflush(msg_session_1), + _session_wrapper_for_no_autoflush(msg_session_2), + _session_wrapper_for_no_autoflush(conv_session_1), + _session_wrapper_for_no_autoflush(conv_session_2), + _session_wrapper_for_no_autoflush(wal_session_1), + _session_wrapper_for_no_autoflush(wal_session_2), + ] + + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repositories for workflow node executions and workflow runs + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [[SimpleNamespace(id="ne-1")], []] + node_repo.delete_executions_by_ids.return_value = 1 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [[SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"})], []] + run_repo.delete_runs_by_ids.return_value = 1 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=10) + + # messages backup, conversations backup, node executions backup, runs backup, workflow app logs backup + assert mock_storage.save.call_count >= 5 + clear_related.assert_called() + + +def test_process_with_tenant_ids_filters_by_plan_and_logs_errors(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + # Total tenant count query + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 2 + count_session.query.return_value = count_query + + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", True) + + def fake_get_info(tenant_id: str): + if tenant_id == "t_sandbox": + return {"subscription": {"plan": CloudPlan.SANDBOX}} + if tenant_id == "t_fail": + raise RuntimeError("boom") + return {"subscription": {"plan": "team"}} + + monkeypatch.setattr(service_module.BillingService, "get_info", staticmethod(fake_get_info)) + + process_tenant_mock = MagicMock(side_effect=lambda *_args, **_kwargs: (_ for _ in ()).throw(RuntimeError("err"))) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + logger_exc = MagicMock() + monkeypatch.setattr(service_module.logger, "exception", logger_exc) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=["t_sandbox", "t_paid", "t_fail"]) + + # Only sandbox tenant should attempt processing, and its failure should be swallowed + logged. + assert process_tenant_mock.call_count == 1 + assert logger_exc.call_count >= 1 + + +def test_process_without_tenant_ids_batches_and_scales_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + fixed_now = started_at + datetime.timedelta(hours=2) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + # Avoid LocalProxy usage + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + # Sessions used: + # 1) total tenant count + # 2) per-batch tenant scan (count + tenant list) + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + q1 = MagicMock() + q1.where.return_value = q1 + q1.count.return_value = 200 + q2 = MagicMock() + q2.where.return_value = q2 + q2.count.return_value = 200 + q3 = MagicMock() + q3.where.return_value = q3 + q3.count.return_value = 200 + q4 = MagicMock() + q4.where.return_value = q4 + q4.count.return_value = 50 # choose this interval, then scale it + + rows = [SimpleNamespace(id="tenant-a"), SimpleNamespace(id="tenant-b")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [q1, q2, q3, q4, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + # Should submit/process tenants from the batch query + assert process_tenant_mock.call_count == 2 + + +def test_process_with_tenant_ids_emits_progress_every_100(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + + count_session = MagicMock() + count_query = MagicMock() + count_query.count.return_value = 100 + count_session.query.return_value = count_query + monkeypatch.setattr(service_module, "Session", lambda _engine: _session_wrapper_for_direct(count_session)) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + echo_mock = MagicMock() + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", echo_mock) + + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", MagicMock()) + + tenant_ids = [f"t{i}" for i in range(100)] + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=tenant_ids) + + assert any("Processed 100 tenants" in str(call.args[0]) for call in echo_mock.call_args_list) + + +def test_process_without_tenant_ids_all_intervals_too_many_uses_min_interval(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(service_module, "db", SimpleNamespace(engine=object())) + monkeypatch.setattr(service_module.dify_config, "BILLING_ENABLED", False) + + started_at = datetime.datetime(2023, 4, 3, 8, 59, 24) + # Keep the total range smaller than the minimum interval (1 hour) so the loop runs once. + fixed_now = started_at + datetime.timedelta(minutes=30) + + class FixedDateTime(datetime.datetime): + @classmethod + def now(cls, tz=None): + return fixed_now + + monkeypatch.setattr(service_module.datetime, "datetime", FixedDateTime) + + flask_app = service_module.Flask("test-app") + monkeypatch.setattr(service_module, "current_app", SimpleNamespace(_get_current_object=lambda: flask_app)) + + executor = _ImmediateExecutor() + monkeypatch.setattr(service_module, "ThreadPoolExecutor", lambda **_kwargs: executor) + + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + + total_session = MagicMock() + total_query = MagicMock() + total_query.count.return_value = 250 + total_session.query.return_value = total_query + + batch_session = MagicMock() + # Count results for all 5 intervals, all > 100 => take the for-else path. + count_queries = [] + for _ in range(5): + q = MagicMock() + q.where.return_value = q + q.count.return_value = 200 + count_queries.append(q) + + rows = [SimpleNamespace(id="tenant-a")] + q_rs = MagicMock() + q_rs.where.return_value = q_rs + q_rs.order_by.return_value = rows + + batch_session.query.side_effect = [*count_queries, q_rs] + + sessions = [_session_wrapper_for_direct(total_session), _session_wrapper_for_direct(batch_session)] + monkeypatch.setattr(service_module, "Session", lambda _engine: sessions.pop(0)) + + process_tenant_mock = MagicMock() + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "process_tenant", process_tenant_mock) + + ClearFreePlanTenantExpiredLogs.process(days=7, batch=10, tenant_ids=[]) + + assert process_tenant_mock.call_count == 1 + assert len(count_queries) == 5 + assert batch_session.query.call_count >= 6 + + +def test_process_tenant_repo_loops_break_on_empty_second_batch(monkeypatch: pytest.MonkeyPatch) -> None: + flask_app = service_module.Flask("test-app") + + monkeypatch.setattr( + service_module, + "db", + SimpleNamespace( + engine=object(), + session=SimpleNamespace(scalars=lambda _stmt: SimpleNamespace(all=lambda: [SimpleNamespace(id="app-1")])), + ), + ) + mock_storage = MagicMock() + monkeypatch.setattr(service_module, "storage", mock_storage) + monkeypatch.setattr(service_module.click, "echo", lambda *_args, **_kwargs: None) + monkeypatch.setattr(service_module.click, "style", lambda msg, **_kwargs: msg) + monkeypatch.setattr(ClearFreePlanTenantExpiredLogs, "_clear_message_related_tables", MagicMock()) + + # Make message/conversation/workflow_app_log loops no-op (empty immediately) + empty_session = MagicMock() + q_empty = MagicMock() + q_empty.where.return_value = q_empty + q_empty.limit.return_value = q_empty + q_empty.all.return_value = [] + empty_session.query.return_value = q_empty + empty_session.commit.return_value = None + session_wrappers = [ + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + _session_wrapper_for_no_autoflush(empty_session), + ] + monkeypatch.setattr(service_module, "Session", lambda _engine: session_wrappers.pop(0)) + + def fake_select(*_args, **_kwargs): + stmt = MagicMock() + stmt.where.return_value = stmt + return stmt + + monkeypatch.setattr(service_module, "select", fake_select) + + # Repos: first returns exactly batch items -> no "< batch" break, second returns [] -> hit the len==0 break. + node_repo = MagicMock() + node_repo.get_expired_executions_batch.side_effect = [ + [SimpleNamespace(id="ne-1"), SimpleNamespace(id="ne-2")], + [], + ] + node_repo.delete_executions_by_ids.return_value = 2 + + run_repo = MagicMock() + run_repo.get_expired_runs_batch.side_effect = [ + [ + SimpleNamespace(id="wr-1", to_dict=lambda: {"id": "wr-1"}), + SimpleNamespace(id="wr-2", to_dict=lambda: {"id": "wr-2"}), + ], + [], + ] + run_repo.delete_runs_by_ids.return_value = 2 + + monkeypatch.setattr(service_module, "sessionmaker", lambda **_kwargs: object()) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_node_execution_repository", + lambda _sm: node_repo, + ) + monkeypatch.setattr( + service_module.DifyAPIRepositoryFactory, + "create_api_workflow_run_repository", + lambda _sm: run_repo, + ) + + ClearFreePlanTenantExpiredLogs.process_tenant(flask_app, "tenant-1", days=7, batch=2) + + assert node_repo.get_expired_executions_batch.call_count == 2 + assert run_repo.get_expired_runs_batch.call_count == 2 diff --git a/api/tests/unit_tests/services/test_code_based_extension_service.py b/api/tests/unit_tests/services/test_code_based_extension_service.py new file mode 100644 index 0000000000..f6538a140a --- /dev/null +++ b/api/tests/unit_tests/services/test_code_based_extension_service.py @@ -0,0 +1,89 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from services.code_based_extension_service import CodeBasedExtensionService + + +class TestCodeBasedExtensionService: + def test_should_return_only_non_builtin_extensions_with_public_fields(self, monkeypatch: pytest.MonkeyPatch): + """Test service returns only non-builtin extensions with name/label/form_schema fields.""" + moderation_extension = SimpleNamespace( + name="custom-moderation", + label={"en-US": "Custom Moderation"}, + form_schema=[{"variable": "api_key"}], + builtin=False, + extension_class=object, + position=20, + ) + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + extension_class=object, + position=1, + ) + retrieval_extension = SimpleNamespace( + name="custom-retrieval", + label={"en-US": "Custom Retrieval"}, + form_schema=None, + builtin=False, + extension_class=object, + position=30, + ) + module_extensions_mock = MagicMock(return_value=[moderation_extension, builtin_extension, retrieval_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("external_data_tool") + + assert result == [ + { + "name": "custom-moderation", + "label": {"en-US": "Custom Moderation"}, + "form_schema": [{"variable": "api_key"}], + }, + { + "name": "custom-retrieval", + "label": {"en-US": "Custom Retrieval"}, + "form_schema": None, + }, + ] + assert set(result[0].keys()) == {"name", "label", "form_schema"} + module_extensions_mock.assert_called_once_with("external_data_tool") + + def test_should_return_empty_list_when_all_extensions_are_builtin(self, monkeypatch: pytest.MonkeyPatch): + """Test builtin extensions are filtered out completely.""" + builtin_extension = SimpleNamespace( + name="builtin-moderation", + label={"en-US": "Builtin Moderation"}, + form_schema=[{"variable": "token"}], + builtin=True, + ) + module_extensions_mock = MagicMock(return_value=[builtin_extension]) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + result = CodeBasedExtensionService.get_code_based_extension("moderation") + + assert result == [] + module_extensions_mock.assert_called_once_with("moderation") + + def test_should_propagate_error_when_module_extensions_lookup_fails(self, monkeypatch: pytest.MonkeyPatch): + """Test ValueError from extension lookup bubbles up unchanged.""" + module_extensions_mock = MagicMock(side_effect=ValueError("Extension Module invalid-module not found")) + monkeypatch.setattr( + "services.code_based_extension_service.code_based_extension.module_extensions", + module_extensions_mock, + ) + + with pytest.raises(ValueError, match="Extension Module invalid-module not found"): + CodeBasedExtensionService.get_code_based_extension("invalid-module") + + module_extensions_mock.assert_called_once_with("invalid-module") diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index d8ecdf45fd..35157790ca 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -1,18 +1,30 @@ """ Comprehensive unit tests for ConversationService. -This file keeps non-SQL guard/unit tests. -SQL-related tests were migrated to testcontainers integration tests. +This file provides complete test coverage for all ConversationService methods. +Tests are organized by functionality and include edge cases, error handling, +and both positive and negative test scenarios. """ -from datetime import datetime +from datetime import datetime, timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch +import pytest +from sqlalchemy import asc, desc + from core.app.entities.app_invoke_entities import InvokeFrom -from models import Account -from models.model import App, Conversation, EndUser +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models import Account, ConversationVariable +from models.enums import ConversationFromSource +from models.model import App, Conversation, EndUser, Message from services.conversation_service import ConversationService -from services.message_service import MessageService +from services.errors.conversation import ( + ConversationNotExistsError, + ConversationVariableNotExistsError, + ConversationVariableTypeMismatchError, + LastConversationNotExistsError, +) +from services.errors.message import MessageNotExistsError class ConversationServiceTestDataFactory: @@ -116,6 +128,84 @@ class ConversationServiceTestDataFactory: setattr(conversation, key, value) return conversation + @staticmethod + def create_message_mock( + message_id: str = "msg-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock Message object. + + Args: + message_id: Unique identifier for the message + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock Message object with specified attributes + """ + message = create_autospec(Message, instance=True) + message.id = message_id + message.conversation_id = conversation_id + message.app_id = app_id + message.query = kwargs.get("query", "Test message content") + message.created_at = kwargs.get("created_at", datetime.utcnow()) + for key, value in kwargs.items(): + setattr(message, key, value) + return message + + @staticmethod + def create_conversation_variable_mock( + variable_id: str = "var-123", + conversation_id: str = "conv-123", + app_id: str = "app-123", + **kwargs, + ) -> Mock: + """ + Create a mock ConversationVariable object. + + Args: + variable_id: Unique identifier for the variable + conversation_id: Associated conversation identifier + app_id: Associated app identifier + **kwargs: Additional attributes to set on the mock + + Returns: + Mock ConversationVariable object with specified attributes + """ + variable = create_autospec(ConversationVariable, instance=True) + variable.id = variable_id + variable.conversation_id = conversation_id + variable.app_id = app_id + variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} + variable.created_at = kwargs.get("created_at", datetime.utcnow()) + variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + + # Mock to_variable method + mock_variable = Mock() + mock_variable.id = variable_id + mock_variable.name = kwargs.get("name", "test_var") + mock_variable.value_type = kwargs.get("value_type", "string") + mock_variable.value = kwargs.get("value", "test_value") + mock_variable.description = kwargs.get("description", "") + mock_variable.selector = kwargs.get("selector", {}) + mock_variable.model_dump.return_value = { + "id": variable_id, + "name": kwargs.get("name", "test_var"), + "value_type": kwargs.get("value_type", "string"), + "value": kwargs.get("value", "test_value"), + "description": kwargs.get("description", ""), + "selector": kwargs.get("selector", {}), + } + variable.to_variable.return_value = mock_variable + + for key, value in kwargs.items(): + setattr(variable, key, value) + return variable + class TestConversationServicePagination: """Test conversation pagination operations.""" @@ -175,99 +265,958 @@ class TestConversationServicePagination: assert result.limit == 20 -class TestConversationServiceMessageCreation: - """ - Test message creation and pagination. +class TestConversationServiceHelpers: + """Test helper methods in ConversationService.""" - Tests MessageService operations for creating and retrieving messages - within conversations. - """ - - def test_pagination_returns_empty_when_no_user(self): + def test_get_sort_params_with_descending_sort(self): """ - Test that pagination returns empty result when user is None. + Test _get_sort_params with descending sort prefix. - This ensures proper handling of unauthenticated requests. + When sort_by starts with '-', should return field name and desc function. + """ + # Act + field, direction = ConversationService._get_sort_params("-updated_at") + + # Assert + assert field == "updated_at" + assert direction == desc + + def test_get_sort_params_with_ascending_sort(self): + """ + Test _get_sort_params with ascending sort. + + When sort_by doesn't start with '-', should return field name and asc function. + """ + # Act + field, direction = ConversationService._get_sort_params("created_at") + + # Assert + assert field == "created_at" + assert direction == asc + + def test_build_filter_condition_with_descending_sort(self): + """ + Test _build_filter_condition with descending sort direction. + + Should create a less-than filter condition. """ # Arrange - app_model = ConversationServiceTestDataFactory.create_app_mock() + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.updated_at = datetime.utcnow() # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=None, - conversation_id="conv-123", - first_id=None, - limit=10, + condition = ConversationService._build_filter_condition( + sort_field="updated_at", + sort_direction=desc, + reference_conversation=mock_conversation, ) # Assert - assert result.data == [] - assert result.has_more is False + # The condition should be a comparison expression + assert condition is not None - def test_pagination_returns_empty_when_no_conversation_id(self): + def test_build_filter_condition_with_ascending_sort(self): """ - Test that pagination returns empty result when conversation_id is None. + Test _build_filter_condition with ascending sort direction. - This ensures proper handling of invalid requests. + Should create a greater-than filter condition. + """ + # Arrange + mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_conversation.created_at = datetime.utcnow() + + # Act + condition = ConversationService._build_filter_condition( + sort_field="created_at", + sort_direction=asc, + reference_conversation=mock_conversation, + ) + + # Assert + # The condition should be a comparison expression + assert condition is not None + + +class TestConversationServiceGetConversation: + """Test conversation retrieval operations.""" + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_account(self, mock_db_session): + """ + Test successful conversation retrieval with account user. + + Should return conversation when found with proper filters. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_account_id=user.id, from_source=ConversationFromSource.CONSOLE + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + mock_db_session.query.assert_called_once_with(Conversation) + + @patch("services.conversation_service.db.session") + def test_get_conversation_success_with_end_user(self, mock_db_session): + """ + Test successful conversation retrieval with end user. + + Should return conversation when found with proper filters for API user. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_end_user_id=user.id, from_source=ConversationFromSource.API + ) + + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = conversation + + # Act + result = ConversationService.get_conversation(app_model, "conv-123", user) + + # Assert + assert result == conversation + + @patch("services.conversation_service.db.session") + def test_get_conversation_not_found_raises_error(self, mock_db_session): + """ + Test that get_conversation raises error when conversation not found. + + Should raise ConversationNotExistsError when no matching conversation found. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() - # Act - result = MessageService.pagination_by_first_id( - app_model=app_model, - user=user, - conversation_id="", - first_id=None, - limit=10, - ) + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.first.return_value = None - # Assert - assert result.data == [] - assert result.has_more is False + # Act & Assert + with pytest.raises(ConversationNotExistsError): + ConversationService.get_conversation(app_model, "conv-123", user) -class TestConversationServiceSummarization: - """ - Test conversation summarization (auto-generated names). +class TestConversationServiceRename: + """Test conversation rename operations.""" - Tests the auto_generate_name functionality that creates conversation - titles based on the first message. - """ - - @patch("services.conversation_service.db.session", autospec=True) - @patch("services.conversation_service.ConversationService.get_conversation", autospec=True) - @patch("services.conversation_service.ConversationService.auto_generate_name", autospec=True) - def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_rename_with_manual_name(self, mock_get_conversation, mock_db_session): """ - Test renaming conversation with auto-generation enabled. + Test renaming conversation with manual name. - When auto_generate is True, the service should call the auto_generate_name - method to generate a new name for the conversation. + Should update conversation name and timestamp when auto_generate is False. """ # Arrange app_model = ConversationServiceTestDataFactory.create_app_mock() user = ConversationServiceTestDataFactory.create_account_mock() conversation = ConversationServiceTestDataFactory.create_conversation_mock() - conversation.name = "Auto-generated Name" - # Mock the conversation lookup to return our test conversation mock_get_conversation.return_value = conversation - # Mock the auto_generate_name method to return the conversation + # Act + result = ConversationService.rename( + app_model=app_model, + conversation_id="conv-123", + user=user, + name="New Name", + auto_generate=False, + ) + + # Assert + assert result == conversation + assert conversation.name == "New Name" + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.ConversationService.auto_generate_name") + def test_rename_with_auto_generate(self, mock_auto_generate, mock_get_conversation, mock_db_session): + """ + Test renaming conversation with auto-generation. + + Should call auto_generate_name when auto_generate is True. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation mock_auto_generate.return_value = conversation # Act result = ConversationService.rename( app_model=app_model, - conversation_id=conversation.id, + conversation_id="conv-123", user=user, - name="", + name=None, auto_generate=True, ) # Assert - mock_auto_generate.assert_called_once_with(app_model, conversation) assert result == conversation + mock_auto_generate.assert_called_once_with(app_model, conversation) + + +class TestConversationServiceAutoGenerateName: + """Test conversation auto-name generation operations.""" + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_success(self, mock_llm_generator, mock_db_session): + """ + Test successful auto-generation of conversation name. + + Should generate name using LLMGenerator and update conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator + mock_llm_generator.generate_conversation_name.return_value = "Generated Name" + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + assert conversation.name == "Generated Name" + mock_llm_generator.generate_conversation_name.assert_called_once_with( + app_model.tenant_id, message.query, conversation.id, app_model.id + ) + mock_db_session.commit.assert_called_once() + + @patch("services.conversation_service.db.session") + def test_auto_generate_name_no_message_raises_error(self, mock_db_session): + """ + Test auto-generation fails when no message found. + + Should raise MessageNotExistsError when conversation has no messages. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + # Mock database query to return None + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + ConversationService.auto_generate_name(app_model, conversation) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.LLMGenerator") + def test_auto_generate_name_handles_llm_exception(self, mock_llm_generator, mock_db_session): + """ + Test auto-generation handles LLM generator exceptions gracefully. + + Should continue without name when LLMGenerator fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + message = ConversationServiceTestDataFactory.create_message_mock( + conversation_id=conversation.id, app_id=app_model.id + ) + + # Mock database query to return message + mock_query = mock_db_session.query.return_value + mock_query.where.return_value.order_by.return_value.first.return_value = message + + # Mock LLM generator to raise exception + mock_llm_generator.generate_conversation_name.side_effect = Exception("LLM Error") + + # Act + result = ConversationService.auto_generate_name(app_model, conversation) + + # Assert + assert result == conversation + # Name should remain unchanged due to exception + mock_db_session.commit.assert_called_once() + + +class TestConversationServiceDelete: + """Test conversation deletion operations.""" + + @patch("services.conversation_service.delete_conversation_related_data") + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_success(self, mock_get_conversation, mock_db_session, mock_delete_task): + """ + Test successful conversation deletion. + + Should delete conversation and schedule cleanup task. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock(name="Test App") + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Act + ConversationService.delete(app_model, "conv-123", user) + + # Assert + mock_db_session.delete.assert_called_once_with(conversation) + mock_db_session.commit.assert_called_once() + mock_delete_task.delay.assert_called_once_with(conversation.id) + + @patch("services.conversation_service.db.session") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_delete_handles_exception_and_rollback(self, mock_get_conversation, mock_db_session): + """ + Test deletion handles exceptions and rolls back transaction. + + Should rollback database changes when deletion fails. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_db_session.delete.side_effect = Exception("Database Error") + + # Act & Assert + with pytest.raises(Exception, match="Database Error"): + ConversationService.delete(app_model, "conv-123", user) + + # Assert rollback was called + mock_db_session.rollback.assert_called_once() + + +class TestConversationServiceConversationalVariable: + """Test conversational variable operations.""" + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_success(self, mock_get_conversation, mock_session_factory): + """ + Test successful retrieval of conversational variables. + + Should return paginated list of variables for conversation. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + variable1 = ConversationServiceTestDataFactory.create_conversation_variable_mock() + variable2 = ConversationServiceTestDataFactory.create_conversation_variable_mock(variable_id="var-456") + + mock_session.scalars.return_value.all.return_value = [variable1, variable2] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 2 + assert result.limit == 10 + assert result.has_more is False + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_with_last_id(self, mock_get_conversation, mock_session_factory): + """ + Test retrieval of variables with last_id pagination. + + Should filter variables created after last_id. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and variables + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( + created_at=datetime.utcnow() - timedelta(hours=1) + ) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + + mock_session.scalar.return_value = last_variable + mock_session.scalars.return_value.all.return_value = [variable] + + # Act + result = ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="var-123", + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + assert result.limit == 10 + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_get_conversational_variable_last_id_not_found_raises_error( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that invalid last_id raises ConversationVariableNotExistsError. + + Should raise error when last_id doesn't exist. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id="invalid-id", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_mysql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for MySQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "mysql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + @patch("services.conversation_service.dify_config") + def test_get_conversational_variable_with_name_filter_postgresql( + self, mock_config, mock_get_conversation, mock_session_factory + ): + """ + Test variable filtering by name for PostgreSQL databases. + + Should apply JSON extraction filter for variable names. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + mock_config.DB_TYPE = "postgresql" + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [] + + # Act + ConversationService.get_conversational_variable( + app_model=app_model, + conversation_id="conv-123", + user=user, + limit=10, + last_id=None, + variable_name="test_var", + ) + + # Assert - JSON filter should be applied + assert mock_session.scalars.called + + +class TestConversationServiceUpdateVariable: + """Test conversation variable update operations.""" + + @patch("services.conversation_service.variable_factory") + @patch("services.conversation_service.ConversationVariableUpdater") + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_success( + self, mock_get_conversation, mock_session_factory, mock_updater_class, mock_variable_factory + ): + """ + Test successful update of conversation variable. + + Should update variable value and return updated data. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="string") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": "new_value"} + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="new_value", + ) + + # Assert + assert result["id"] == "var-123" + assert result["value"] == "new_value" + mock_updater.update.assert_called_once_with("conv-123", updated_variable) + mock_updater.flush.assert_called_once() + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_not_found_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when variable doesn't exist. + + Should raise ConversationVariableNotExistsError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act & Assert + with pytest.raises(ConversationVariableNotExistsError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="invalid-id", + user=user, + new_value="new_value", + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_type_mismatch_raises_error(self, mock_get_conversation, mock_session_factory): + """ + Test update fails when value type doesn't match expected type. + + Should raise ConversationVariableTypeMismatchError. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="number") + mock_session.scalar.return_value = existing_variable + + # Act & Assert - Try to set string value for number variable + with pytest.raises(ConversationVariableTypeMismatchError): + ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value="string_value", # Wrong type + ) + + @patch("services.conversation_service.session_factory") + @patch("services.conversation_service.ConversationService.get_conversation") + def test_update_conversation_variable_integer_number_compatibility( + self, mock_get_conversation, mock_session_factory + ): + """ + Test that integer type accepts number values. + + Should allow number values for integer type variables. + """ + # Arrange + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + + mock_get_conversation.return_value = conversation + + # Mock session and existing variable + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + existing_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(value_type="integer") + mock_session.scalar.return_value = existing_variable + + # Mock variable factory and updater + updated_variable = Mock() + updated_variable.model_dump.return_value = {"id": "var-123", "name": "test_var", "value": 42} + + with ( + patch("services.conversation_service.variable_factory") as mock_variable_factory, + patch("services.conversation_service.ConversationVariableUpdater") as mock_updater_class, + ): + mock_variable_factory.build_conversation_variable_from_mapping.return_value = updated_variable + mock_updater = MagicMock() + mock_updater_class.return_value = mock_updater + + # Act + result = ConversationService.update_conversation_variable( + app_model=app_model, + conversation_id="conv-123", + variable_id="var-123", + user=user, + new_value=42, # Number value for integer type + ) + + # Assert + assert result["value"] == 42 + mock_updater.update.assert_called_once() + + +class TestConversationServicePaginationAdvanced: + """Advanced pagination tests for ConversationService.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_last_id_not_found(self, mock_session_factory): + """ + Test pagination with invalid last_id raises error. + + Should raise LastConversationNotExistsError when last_id doesn't exist. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act & Assert + with pytest.raises(LastConversationNotExistsError): + ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id="invalid-id", + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_exclude_ids(self, mock_session_factory): + """ + Test pagination with exclude_ids filter. + + Should exclude specified conversation IDs from results. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=["excluded-123"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert len(result.data) == 1 + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_has_more_detection(self, mock_session_factory): + """ + Test pagination has_more detection logic. + + Should set has_more=True when there are more results beyond limit. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # Return exactly limit items to trigger has_more check + conversations = [ + ConversationServiceTestDataFactory.create_conversation_mock(conversation_id=f"conv-{i}") for i in range(20) + ] + mock_session.scalars.return_value.all.return_value = conversations + mock_session.scalar.return_value = conversations[-1] + + # Mock count query to return > 0 + mock_session.scalar.return_value = 5 # Additional items exist + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is True + + @patch("services.conversation_service.session_factory") + def test_pagination_by_last_id_with_different_sort_by(self, mock_session_factory): + """ + Test pagination with different sort fields. + + Should handle various sort_by parameters correctly. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock() + mock_session.scalars.return_value.all.return_value = [conversation] + mock_session.scalar.return_value = conversation + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Test different sort fields + sort_fields = ["created_at", "-updated_at", "name", "-status"] + + for sort_by in sort_fields: + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + sort_by=sort_by, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + +class TestConversationServiceEdgeCases: + """Test edge cases and error scenarios.""" + + @patch("services.conversation_service.session_factory") + def test_pagination_with_end_user_api_source(self, mock_session_factory): + """ + Test pagination correctly handles EndUser with API source. + + Should use 'api' as from_source for EndUser instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source=ConversationFromSource.API, from_end_user_id="user-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_end_user_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + @patch("services.conversation_service.session_factory") + def test_pagination_with_account_console_source(self, mock_session_factory): + """ + Test pagination correctly handles Account with console source. + + Should use 'console' as from_source for Account instances. + """ + # Arrange + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + conversation = ConversationServiceTestDataFactory.create_conversation_mock( + from_source=ConversationFromSource.CONSOLE, from_account_id="account-123" + ) + mock_session.scalars.return_value.all.return_value = [conversation] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + + def test_pagination_with_include_ids_filter(self): + """ + Test pagination with include_ids filter. + + Should only return conversations with IDs in include_ids list. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + include_ids=["conv-123", "conv-456"], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + # Verify that include_ids filter was applied + assert mock_session.scalars.called + + def test_pagination_with_empty_exclude_ids(self): + """ + Test pagination with empty exclude_ids list. + + Should handle empty exclude_ids gracefully. + """ + # Arrange + mock_session = MagicMock() + mock_session.scalars.return_value.all.return_value = [] + + app_model = ConversationServiceTestDataFactory.create_app_mock() + user = ConversationServiceTestDataFactory.create_account_mock() + + # Act + result = ConversationService.pagination_by_last_id( + session=mock_session, + app_model=app_model, + user=user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + exclude_ids=[], + ) + + # Assert + assert isinstance(result, InfiniteScrollPagination) + assert result.has_more is False diff --git a/api/tests/unit_tests/services/test_dataset_permission.py b/api/tests/unit_tests/services/test_dataset_permission.py deleted file mode 100644 index 4974d6c1ef..0000000000 --- a/api/tests/unit_tests/services/test_dataset_permission.py +++ /dev/null @@ -1,305 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset, DatasetPermission, DatasetPermissionEnum -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - - -class DatasetPermissionTestDataFactory: - """Factory class for creating test data and mock objects for dataset permission tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - permission: DatasetPermissionEnum = DatasetPermissionEnum.ONLY_ME, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.permission = permission - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.NORMAL, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - @staticmethod - def create_dataset_permission_mock( - dataset_id: str = "dataset-123", - account_id: str = "user-789", - **kwargs, - ) -> Mock: - """Create a mock dataset permission record.""" - permission = Mock(spec=DatasetPermission) - permission.dataset_id = dataset_id - permission.account_id = account_id - for key, value in kwargs.items(): - setattr(permission, key, value) - return permission - - -class TestDatasetPermissionService: - """ - Comprehensive unit tests for DatasetService.check_dataset_permission method. - - This test suite covers all permission scenarios including: - - Cross-tenant access restrictions - - Owner privilege checks - - Different permission levels (ONLY_ME, ALL_TEAM, PARTIAL_TEAM) - - Explicit permission checks for PARTIAL_TEAM - - Error conditions and logging - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with patch("services.dataset_service.db.session") as mock_session: - yield { - "db_session": mock_session, - } - - @pytest.fixture - def mock_logging_dependencies(self): - """Mock setup for logging tests.""" - with patch("services.dataset_service.logger") as mock_logging: - yield { - "logging": mock_logging, - } - - def _assert_permission_check_passes(self, dataset: Mock, user: Mock): - """Helper method to verify that permission check passes without raising exceptions.""" - # Should not raise any exception - DatasetService.check_dataset_permission(dataset, user) - - def _assert_permission_check_fails( - self, dataset: Mock, user: Mock, expected_message: str = "You do not have permission to access this dataset." - ): - """Helper method to verify that permission check fails with expected error.""" - with pytest.raises(NoPermissionError, match=expected_message): - DatasetService.check_dataset_permission(dataset, user) - - def _assert_database_query_called(self, mock_session: Mock, dataset_id: str, account_id: str): - """Helper method to verify database query calls for permission checks.""" - mock_session.query().filter_by.assert_called_with(dataset_id=dataset_id, account_id=account_id) - - def _assert_database_query_not_called(self, mock_session: Mock): - """Helper method to verify that database query was not called.""" - mock_session.query.assert_not_called() - - # ==================== Cross-Tenant Access Tests ==================== - - def test_permission_check_different_tenant_should_fail(self): - """Test that users from different tenants cannot access dataset regardless of other permissions.""" - # Create dataset and user from different tenants - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - tenant_id="tenant-123", permission=DatasetPermissionEnum.ALL_TEAM - ) - user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="user-789", tenant_id="different-tenant-456", role=TenantAccountRole.EDITOR - ) - - # Should fail due to different tenant - self._assert_permission_check_fails(dataset, user) - - # ==================== Owner Privilege Tests ==================== - - def test_owner_can_access_any_dataset(self): - """Test that tenant owners can access any dataset regardless of permission level.""" - # Create dataset with restrictive permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ONLY_ME) - - # Create owner user - owner_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="owner-999", role=TenantAccountRole.OWNER - ) - - # Owner should have access regardless of dataset permission - self._assert_permission_check_passes(dataset, owner_user) - - # ==================== ONLY_ME Permission Tests ==================== - - def test_only_me_permission_creator_can_access(self): - """Test ONLY_ME permission allows only the dataset creator to access.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should be able to access - self._assert_permission_check_passes(dataset, creator_user) - - def test_only_me_permission_others_cannot_access(self): - """Test ONLY_ME permission denies access to non-creators.""" - # Create dataset with ONLY_ME permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.ONLY_ME - ) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Non-creator should be denied access - self._assert_permission_check_fails(dataset, normal_user) - - # ==================== ALL_TEAM Permission Tests ==================== - - def test_all_team_permission_allows_access(self): - """Test ALL_TEAM permission allows any team member to access the dataset.""" - # Create dataset with ALL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.ALL_TEAM) - - # Create different types of team members - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - editor_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="editor-456", role=TenantAccountRole.EDITOR - ) - - # All team members should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_permission_check_passes(dataset, editor_user) - - # ==================== PARTIAL_TEAM Permission Tests ==================== - - def test_partial_team_permission_creator_can_access(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows creator to access without database query.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should have access without database query - self._assert_permission_check_passes(dataset, creator_user) - self._assert_database_query_not_called(mock_dataset_service_dependencies["db_session"]) - - def test_partial_team_permission_with_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission allows users with explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return a permission record - mock_permission = DatasetPermissionTestDataFactory.create_dataset_permission_mock( - dataset_id=dataset.id, account_id=normal_user.id - ) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = mock_permission - - # User with explicit permission should have access - self._assert_permission_check_passes(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_without_explicit_permission(self, mock_dataset_service_dependencies): - """Test PARTIAL_TEAM permission denies users without explicit permission records.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # User without explicit permission should be denied access - self._assert_permission_check_fails(dataset, normal_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, normal_user.id) - - def test_partial_team_permission_non_creator_without_permission_fails(self, mock_dataset_service_dependencies): - """Test that non-creators without explicit permission are denied access to PARTIAL_TEAM datasets.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create a different user (not the creator) - other_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="other-user-123", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Non-creator without explicit permission should be denied access - self._assert_permission_check_fails(dataset, other_user) - self._assert_database_query_called(mock_dataset_service_dependencies["db_session"], dataset.id, other_user.id) - - # ==================== Enum Usage Tests ==================== - - def test_partial_team_permission_uses_correct_enum(self): - """Test that the method correctly uses DatasetPermissionEnum.PARTIAL_TEAM instead of string literals.""" - # Create dataset with PARTIAL_TEAM permission using enum - dataset = DatasetPermissionTestDataFactory.create_dataset_mock( - created_by="creator-456", permission=DatasetPermissionEnum.PARTIAL_TEAM - ) - - # Create creator user - creator_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="creator-456", role=TenantAccountRole.EDITOR - ) - - # Creator should always have access regardless of permission level - self._assert_permission_check_passes(dataset, creator_user) - - # ==================== Logging Tests ==================== - - def test_permission_denied_logs_debug_message(self, mock_dataset_service_dependencies, mock_logging_dependencies): - """Test that permission denied events are properly logged for debugging purposes.""" - # Create dataset with PARTIAL_TEAM permission - dataset = DatasetPermissionTestDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) - - # Create normal user (not the creator) - normal_user = DatasetPermissionTestDataFactory.create_user_mock( - user_id="normal-789", role=TenantAccountRole.NORMAL - ) - - # Mock database query to return None (no permission record) - mock_dataset_service_dependencies["db_session"].query().filter_by().first.return_value = None - - # Attempt permission check (should fail) - with pytest.raises(NoPermissionError): - DatasetService.check_dataset_permission(dataset, normal_user) - - # Verify debug message was logged with correct user and dataset information - mock_logging_dependencies["logging"].debug.assert_called_with( - "User %s does not have permission to access dataset %s", normal_user.id, dataset.id - ) diff --git a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py index 69766188f3..abff48347e 100644 --- a/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py +++ b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py @@ -1,13 +1,10 @@ import datetime - -# Mock redis_client before importing dataset_service -from unittest.mock import Mock, call, patch +from unittest.mock import Mock, patch import pytest from models.dataset import Dataset, Document from services.dataset_service import DocumentService -from services.errors.document import DocumentIndexingError from tests.unit_tests.conftest import redis_mock @@ -48,7 +45,6 @@ class DocumentBatchUpdateTestDataFactory: document.indexing_status = indexing_status document.completed_at = completed_at or datetime.datetime.now() - # Set default values for optional fields document.disabled_at = None document.disabled_by = None document.archived_at = None @@ -59,32 +55,9 @@ class DocumentBatchUpdateTestDataFactory: setattr(document, key, value) return document - @staticmethod - def create_multiple_documents( - document_ids: list[str], enabled: bool = True, archived: bool = False, indexing_status: str = "completed" - ) -> list[Mock]: - """Create multiple mock documents with specified attributes.""" - documents = [] - for doc_id in document_ids: - doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - document_id=doc_id, - name=f"document_{doc_id}.pdf", - enabled=enabled, - archived=archived, - indexing_status=indexing_status, - ) - documents.append(doc) - return documents - class TestDatasetServiceBatchUpdateDocumentStatus: - """ - Comprehensive unit tests for DocumentService.batch_update_document_status method. - - This test suite covers all supported actions (enable, disable, archive, un_archive), - error conditions, edge cases, and validates proper interaction with Redis cache, - database operations, and async task triggers. - """ + """Unit tests for non-SQL path in DocumentService.batch_update_document_status.""" @pytest.fixture def mock_document_service_dependencies(self): @@ -104,697 +77,24 @@ class TestDatasetServiceBatchUpdateDocumentStatus: "current_time": current_time, } - @pytest.fixture - def mock_async_task_dependencies(self): - """Mock setup for async task dependencies.""" - with ( - patch("services.dataset_service.add_document_to_index_task") as mock_add_task, - patch("services.dataset_service.remove_document_from_index_task") as mock_remove_task, - ): - yield {"add_task": mock_add_task, "remove_task": mock_remove_task} - - def _assert_document_enabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was enabled correctly.""" - assert document.enabled == True - assert document.disabled_at is None - assert document.disabled_by is None - assert document.updated_at == current_time - - def _assert_document_disabled(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was disabled correctly.""" - assert document.enabled == False - assert document.disabled_at == current_time - assert document.disabled_by == user_id - assert document.updated_at == current_time - - def _assert_document_archived(self, document: Mock, user_id: str, current_time: datetime.datetime): - """Helper method to verify document was archived correctly.""" - assert document.archived == True - assert document.archived_at == current_time - assert document.archived_by == user_id - assert document.updated_at == current_time - - def _assert_document_unarchived(self, document: Mock): - """Helper method to verify document was unarchived correctly.""" - assert document.archived == False - assert document.archived_at is None - assert document.archived_by is None - - def _assert_redis_cache_operations(self, document_ids: list[str], action: str = "setex"): - """Helper method to verify Redis cache operations.""" - if action == "setex": - expected_calls = [call(f"document_{doc_id}_indexing", 600, 1) for doc_id in document_ids] - redis_mock.setex.assert_has_calls(expected_calls) - elif action == "get": - expected_calls = [call(f"document_{doc_id}_indexing") for doc_id in document_ids] - redis_mock.get.assert_has_calls(expected_calls) - - def _assert_async_task_calls(self, mock_task, document_ids: list[str], task_type: str): - """Helper method to verify async task calls.""" - expected_calls = [call(doc_id) for doc_id in document_ids] - if task_type in {"add", "remove"}: - mock_task.delay.assert_has_calls(expected_calls) - - # ==================== Enable Document Tests ==================== - - def test_batch_update_enable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful enabling of disabled documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create disabled documents - disabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=False) - mock_document_service_dependencies["get_document"].side_effect = disabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to enable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="enable", user=user - ) - - # Verify document attributes were updated correctly - for doc in disabled_docs: - self._assert_document_enabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations - self._assert_redis_cache_operations(["doc-1", "doc-2"], "get") - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered for indexing - self._assert_async_task_calls(mock_async_task_dependencies["add_task"], ["doc-1", "doc-2"], "add") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_enable_already_enabled_document_skipped(self, mock_document_service_dependencies): - """Test enabling documents that are already enabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to enable already enabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # ==================== Disable Document Tests ==================== - - def test_batch_update_disable_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful disabling of enabled and completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create enabled documents - enabled_docs = DocumentBatchUpdateTestDataFactory.create_multiple_documents(["doc-1", "doc-2"], enabled=True) - mock_document_service_dependencies["get_document"].side_effect = enabled_docs - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to disable documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2"], action="disable", user=user - ) - - # Verify document attributes were updated correctly - for doc in enabled_docs: - self._assert_document_disabled(doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache operations for indexing prevention - self._assert_redis_cache_operations(["doc-1", "doc-2"], "setex") - - # Verify async tasks were triggered to remove from index - self._assert_async_task_calls(mock_async_task_dependencies["remove_task"], ["doc-1", "doc-2"], "remove") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 2 - assert mock_db.commit.call_count == 1 - - def test_batch_update_disable_already_disabled_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test disabling documents that are already disabled.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to disable already disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_disable_non_completed_document_error(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when trying to disable non-completed documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create a document that's not completed - non_completed_doc = DocumentBatchUpdateTestDataFactory.create_document_mock( - enabled=True, - indexing_status="indexing", # Not completed - completed_at=None, # Not completed - ) - mock_document_service_dependencies["get_document"].return_value = non_completed_doc - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="disable", user=user - ) - - # Verify error message indicates document is not completed - assert "is not completed" in str(exc_info.value) - - # ==================== Archive Document Tests ==================== - - def test_batch_update_archive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful archiving of unarchived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create unarchived enabled document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to archive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_archived(unarchived_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify Redis cache was set (because document was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to remove from index (because enabled) - mock_async_task_dependencies["remove_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_archive_already_archived_document_skipped(self, mock_document_service_dependencies): - """Test archiving documents that are already archived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to archive already archived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-3"], action="archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - def test_batch_update_archive_disabled_document_no_index_removal( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test archiving disabled documents (should not trigger index removal).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Set up disabled, unarchived document - disabled_unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=False) - mock_document_service_dependencies["get_document"].return_value = disabled_unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Archive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="archive", user=user - ) - - # Verify document was archived - self._assert_document_archived( - disabled_unarchived_doc, user.id, mock_document_service_dependencies["current_time"] - ) - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index removal task was triggered (document is disabled) - mock_async_task_dependencies["remove_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Unarchive Document Tests ==================== - - def test_batch_update_unarchive_documents_success( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test successful unarchiving of archived documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived document - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Call the method to unarchive documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document attributes were updated correctly - self._assert_document_unarchived(archived_doc) - assert archived_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify Redis cache was set (because document is enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify async task was triggered to add back to index (because enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_batch_update_unarchive_already_unarchived_document_skipped( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving documents that are already unarchived.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create already unarchived document - unarchived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True, archived=False) - mock_document_service_dependencies["get_document"].return_value = unarchived_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Attempt to unarchive already unarchived document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify no database operations occurred (document was skipped) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.commit.assert_not_called() - - # Verify no Redis setex operations occurred (document was skipped) - redis_mock.setex.assert_not_called() - - # Verify no async tasks were triggered (document was skipped) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - def test_batch_update_unarchive_disabled_document_no_index_addition( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test unarchiving disabled documents (should not trigger index addition).""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock archived but disabled document - archived_disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False, archived=True) - mock_document_service_dependencies["get_document"].return_value = archived_disabled_doc - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Unarchive the disabled document - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="un_archive", user=user - ) - - # Verify document was unarchived - self._assert_document_unarchived(archived_disabled_doc) - assert archived_disabled_doc.updated_at == mock_document_service_dependencies["current_time"] - - # Verify no Redis cache was set (document is disabled) - redis_mock.setex.assert_not_called() - - # Verify no index addition task was triggered (document is disabled) - mock_async_task_dependencies["add_task"].delay.assert_not_called() - - # Verify database operations still occurred - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # ==================== Error Handling Tests ==================== - - def test_batch_update_document_indexing_error_redis_cache_hit(self, mock_document_service_dependencies): - """Test that DocumentIndexingError is raised when documents are currently being indexed.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock enabled document - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) - mock_document_service_dependencies["get_document"].return_value = enabled_doc - - # Set up mock to indicate document is being indexed - redis_mock.reset_mock() - redis_mock.get.return_value = "indexing" - - # Verify that DocumentIndexingError is raised - with pytest.raises(DocumentIndexingError) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message contains document name - assert "test_document.pdf" in str(exc_info.value) - assert "is being indexed" in str(exc_info.value) - - # Verify Redis cache was checked - redis_mock.get.assert_called_once_with("document_doc-1_indexing") - def test_batch_update_invalid_action_error(self, mock_document_service_dependencies): """Test that ValueError is raised when an invalid action is provided.""" dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() user = DocumentBatchUpdateTestDataFactory.create_user_mock() - # Create mock document doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=True) mock_document_service_dependencies["get_document"].return_value = doc - # Reset module-level Redis mock redis_mock.reset_mock() redis_mock.get.return_value = None - # Test with invalid action invalid_action = "invalid_action" with pytest.raises(ValueError) as exc_info: DocumentService.batch_update_document_status( dataset=dataset, document_ids=["doc-1"], action=invalid_action, user=user ) - # Verify error message contains the invalid action assert invalid_action in str(exc_info.value) assert "Invalid action" in str(exc_info.value) - # Verify no Redis operations occurred redis_mock.setex.assert_not_called() - - def test_batch_update_async_task_error_handling( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test handling of async task errors during batch operations.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create mock disabled document - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock(enabled=False) - mock_document_service_dependencies["get_document"].return_value = disabled_doc - - # Mock async task to raise an exception - mock_async_task_dependencies["add_task"].delay.side_effect = Exception("Celery task error") - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Verify that async task error is propagated - with pytest.raises(Exception) as exc_info: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1"], action="enable", user=user - ) - - # Verify error message - assert "Celery task error" in str(exc_info.value) - - # Verify database operations completed successfully - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - # Verify Redis cache was set successfully - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Verify document was updated - self._assert_document_enabled(disabled_doc, user.id, mock_document_service_dependencies["current_time"]) - - # ==================== Edge Case Tests ==================== - - def test_batch_update_empty_document_list(self, mock_document_service_dependencies): - """Test batch operations with an empty document ID list.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Call method with empty document list - result = DocumentService.batch_update_document_status( - dataset=dataset, document_ids=[], action="enable", user=user - ) - - # Verify no document lookups were performed - mock_document_service_dependencies["get_document"].assert_not_called() - - # Verify method returns None (early return) - assert result is None - - def test_batch_update_document_not_found_skipped(self, mock_document_service_dependencies): - """Test behavior when some documents don't exist in the database.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Mock document service to return None (document not found) - mock_document_service_dependencies["get_document"].return_value = None - - # Call method with non-existent document ID - # This should not raise an error, just skip the missing document - try: - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["non-existent-doc"], action="enable", user=user - ) - except Exception as e: - pytest.fail(f"Method should not raise exception for missing documents: {e}") - - # Verify document lookup was attempted - mock_document_service_dependencies["get_document"].assert_called_once_with(dataset.id, "non-existent-doc") - - def test_batch_update_mixed_document_states_and_actions( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations on documents with mixed states and various scenarios.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - disabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) - enabled_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-2", enabled=True) - archived_doc = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-3", enabled=True, archived=True) - - # Mix of different document states - documents = [disabled_doc, enabled_doc, archived_doc] - mock_document_service_dependencies["get_document"].side_effect = documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform enable operation on mixed state documents - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=["doc-1", "doc-2", "doc-3"], action="enable", user=user - ) - - # Verify only the disabled document was processed - # (enabled and archived documents should be skipped for enable action) - - # Only one add should occur (for the disabled document that was enabled) - mock_db = mock_document_service_dependencies["db_session"] - mock_db.add.assert_called_once() - # Only one commit should occur - mock_db.commit.assert_called_once() - - # Only one Redis setex should occur (for the document that was enabled) - redis_mock.setex.assert_called_once_with("document_doc-1_indexing", 600, 1) - - # Only one async task should be triggered (for the document that was enabled) - mock_async_task_dependencies["add_task"].delay.assert_called_once_with("doc-1") - - # ==================== Performance Tests ==================== - - def test_batch_update_large_document_list_performance( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test batch operations with a large number of documents.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create large list of document IDs - document_ids = [f"doc-{i}" for i in range(1, 101)] # 100 documents - - # Create mock documents - mock_documents = DocumentBatchUpdateTestDataFactory.create_multiple_documents( - document_ids, - enabled=False, # All disabled, will be enabled - ) - mock_document_service_dependencies["get_document"].side_effect = mock_documents - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform batch enable operation - DocumentService.batch_update_document_status( - dataset=dataset, document_ids=document_ids, action="enable", user=user - ) - - # Verify all documents were processed - assert mock_document_service_dependencies["get_document"].call_count == 100 - - # Verify all documents were updated - for mock_doc in mock_documents: - self._assert_document_enabled(mock_doc, user.id, mock_document_service_dependencies["current_time"]) - - # Verify database operations - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 100 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for each document - assert redis_mock.setex.call_count == 100 - - # Verify async tasks were triggered for each document - assert mock_async_task_dependencies["add_task"].delay.call_count == 100 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call(f"document_doc-{i}_indexing", 600, 1) for i in range(1, 101)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call(f"doc-{i}") for i in range(1, 101)] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) - - def test_batch_update_mixed_document_states_complex_scenario( - self, mock_document_service_dependencies, mock_async_task_dependencies - ): - """Test complex batch operations with documents in various states.""" - dataset = DocumentBatchUpdateTestDataFactory.create_dataset_mock() - user = DocumentBatchUpdateTestDataFactory.create_user_mock() - - # Create documents in various states - doc1 = DocumentBatchUpdateTestDataFactory.create_document_mock("doc-1", enabled=False) # Will be enabled - doc2 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-2", enabled=True - ) # Already enabled, will be skipped - doc3 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-3", enabled=True - ) # Already enabled, will be skipped - doc4 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-4", enabled=True - ) # Not affected by enable action - doc5 = DocumentBatchUpdateTestDataFactory.create_document_mock( - "doc-5", enabled=True, archived=True - ) # Not affected by enable action - doc6 = None # Non-existent, will be skipped - - mock_document_service_dependencies["get_document"].side_effect = [doc1, doc2, doc3, doc4, doc5, doc6] - - # Reset module-level Redis mock - redis_mock.reset_mock() - redis_mock.get.return_value = None - - # Perform mixed batch operations - DocumentService.batch_update_document_status( - dataset=dataset, - document_ids=["doc-1", "doc-2", "doc-3", "doc-4", "doc-5", "doc-6"], - action="enable", # This will only affect doc1 - user=user, - ) - - # Verify document 1 was enabled - self._assert_document_enabled(doc1, user.id, mock_document_service_dependencies["current_time"]) - - # Verify other documents were skipped appropriately - assert doc2.enabled == True # No change - assert doc3.enabled == True # No change - assert doc4.enabled == True # No change - assert doc5.enabled == True # No change - - # Verify database commits occurred for processed documents - # Only doc1 should be added (others were skipped, doc6 doesn't exist) - mock_db = mock_document_service_dependencies["db_session"] - assert mock_db.add.call_count == 1 - assert mock_db.commit.call_count == 1 - - # Verify Redis cache operations occurred for processed documents - # Only doc1 should have Redis operations - assert redis_mock.setex.call_count == 1 - - # Verify async tasks were triggered for processed documents - # Only doc1 should trigger tasks - assert mock_async_task_dependencies["add_task"].delay.call_count == 1 - - # Verify correct Redis cache keys were set - expected_redis_calls = [call("document_doc-1_indexing", 600, 1)] - redis_mock.setex.assert_has_calls(expected_redis_calls) - - # Verify correct async task calls - expected_task_calls = [call("doc-1")] - mock_async_task_dependencies["add_task"].delay.assert_has_calls(expected_task_calls) diff --git a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py index 7c7a70f962..f8c5270656 100644 --- a/api/tests/unit_tests/services/test_dataset_service_create_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_create_dataset.py @@ -1,726 +1,39 @@ -""" -Comprehensive unit tests for DatasetService creation methods. +"""Unit tests for non-SQL validation paths in DatasetService dataset creation.""" -This test suite covers: -- create_empty_dataset for internal datasets -- create_empty_dataset for external datasets -- create_empty_rag_pipeline_dataset -- Error conditions and edge cases -""" - -from unittest.mock import Mock, create_autospec, patch +from unittest.mock import Mock, patch from uuid import uuid4 import pytest -from core.model_runtime.entities.model_entities import ModelType -from models.account import Account -from models.dataset import Dataset, Pipeline from services.dataset_service import DatasetService -from services.entities.knowledge_entities.knowledge_entities import RetrievalModel -from services.entities.knowledge_entities.rag_pipeline_entities import ( - IconInfo, - RagPipelineDatasetCreateEntity, -) -from services.errors.dataset import DatasetNameDuplicateError +from services.entities.knowledge_entities.rag_pipeline_entities import IconInfo, RagPipelineDatasetCreateEntity -class DatasetCreateTestDataFactory: - """Factory class for creating test data and mock objects for dataset creation tests.""" - - @staticmethod - def create_account_mock( - account_id: str = "account-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock account.""" - account = create_autospec(Account, instance=True) - account.id = account_id - account.current_tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(account, key, value) - return account - - @staticmethod - def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: - """Create a mock embedding model.""" - embedding_model = Mock() - embedding_model.model_name = model - embedding_model.provider = provider - return embedding_model - - @staticmethod - def create_retrieval_model_mock() -> Mock: - """Create a mock retrieval model.""" - retrieval_model = Mock(spec=RetrievalModel) - retrieval_model.model_dump.return_value = { - "search_method": "semantic_search", - "top_k": 2, - "score_threshold": 0.0, - } - retrieval_model.reranking_model = None - return retrieval_model - - @staticmethod - def create_external_knowledge_api_mock(api_id: str = "api-123", **kwargs) -> Mock: - """Create a mock external knowledge API.""" - api = Mock() - api.id = api_id - for key, value in kwargs.items(): - setattr(api, key, value) - return api - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - name: str = "Test Dataset", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """Create a mock dataset.""" - dataset = create_autospec(Dataset, instance=True) - dataset.id = dataset_id - dataset.name = name - dataset.tenant_id = tenant_id - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_pipeline_mock( - pipeline_id: str = "pipeline-123", - name: str = "Test Pipeline", - **kwargs, - ) -> Mock: - """Create a mock pipeline.""" - pipeline = Mock(spec=Pipeline) - pipeline.id = pipeline_id - pipeline.name = name - for key, value in kwargs.items(): - setattr(pipeline, key, value) - return pipeline - - -class TestDatasetServiceCreateEmptyDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_dataset method. - - This test suite covers: - - Internal dataset creation (vendor provider) - - External dataset creation - - High quality indexing technique with embedding models - - Economy indexing technique - - Retrieval model configuration - - Error conditions (duplicate names, missing external knowledge IDs) - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.db.session") as mock_db, - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, - patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, - patch("services.dataset_service.ExternalDatasetService") as mock_external_service, - ): - yield { - "db_session": mock_db, - "model_manager": mock_model_manager, - "check_embedding": mock_check_embedding, - "check_reranking": mock_check_reranking, - "external_service": mock_external_service, - } - - # ==================== Internal Dataset Creation Tests ==================== - - def test_create_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """Test successful creation of basic internal dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Test Dataset" - description = "Test description" - - # Mock database query to return None (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock database session operations - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=description, - indexing_technique=None, - account=account, - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == account.id - assert result.updated_by == account.id - assert result.provider == "vendor" - assert result.permission == "only_me" - mock_db.add.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_economy_indexing(self, mock_dataset_service_dependencies): - """Test successful creation of internal dataset with economy indexing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Economy Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="economy", - account=account, - ) - - # Assert - assert result.indexing_technique == "economy" - assert result.embedding_model_provider is None - assert result.embedding_model is None - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_default_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using default embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "High Quality Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_model.provider - assert result.embedding_model == embedding_model.model_name - mock_model_manager_instance.get_default_model_instance.assert_called_once_with( - tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_high_quality_indexing_custom_embedding( - self, mock_dataset_service_dependencies - ): - """Test creation with high_quality indexing using custom embedding model.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Embedding Dataset" - embedding_provider = "openai" - embedding_model_name = "text-embedding-3-small" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock( - model=embedding_model_name, provider=embedding_provider - ) - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - embedding_model_provider=embedding_provider, - embedding_model_name=embedding_model_name, - ) - - # Assert - assert result.indexing_technique == "high_quality" - assert result.embedding_model_provider == embedding_provider - assert result.embedding_model == embedding_model_name - mock_dataset_service_dependencies["check_embedding"].assert_called_once_with( - tenant_id, embedding_provider, embedding_model_name - ) - mock_model_manager_instance.get_model_instance.assert_called_once_with( - tenant_id=tenant_id, - provider=embedding_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=embedding_model_name, - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model(self, mock_dataset_service_dependencies): - """Test creation with retrieval model configuration.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Retrieval Model Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock retrieval model - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model_dict = {"search_method": "semantic_search", "top_k": 2, "score_threshold": 0.0} - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - assert result.retrieval_model == retrieval_model_dict - retrieval_model.model_dump.assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_retrieval_model_reranking(self, mock_dataset_service_dependencies): - """Test creation with retrieval model that includes reranking.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Reranking Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock model manager - embedding_model = DatasetCreateTestDataFactory.create_embedding_model_mock() - mock_model_manager_instance = Mock() - mock_model_manager_instance.get_default_model_instance.return_value = embedding_model - mock_dataset_service_dependencies["model_manager"].return_value = mock_model_manager_instance - - # Mock retrieval model with reranking - reranking_model = Mock() - reranking_model.reranking_provider_name = "cohere" - reranking_model.reranking_model_name = "rerank-english-v3.0" - - retrieval_model = DatasetCreateTestDataFactory.create_retrieval_model_mock() - retrieval_model.reranking_model = reranking_model - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique="high_quality", - account=account, - retrieval_model=retrieval_model, - ) - - # Assert - mock_dataset_service_dependencies["check_reranking"].assert_called_once_with( - tenant_id, "cohere", "rerank-english-v3.0" - ) - mock_db.commit.assert_called_once() - - def test_create_internal_dataset_with_custom_permission(self, mock_dataset_service_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Custom Permission Dataset" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - permission="all_team_members", - ) - - # Assert - assert result.permission == "all_team_members" - mock_db.commit.assert_called_once() - - # ==================== External Dataset Creation Tests ==================== - - def test_create_external_dataset_success(self, mock_dataset_service_dependencies): - """Test successful creation of external dataset.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - external_knowledge_id = "external-knowledge-456" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Act - result = DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=external_knowledge_id, - ) - - # Assert - assert result.provider == "external" - assert mock_db.add.call_count == 2 # Dataset + ExternalKnowledgeBindings - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.assert_called_once_with( - external_api_id - ) - mock_db.commit.assert_called_once() - - def test_create_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge API is not found.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "non-existent-api" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API not found - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = None - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="External API template not found"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id="knowledge-123", - ) - - def test_create_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge ID is missing.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "External Dataset" - external_api_id = "external-api-123" - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Mock external knowledge API - external_api = DatasetCreateTestDataFactory.create_external_knowledge_api_mock(api_id=external_api_id) - mock_dataset_service_dependencies["external_service"].get_external_knowledge_api.return_value = external_api - - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - - # Act & Assert - with pytest.raises(ValueError, match="external_knowledge_id is required"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - provider="external", - external_knowledge_api_id=external_api_id, - external_knowledge_id=None, - ) - - # ==================== Error Handling Tests ==================== - - def test_create_dataset_duplicate_name_error(self, mock_dataset_service_dependencies): - """Test error when dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - account = DatasetCreateTestDataFactory.create_account_mock(tenant_id=tenant_id) - name = "Duplicate Dataset" - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_dataset_service_dependencies["db_session"].query.return_value = mock_query - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_dataset( - tenant_id=tenant_id, - name=name, - description=None, - indexing_technique=None, - account=account, - ) - - -class TestDatasetServiceCreateEmptyRagPipelineDataset: - """ - Comprehensive unit tests for DatasetService.create_empty_rag_pipeline_dataset method. - - This test suite covers: - - RAG pipeline dataset creation with provided name - - RAG pipeline dataset creation with auto-generated name - - Pipeline creation - - Error conditions (duplicate names, missing current user) - """ +class TestDatasetServiceCreateRagPipelineDatasetNonSQL: + """Unit coverage for non-SQL validation in create_empty_rag_pipeline_dataset.""" @pytest.fixture def mock_rag_pipeline_dependencies(self): - """Common mock setup for RAG pipeline dataset creation.""" + """Patch database session and current_user for validation-only unit coverage.""" with ( patch("services.dataset_service.db.session") as mock_db, patch("services.dataset_service.current_user") as mock_current_user, - patch("services.dataset_service.generate_incremental_name") as mock_generate_name, ): - # Configure mock_current_user to behave like a Flask-Login proxy - # Default: no user (falsy) - mock_current_user.id = None yield { "db_session": mock_db, "current_user_mock": mock_current_user, - "generate_name": mock_generate_name, } - def test_create_rag_pipeline_dataset_with_name_success(self, mock_rag_pipeline_dependencies): - """Test successful creation of RAG pipeline dataset with provided name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "RAG Pipeline Dataset" - description = "RAG Pipeline Description" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (no duplicate name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description=description, - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result is not None - assert result.name == name - assert result.description == description - assert result.tenant_id == tenant_id - assert result.created_by == user_id - assert result.provider == "vendor" - assert result.runtime_mode == "rag_pipeline" - assert result.permission == "only_me" - assert mock_db.add.call_count == 2 # Pipeline + Dataset - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_auto_generated_name(self, mock_rag_pipeline_dependencies): - """Test creation of RAG pipeline dataset with auto-generated name.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - auto_name = "Untitled 1" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query (empty name, need to generate) - mock_query = Mock() - mock_query.filter_by.return_value.all.return_value = [] - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock name generation - mock_rag_pipeline_dependencies["generate_name"].return_value = auto_name - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with empty name - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name="", - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.name == auto_name - mock_rag_pipeline_dependencies["generate_name"].assert_called_once() - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_duplicate_name_error(self, mock_rag_pipeline_dependencies): - """Test error when RAG pipeline dataset name already exists.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Duplicate RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query to return existing dataset - existing_dataset = DatasetCreateTestDataFactory.create_dataset_mock(name=name) - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = existing_dataset - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act & Assert - with pytest.raises(DatasetNameDuplicateError, match=f"Dataset with name {name} already exists"): - DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - def test_create_rag_pipeline_dataset_missing_current_user_error(self, mock_rag_pipeline_dependencies): - """Test error when current user is not available.""" + """Raise ValueError when current_user.id is unavailable before SQL persistence.""" # Arrange tenant_id = str(uuid4()) - - # Mock current user as None - set id to None so the check fails mock_rag_pipeline_dependencies["current_user_mock"].id = None - # Mock database query mock_query = Mock() mock_query.filter_by.return_value.first.return_value = None mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - # Create entity icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") entity = RagPipelineDatasetCreateEntity( name="Test Dataset", @@ -729,91 +42,9 @@ class TestDatasetServiceCreateEmptyRagPipelineDataset: permission="only_me", ) - # Act & Assert + # Act / Assert with pytest.raises(ValueError, match="Current user or current user id not found"): DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity + tenant_id=tenant_id, + rag_pipeline_dataset_create_entity=entity, ) - - def test_create_rag_pipeline_dataset_with_custom_permission(self, mock_rag_pipeline_dependencies): - """Test creation with custom permission setting.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Custom Permission RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity - icon_info = IconInfo(icon="📙", icon_background="#FFF4ED", icon_type="emoji") - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="all_team", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.permission == "all_team" - mock_db.commit.assert_called_once() - - def test_create_rag_pipeline_dataset_with_icon_info(self, mock_rag_pipeline_dependencies): - """Test creation with icon info configuration.""" - # Arrange - tenant_id = str(uuid4()) - user_id = str(uuid4()) - name = "Icon Info RAG Dataset" - - # Mock current user - set up the mock to have id attribute accessible directly - mock_rag_pipeline_dependencies["current_user_mock"].id = user_id - - # Mock database query - mock_query = Mock() - mock_query.filter_by.return_value.first.return_value = None - mock_rag_pipeline_dependencies["db_session"].query.return_value = mock_query - - # Mock database operations - mock_db = mock_rag_pipeline_dependencies["db_session"] - mock_db.add = Mock() - mock_db.flush = Mock() - mock_db.commit = Mock() - - # Create entity with icon info - icon_info = IconInfo( - icon="📚", - icon_background="#E8F5E9", - icon_type="emoji", - icon_url="https://example.com/icon.png", - ) - entity = RagPipelineDatasetCreateEntity( - name=name, - description="", - icon_info=icon_info, - permission="only_me", - ) - - # Act - result = DatasetService.create_empty_rag_pipeline_dataset( - tenant_id=tenant_id, rag_pipeline_dataset_create_entity=entity - ) - - # Assert - assert result.icon_info == icon_info.model_dump() - mock_db.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py b/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py deleted file mode 100644 index cc718c9997..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_delete_dataset.py +++ /dev/null @@ -1,216 +0,0 @@ -from unittest.mock import Mock, patch - -import pytest - -from models.account import Account, TenantAccountRole -from models.dataset import Dataset -from services.dataset_service import DatasetService - - -class DatasetDeleteTestDataFactory: - """Factory class for creating test data and mock objects for dataset delete tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - tenant_id: str = "test-tenant-123", - created_by: str = "creator-456", - doc_form: str | None = None, - indexing_technique: str | None = "high_quality", - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.created_by = created_by - dataset.doc_form = doc_form - dataset.indexing_technique = indexing_technique - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock( - user_id: str = "user-789", - tenant_id: str = "test-tenant-123", - role: TenantAccountRole = TenantAccountRole.ADMIN, - **kwargs, - ) -> Mock: - """Create a mock user with specified attributes.""" - user = Mock(spec=Account) - user.id = user_id - user.current_tenant_id = tenant_id - user.current_role = role - for key, value in kwargs.items(): - setattr(user, key, value) - return user - - -class TestDatasetServiceDeleteDataset: - """ - Comprehensive unit tests for DatasetService.delete_dataset method. - - This test suite covers all deletion scenarios including: - - Normal dataset deletion with documents - - Empty dataset deletion (no documents, doc_form is None) - - Dataset deletion with missing indexing_technique - - Permission checks - - Event handling - - This test suite provides regression protection for issue #27073. - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.dataset_was_deleted") as mock_dataset_was_deleted, - ): - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "db_session": mock_db, - "dataset_was_deleted": mock_dataset_was_deleted, - } - - def test_delete_dataset_with_documents_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of a dataset with documents. - - This test verifies: - - Dataset is retrieved correctly - - Permission check is performed - - dataset_was_deleted event is sent - - Dataset is deleted from database - - Method returns True - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock( - doc_form="text_model", indexing_technique="high_quality" - ) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_empty_dataset_success(self, mock_dataset_service_dependencies): - """ - Test successful deletion of an empty dataset (no documents, doc_form is None). - - This test verifies that: - - Empty datasets can be deleted without errors - - dataset_was_deleted event is sent (event handler will skip cleanup if doc_form is None) - - Dataset is deleted from database - - Method returns True - - This is the primary test for issue #27073 where deleting an empty dataset - caused internal server error due to assertion failure in event handlers. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_partial_none_values(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset with partial None values. - - This test verifies that datasets with partial None values (e.g., doc_form exists - but indexing_technique is None) can be deleted successfully. The event handler - will skip cleanup if any required field is None. - - Improvement based on Gemini Code Assist suggestion: Added comprehensive assertions - to verify all core deletion operations are performed, not just event sending. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form="text_model", indexing_technique=None) - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow (Gemini suggestion implemented) - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_with_doc_form_none_indexing_technique_exists(self, mock_dataset_service_dependencies): - """ - Test deletion of dataset where doc_form is None but indexing_technique exists. - - This edge case can occur in certain dataset configurations and should be handled - gracefully by the event handler's conditional check. - """ - # Arrange - dataset = DatasetDeleteTestDataFactory.create_dataset_mock(doc_form=None, indexing_technique="high_quality") - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - # Act - result = DatasetService.delete_dataset(dataset.id, user) - - # Assert - Verify complete deletion flow - assert result is True - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset.id) - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].delete.assert_called_once_with(dataset) - mock_dataset_service_dependencies["db_session"].commit.assert_called_once() - - def test_delete_dataset_not_found(self, mock_dataset_service_dependencies): - """ - Test deletion attempt when dataset doesn't exist. - - This test verifies that: - - Method returns False when dataset is not found - - No deletion operations are performed - - No events are sent - """ - # Arrange - dataset_id = "non-existent-dataset" - user = DatasetDeleteTestDataFactory.create_user_mock() - - mock_dataset_service_dependencies["get_dataset"].return_value = None - - # Act - result = DatasetService.delete_dataset(dataset_id, user) - - # Assert - assert result is False - mock_dataset_service_dependencies["get_dataset"].assert_called_once_with(dataset_id) - mock_dataset_service_dependencies["check_permission"].assert_not_called() - mock_dataset_service_dependencies["dataset_was_deleted"].send.assert_not_called() - mock_dataset_service_dependencies["db_session"].delete.assert_not_called() - mock_dataset_service_dependencies["db_session"].commit.assert_not_called() diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py new file mode 100644 index 0000000000..105ef7ba48 --- /dev/null +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -0,0 +1,760 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from dify_graph.model_runtime.entities.provider_entities import FormType +from models.account import Account +from models.model import EndUser +from models.oauth import DatasourceProvider +from models.provider_ids import DatasourceProviderID +from services.datasource_provider_service import DatasourceProviderService, get_current_user + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID: + return DatasourceProviderID(s) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +class TestDatasourceProviderService: + """Comprehensive tests for DatasourceProviderService targeting >95% coverage.""" + + @pytest.fixture + def service(self): + return DatasourceProviderService() + + @pytest.fixture + def mock_db_session(self): + """ + Robust, chainable query mock. + q returns itself for .filter_by(), .order_by(), .where() so any + SQLAlchemy chaining pattern works without multiple brittle sub-mocks. + """ + with patch("services.datasource_provider_service.Session") as mock_cls: + sess = MagicMock(spec=Session) + + q = MagicMock() + sess.query.return_value = q + + # Self-returning chain — any method called on q returns q + q.filter_by.return_value = q + q.order_by.return_value = q + q.where.return_value = q + + # Default terminal values (tests override per-case) + q.first.return_value = None + q.all.return_value = [] + q.count.return_value = 0 + q.delete.return_value = 1 + + mock_cls.return_value.__enter__.return_value = sess + mock_cls.return_value.no_autoflush.__enter__.return_value = sess + + yield sess + + @pytest.fixture(autouse=True) + def patch_db(self, mock_db_session): + with patch("services.datasource_provider_service.db") as mock_db: + mock_db.session = mock_db_session + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture(autouse=True) + def patch_externals(self): + with ( + patch("httpx.request") as mock_httpx, + patch("services.datasource_provider_service.dify_config") as mock_cfg, + patch("services.datasource_provider_service.encrypter") as mock_enc, + patch("services.datasource_provider_service.redis_client") as mock_redis, + patch("services.datasource_provider_service.generate_incremental_name") as mock_genname, + patch("services.datasource_provider_service.OAuthHandler") as mock_oauth, + ): + mock_cfg.CONSOLE_API_URL = "http://localhost" + mock_enc.encrypt_token.return_value = "enc_tok" + mock_enc.decrypt_token.return_value = "dec_tok" + mock_enc.decrypt.return_value = {"k": "dec"} + mock_enc.encrypt.return_value = {"k": "enc"} + mock_enc.obfuscated_token.return_value = "obf" + mock_enc.mask_plugin_credentials.return_value = {"k": "mask"} + + mock_redis.lock.return_value.__enter__.return_value = MagicMock() + mock_genname.return_value = "gen_name" + + mock_oauth.return_value.refresh_credentials.return_value = MagicMock( + credentials={"k": "v"}, expires_at=9999 + ) + + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = { + "code": 0, + "message": "ok", + "data": { + "provider": "prov", + "plugin_unique_identifier": "pui", + "plugin_id": "org/plug", + "is_authorized": False, + "declaration": { + "identity": { + "author": "a", + "name": "n", + "description": {"en_US": "d"}, + "icon": "i", + "label": {"en_US": "l"}, + }, + "credentials_schema": [], + "oauth_schema": {"credentials_schema": [], "client_schema": []}, + "provider_type": "local_file", + "datasources": [], + }, + }, + } + mock_httpx.return_value = resp + + # Store handles for assertions + self._enc = mock_enc + self._redis = mock_redis + yield + + @pytest.fixture + def mock_user(self): + u = MagicMock() + u.id = "uid-1" + return u + + # ----------------------------------------------------------------------- + # get_current_user (lines 27-40) + # ----------------------------------------------------------------------- + + def test_should_return_proxy_when_current_object_is_account(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = Account + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_current_object_is_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = EndUser + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_get_current_object_raises_attribute_error(self): + """AttributeError from LocalProxy falls back to the proxy itself.""" + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.side_effect = AttributeError("no attr") + proxy.__class__ = Account # make the proxy itself satisfy isinstance + assert get_current_user() is proxy + + def test_should_raise_type_error_when_user_is_not_account_or_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.return_value = "plain_string" + with pytest.raises(TypeError, match="current_user must be Account or EndUser"): + get_current_user() + + # ----------------------------------------------------------------------- + # is_system_oauth_params_exist (line 357-363) + # ----------------------------------------------------------------------- + + def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock() + assert service.is_system_oauth_params_exist(make_id()) is True + + def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.is_system_oauth_params_exist(make_id()) is False + + # ----------------------------------------------------------------------- + # is_tenant_oauth_params_enabled (lines 365-379) + # NOTE: uses .count() not .first() + # ----------------------------------------------------------------------- + + def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 1 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True + + def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False + + # ----------------------------------------------------------------------- + # remove_oauth_custom_client_params (lines 55-61) + # ----------------------------------------------------------------------- + + def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): + service.remove_oauth_custom_client_params("t1", make_id()) + mock_db_session.query().delete.assert_called_once() + + # ----------------------------------------------------------------------- + # setup_oauth_custom_client_params (315-351) + # ----------------------------------------------------------------------- + + def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session): + """When credentials=None, should return immediately without any DB write.""" + service.setup_oauth_custom_client_params("t1", make_id(), None, None) + mock_db_session.add.assert_not_called() + + def test_should_create_new_config_when_none_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) + mock_db_session.add.assert_called_once() + + def test_should_update_existing_config_when_record_found(self, service, mock_db_session): + existing = MagicMock() + mock_db_session.query().first.return_value = existing + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) + mock_db_session.add.assert_not_called() # update in place, no add + + # ----------------------------------------------------------------------- + # decrypt / encrypt credentials (lines 70-98) + # ----------------------------------------------------------------------- + + def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "enc_val"} + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov") + assert result["sk"] == "dec_tok" + + def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p) + assert result["sk"] == "enc_tok" + self._enc.encrypt_token.assert_called() + + # ----------------------------------------------------------------------- + # get_datasource_credentials (lines 113-165) + # ----------------------------------------------------------------------- + + def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().first.return_value = None + assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} + + def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): + """Expired OAuth credential (expires_at near zero) triggers a silent refresh.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 # expired + p.encrypted_credentials = {"tok": "x"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), + ): + service.get_datasource_credentials("t1", "prov", "org/plug") + mock_db_session.commit.assert_called_once() + + def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): + """API key credentials with expires_at=-1 skip refresh and return directly.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 # sentinel: never expires + p.encrypted_credentials = {"k": "v"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug") + assert result == {"k": "plain"} + + def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user): + """When credential_id is passed, the credential_id filter path (line 113) is taken.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 + p.encrypted_credentials = {} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id") + assert result == {"k": "v"} + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials_by_provider (lines 176-228) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().all.return_value = [] + assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] + + def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 + p.encrypted_credentials = {"t": "x"} + mock_db_session.query().all.return_value = [p] + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), + ): + result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_provider_name (lines 236-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") + + def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "same" + mock_db_session.query().first.return_value = p + service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") + mock_db_session.commit.assert_not_called() + + def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + + def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + assert p.name == "new_name" + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # set_default_datasource_provider (lines 277-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.set_default_datasource_provider("t1", make_id(), "bad-id") + + def test_should_mark_target_as_default_and_commit(self, service, mock_db_session): + target = MagicMock(spec=DatasourceProvider) + target.provider = "provider" + target.plugin_id = "org/plug" + mock_db_session.query().first.return_value = target + service.set_default_datasource_provider("t1", make_id(), "new-id") + assert target.is_default is True + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # get_oauth_encrypter (lines 404-420) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_oauth_schema_missing(self, service): + pm = MagicMock() + pm.declaration.oauth_schema = None + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="oauth schema not found"): + service.get_oauth_encrypter("t1", make_id()) + + def test_should_return_encrypter_when_oauth_schema_exists(self, service): + schema_item = MagicMock() + schema_item.to_basic_provider_config.return_value = MagicMock() + pm = MagicMock() + pm.declaration.oauth_schema.client_schema = [schema_item] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm), + patch( + "services.datasource_provider_service.create_provider_encrypter", + return_value=(MagicMock(), MagicMock()), + ), + ): + result = service.get_oauth_encrypter("t1", make_id()) + assert result is not None + + # ----------------------------------------------------------------------- + # get_tenant_oauth_client (lines 381-402) + # ----------------------------------------------------------------------- + + def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=True) + assert result == {"k": "mask"} + + def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=False) + assert result == {"k": "dec"} + + def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.get_tenant_oauth_client("t1", make_id()) is None + + # ----------------------------------------------------------------------- + # get_oauth_client (lines 423-457) + # ----------------------------------------------------------------------- + + def test_should_use_tenant_config_when_available(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "dec"} + + def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): + mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), + ): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "sys"} + + def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): + """Neither tenant nor system credentials → raises ValueError.""" + mock_db_session.query().first.side_effect = [None, None] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), + ): + with pytest.raises(ValueError, match="Please configure oauth client params"): + service.get_oauth_client("t1", make_id()) + + # ----------------------------------------------------------------------- + # add_datasource_oauth_provider (lines 539-607) + # ----------------------------------------------------------------------- + + def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): + """Conflict on name results in auto-incremented name, not an error.""" + mock_db_session.query().count.return_value = 1 # conflict first, then auto-named + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), + ): + service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): + """name=None causes auto-generation via generate_next_datasource_provider_name.""" + mock_db_session.query().count.return_value = 0 + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), + ): + service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # reauthorize_datasource_oauth_provider (lines 477-537) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "extract_secret_variables", return_value=[]): + with pytest.raises(ValueError, match="not found"): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") + + def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + mock_db_session.query().all.return_value = [] + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider( + "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" + ) + mock_db_session.commit.assert_called_once() + + def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # add_datasource_api_key_provider (lines 608-675) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): + """explicit name supplied + conflict → raises ValueError immediately.""" + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) + + def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) + + def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # extract_secret_variables (lines 666-699) + # ----------------------------------------------------------------------- + + def test_should_extract_secret_variable_names_for_api_key_schema(self, service): + schema = MagicMock() + schema.name = "my_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT # "secret-input" + pm = MagicMock() + pm.declaration.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY) + assert "my_secret" in result + + def test_should_extract_secret_variable_names_for_oauth2_schema(self, service): + schema = MagicMock() + schema.name = "oauth_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT + pm = MagicMock() + pm.declaration.oauth_schema.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2) + assert "oauth_secret" in result + + def test_should_raise_value_error_when_credential_type_is_invalid(self, service): + pm = MagicMock() + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="Invalid credential type"): + service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED) + + # ----------------------------------------------------------------------- + # list_datasource_credentials (lines 721-754) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.list_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + p.is_default = False + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.list_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials (lines 808-871) + # ----------------------------------------------------------------------- + + def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service): + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.provider = "prov" + ds.plugin_id = "org/plug" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"} + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + cred = {"credential": {"k": "v"}, "is_default": True} + with patch.object(service, "list_datasource_credentials", return_value=[cred]): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + + def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session): + """Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs.""" + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.plugin_id = "langgenius/firecrawl_datasource" + ds.provider = "firecrawl" + ds.plugin_unique_identifier = "pui" + ds.declaration.identity.icon = "icon" + ds.declaration.identity.name = "langgenius/firecrawl_datasource" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"} + ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"} + ds.declaration.identity.author = "langgenius" + ds.declaration.credentials_schema = [] + ds.declaration.oauth_schema.client_schema = [] + ds.declaration.oauth_schema.credentials_schema = [] + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + with ( + patch.object(service, "list_datasource_credentials", return_value=[]), + patch.object(service, "get_tenant_oauth_client", return_value=None), + patch.object(service, "is_tenant_oauth_params_enabled", return_value=False), + patch.object(service, "is_system_oauth_params_exist", return_value=False), + ): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + assert results[0]["oauth_schema"] is not None + + # ----------------------------------------------------------------------- + # get_real_datasource_credentials (lines 873-915) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.get_real_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_credentials (lines 917-978) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): + mock_db_session.query().first.return_value = None + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="not found"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") + + def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") + + def test_should_raise_value_error_when_credential_validation_fails_on_update( + self, service, mock_db_session, mock_user + ): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name") + + def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user): + """Verifies that encrypted_credentials is reassigned with encrypted value and commit is called.""" + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "old_enc"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials"), + ): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name") + # encrypter must have been called with the new secret value + self._enc.encrypt_token.assert_called() + # commit must be called exactly once + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # remove_datasource_credentials (lines 980-997) + # ----------------------------------------------------------------------- + + def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_called_once_with(p) + mock_db_session.commit.assert_called_once() + + def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): + """No error raised; no delete called when record doesn't exist (lines 994 branch).""" + mock_db_session.query().first.return_value = None + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_not_called() diff --git a/api/tests/unit_tests/services/test_document_service_rename_document.py b/api/tests/unit_tests/services/test_document_service_rename_document.py deleted file mode 100644 index 94850ecb09..0000000000 --- a/api/tests/unit_tests/services/test_document_service_rename_document.py +++ /dev/null @@ -1,176 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import Mock, create_autospec, patch - -import pytest - -from models import Account -from services.dataset_service import DocumentService - - -@pytest.fixture -def mock_env(): - """Patch dependencies used by DocumentService.rename_document. - - Mocks: - - DatasetService.get_dataset - - DocumentService.get_document - - current_user (with current_tenant_id) - - db.session - """ - with ( - patch("services.dataset_service.DatasetService.get_dataset") as get_dataset, - patch("services.dataset_service.DocumentService.get_document") as get_document, - patch("services.dataset_service.current_user", create_autospec(Account, instance=True)) as current_user, - patch("extensions.ext_database.db.session") as db_session, - ): - current_user.current_tenant_id = "tenant-123" - yield { - "get_dataset": get_dataset, - "get_document": get_document, - "current_user": current_user, - "db_session": db_session, - } - - -def make_dataset(dataset_id="dataset-123", tenant_id="tenant-123", built_in_field_enabled=False): - return SimpleNamespace(id=dataset_id, tenant_id=tenant_id, built_in_field_enabled=built_in_field_enabled) - - -def make_document( - document_id="document-123", - dataset_id="dataset-123", - tenant_id="tenant-123", - name="Old Name", - data_source_info=None, - doc_metadata=None, -): - doc = Mock() - doc.id = document_id - doc.dataset_id = dataset_id - doc.tenant_id = tenant_id - doc.name = name - doc.data_source_info = data_source_info or {} - # property-like usage in code relies on a dict - doc.data_source_info_dict = dict(doc.data_source_info) - doc.doc_metadata = dict(doc_metadata or {}) - return doc - - -def test_rename_document_success(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "New Document Name" - - dataset = make_dataset(dataset_id) - document = make_document(document_id=document_id, dataset_id=dataset_id) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - result = DocumentService.rename_document(dataset_id, document_id, new_name) - - assert result is document - assert document.name == new_name - mock_env["db_session"].add.assert_called_once_with(document) - mock_env["db_session"].commit.assert_called_once() - - -def test_rename_document_with_built_in_fields(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Renamed" - - dataset = make_dataset(dataset_id, built_in_field_enabled=True) - document = make_document(document_id=document_id, dataset_id=dataset_id, doc_metadata={"foo": "bar"}) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - # BuiltInField.document_name == "document_name" in service code - assert document.doc_metadata["document_name"] == new_name - assert document.doc_metadata["foo"] == "bar" - - -def test_rename_document_updates_upload_file_when_present(mock_env): - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Renamed" - file_id = "file-123" - - dataset = make_dataset(dataset_id) - document = make_document( - document_id=document_id, - dataset_id=dataset_id, - data_source_info={"upload_file_id": file_id}, - ) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - # Intercept UploadFile rename UPDATE chain - mock_query = Mock() - mock_query.where.return_value = mock_query - mock_env["db_session"].query.return_value = mock_query - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - mock_env["db_session"].query.assert_called() # update executed - - -def test_rename_document_does_not_update_upload_file_when_missing_id(mock_env): - """ - When data_source_info_dict exists but does not contain "upload_file_id", - UploadFile should not be updated. - """ - dataset_id = "dataset-123" - document_id = "document-123" - new_name = "Another Name" - - dataset = make_dataset(dataset_id) - # Ensure data_source_info_dict is truthy but lacks the key - document = make_document( - document_id=document_id, - dataset_id=dataset_id, - data_source_info={"url": "https://example.com"}, - ) - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - DocumentService.rename_document(dataset_id, document_id, new_name) - - assert document.name == new_name - # Should NOT attempt to update UploadFile - mock_env["db_session"].query.assert_not_called() - - -def test_rename_document_dataset_not_found(mock_env): - mock_env["get_dataset"].return_value = None - - with pytest.raises(ValueError, match="Dataset not found"): - DocumentService.rename_document("missing", "doc", "x") - - -def test_rename_document_not_found(mock_env): - dataset = make_dataset("dataset-123") - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = None - - with pytest.raises(ValueError, match="Document not found"): - DocumentService.rename_document(dataset.id, "missing", "x") - - -def test_rename_document_permission_denied_when_tenant_mismatch(mock_env): - dataset = make_dataset("dataset-123") - # different tenant than current_user.current_tenant_id - document = make_document(dataset_id=dataset.id, tenant_id="tenant-other") - - mock_env["get_dataset"].return_value = dataset - mock_env["get_document"].return_value = document - - with pytest.raises(ValueError, match="No permission"): - DocumentService.rename_document(dataset.id, document.id, "x") diff --git a/api/tests/unit_tests/services/test_end_user_service.py b/api/tests/unit_tests/services/test_end_user_service.py index 7f087a17d8..a3b1f46436 100644 --- a/api/tests/unit_tests/services/test_end_user_service.py +++ b/api/tests/unit_tests/services/test_end_user_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest from core.app.entities.app_invoke_entities import InvokeFrom -from models.model import App, EndUser +from models.model import App, DefaultEndUserSessionID, EndUser from services.end_user_service import EndUserService @@ -44,6 +44,145 @@ class TestEndUserServiceFactory: return end_user +class TestEndUserServiceGetEndUserById: + """Unit tests for EndUserService.get_end_user_by_id method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_success(self, mock_db, mock_session_class, factory): + """Test successful retrieval of end user by ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_end_user = factory.create_end_user_mock(user_id=end_user_id, tenant_id=tenant_id, app_id=app_id) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = mock_end_user + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result == mock_end_user + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.first.assert_called_once() + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_not_found(self, mock_db, mock_session_class): + """Test retrieval of non-existent end user returns None.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + assert result is None + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_get_end_user_by_id_query_parameters(self, mock_db, mock_session_class): + """Test that query parameters are correctly applied.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + end_user_id = "user-789" + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_end_user_by_id(tenant_id=tenant_id, app_id=app_id, end_user_id=end_user_id) + + # Assert + # Verify the where clause was called with the correct conditions + call_args = mock_query.where.call_args[0] + assert len(call_args) == 3 + # Check that the conditions match the expected filters + # (We can't easily test the exact conditions without importing SQLAlchemy) + + +class TestEndUserServiceGetOrCreateEndUser: + """Unit tests for EndUserService.get_or_create_end_user method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_with_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user with specific user_id.""" + # Arrange + app_mock = factory.create_app_mock() + user_id = "user-123" + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, user_id) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, user_id + ) + + @patch("services.end_user_service.EndUserService.get_or_create_end_user_by_type") + def test_get_or_create_end_user_without_user_id(self, mock_get_or_create_by_type, factory): + """Test get_or_create_end_user without user_id (None).""" + # Arrange + app_mock = factory.create_app_mock() + expected_end_user = factory.create_end_user_mock() + mock_get_or_create_by_type.return_value = expected_end_user + + # Act + result = EndUserService.get_or_create_end_user(app_mock, None) + + # Assert + assert result == expected_end_user + mock_get_or_create_by_type.assert_called_once_with( + InvokeFrom.SERVICE_API, app_mock.tenant_id, app_mock.id, None + ) + + class TestEndUserServiceGetOrCreateEndUserByType: """ Unit tests for EndUserService.get_or_create_end_user_by_type method. @@ -60,6 +199,191 @@ class TestEndUserServiceGetOrCreateEndUserByType: """Provide test data factory.""" return TestEndUserServiceFactory() + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_with_user_id(self, mock_db, mock_session_class, factory): + """Test creating a new end user with specific user_id.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + # Verify new EndUser was created with correct parameters + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + added_user = mock_session.add.call_args[0][0] + assert added_user.tenant_id == tenant_id + assert added_user.app_id == app_id + assert added_user.type == type_enum + assert added_user.session_id == user_id + assert added_user.external_user_id == user_id + assert added_user._is_anonymous is False + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_new_end_user_default_session(self, mock_db, mock_session_class, factory): + """Test creating a new end user with default session ID.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = None + type_enum = InvokeFrom.WEB_APP + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None # No existing user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert added_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_same_type(self, mock_logger, mock_db, mock_session_class, factory): + """Test retrieving existing user with same type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=type_enum, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + mock_logger.info.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + @patch("services.end_user_service.logger") + def test_existing_user_different_type_upgrade(self, mock_logger, mock_db, mock_session_class, factory): + """Test upgrading existing user with different type.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + old_type = InvokeFrom.WEB_APP + new_type = InvokeFrom.SERVICE_API + + existing_user = factory.create_end_user_mock( + tenant_id=tenant_id, app_id=app_id, session_id=user_id, type=old_type + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = existing_user + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=new_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + assert result == existing_user + assert existing_user.type == new_type + mock_session.commit.assert_called_once() + mock_logger.info.assert_called_once() + logger_call_args = mock_logger.info.call_args[0] + assert "Upgrading legacy EndUser" in logger_call_args[0] + # The old and new types are passed as separate arguments + assert mock_logger.info.call_args[0][1] == existing_user.id + assert mock_logger.info.call_args[0][2] == old_type + assert mock_logger.info.call_args[0][3] == new_type + assert mock_logger.info.call_args[0][4] == user_id + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_query_ordering_prioritizes_exact_type_match(self, mock_db, mock_session_class, factory): + """Test that query ordering prioritizes exact type matches.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + target_type = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + EndUserService.get_or_create_end_user_by_type( + type=target_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + mock_query.order_by.assert_called_once() + # Verify that case statement is used for ordering + order_by_call = mock_query.order_by.call_args[0][0] + # The exact structure depends on SQLAlchemy's case implementation + # but we can verify it was called + # Test 10: Session context manager properly closes @patch("services.end_user_service.Session") @patch("services.end_user_service.db") @@ -93,3 +417,425 @@ class TestEndUserServiceGetOrCreateEndUserByType: # Verify context manager was entered and exited mock_context.__enter__.assert_called_once() mock_context.__exit__.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_all_invokefrom_types_supported(self, mock_db, mock_session_class): + """Test that all InvokeFrom enum values are supported.""" + # Arrange + tenant_id = "tenant-123" + app_id = "app-456" + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.first.return_value = None + + # Act + result = EndUserService.get_or_create_end_user_by_type( + type=invoke_type, tenant_id=tenant_id, app_id=app_id, user_id=user_id + ) + + # Assert + added_user = mock_session.add.call_args[0][0] + assert added_user.type == invoke_type + + +class TestEndUserServiceCreateEndUserBatch: + """Unit tests for EndUserService.create_end_user_batch method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestEndUserServiceFactory() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_empty_app_ids(self, mock_db, mock_session_class): + """Test batch creation with empty app_ids list.""" + # Arrange + tenant_id = "tenant-123" + app_ids: list[str] = [] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert result == {} + mock_session_class.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_default_session_id(self, mock_db, mock_session_class): + """Test batch creation with empty user_id (uses default session).""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + for app_id, end_user in result.items(): + assert end_user.session_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user.external_user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID + assert end_user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_deduplicate_app_ids(self, mock_db, mock_session_class): + """Test that duplicate app_ids are deduplicated while preserving order.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-456", "app-123", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + # Should have 3 unique app_ids in original order + assert len(result) == 3 + assert "app-456" in result + assert "app-789" in result + assert "app-123" in result + + # Verify the order is preserved + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 3 + assert added_users[0].app_id == "app-456" + assert added_users[1].app_id == "app-789" + assert added_users[2].app_id == "app-123" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation when all users already exist.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-789", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 2 + assert result["app-456"] == existing_user1 + assert result["app-789"] == existing_user2 + mock_session.add_all.assert_not_called() + mock_session.commit.assert_not_called() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_partial_existing_users(self, mock_db, mock_session_class, factory): + """Test batch creation with some existing and some new users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + existing_user1 = factory.create_end_user_mock( + tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + # app-789 and app-123 don't exist + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 3 + assert result["app-456"] == existing_user1 + assert "app-789" in result + assert "app-123" in result + + # Should create 2 new users + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 2 + + mock_session.commit.assert_called_once() + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_handles_duplicates_in_existing(self, mock_db, mock_session_class, factory): + """Test batch creation handles duplicates in existing users gracefully.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + # Simulate duplicate records in database + existing_user1 = factory.create_end_user_mock( + user_id="user-1", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + existing_user2 = factory.create_end_user_mock( + user_id="user-2", tenant_id=tenant_id, app_id="app-456", session_id=user_id, type=type_enum + ) + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [existing_user1, existing_user2] + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + # Should prefer the first one found + assert result["app-456"] == existing_user1 + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_all_invokefrom_types(self, mock_db, mock_session_class): + """Test batch creation with all InvokeFrom types.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + + for invoke_type in InvokeFrom: + with patch("services.end_user_service.Session") as mock_session_class: + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=invoke_type, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + added_user = mock_session.add_all.call_args[0][0][0] + assert added_user.type == invoke_type + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_single_app_id(self, mock_db, mock_session_class, factory): + """Test batch creation with single app_id.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + result = EndUserService.create_end_user_batch( + type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id + ) + + # Assert + assert len(result) == 1 + assert "app-456" in result + mock_session.add_all.assert_called_once() + added_users = mock_session.add_all.call_args[0][0] + assert len(added_users) == 1 + assert added_users[0].app_id == "app-456" + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_anonymous_vs_authenticated(self, mock_db, mock_session_class): + """Test batch creation correctly sets anonymous flag.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789"] + + # Test with regular user ID + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act - authenticated user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, tenant_id=tenant_id, app_ids=app_ids, user_id="user-789" + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is False + + # Test with default session ID + mock_session.reset_mock() + mock_query.reset_mock() + mock_query.all.return_value = [] + + # Act - anonymous user + result = EndUserService.create_end_user_batch( + type=InvokeFrom.SERVICE_API, + tenant_id=tenant_id, + app_ids=app_ids, + user_id=DefaultEndUserSessionID.DEFAULT_SESSION_ID, + ) + + # Assert + added_users = mock_session.add_all.call_args[0][0] + for user in added_users: + assert user._is_anonymous is True + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_efficient_single_query(self, mock_db, mock_session_class): + """Test that batch creation uses efficient single query for existing users.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456", "app-789", "app-123"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + # Should make exactly one query to check for existing users + mock_session.query.assert_called_once_with(EndUser) + mock_query.where.assert_called_once() + mock_query.all.assert_called_once() + + # Verify the where clause uses .in_() for app_ids + where_call = mock_query.where.call_args[0] + # The exact structure depends on SQLAlchemy implementation + # but we can verify it was called with the right parameters + + @patch("services.end_user_service.Session") + @patch("services.end_user_service.db") + def test_create_batch_session_context_manager(self, mock_db, mock_session_class): + """Test that batch creation properly uses session context manager.""" + # Arrange + tenant_id = "tenant-123" + app_ids = ["app-456"] + user_id = "user-789" + type_enum = InvokeFrom.SERVICE_API + + mock_session = MagicMock() + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_session + mock_session_class.return_value = mock_context + + mock_query = MagicMock() + mock_session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.all.return_value = [] # No existing users + + # Act + EndUserService.create_end_user_batch(type=type_enum, tenant_id=tenant_id, app_ids=app_ids, user_id=user_id) + + # Assert + mock_context.__enter__.assert_called_once() + mock_context.__exit__.assert_called_once() + mock_session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_export_app_messages.py b/api/tests/unit_tests/services/test_export_app_messages.py new file mode 100644 index 0000000000..5f2d3f21c0 --- /dev/null +++ b/api/tests/unit_tests/services/test_export_app_messages.py @@ -0,0 +1,43 @@ +import datetime + +import pytest + +from services.retention.conversation.message_export_service import AppMessageExportService + + +def test_validate_export_filename_accepts_relative_path(): + assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01" + + +@pytest.mark.parametrize( + "filename", + [ + "test01.jsonl.gz", + "test01.jsonl", + "test01.gz", + "/tmp/test01", + "exports/../test01", + "bad\x00name", + "bad\tname", + "a" * 1025, + ], +) +def test_validate_export_filename_rejects_invalid_values(filename: str): + with pytest.raises(ValueError): + AppMessageExportService.validate_export_filename(filename) + + +def test_service_derives_output_names_from_filename_base(): + service = AppMessageExportService( + app_id="736b9b03-20f2-4697-91da-8d00f6325900", + start_from=None, + end_before=datetime.datetime(2026, 3, 1), + filename="exports/2026/test01", + batch_size=1000, + use_cloud_storage=True, + dry_run=True, + ) + + assert service._filename_base == "exports/2026/test01" + assert service.output_gz_name == "exports/2026/test01.jsonl.gz" + assert service.output_jsonl_name == "exports/2026/test01.jsonl" diff --git a/api/tests/unit_tests/services/test_file_service.py b/api/tests/unit_tests/services/test_file_service.py new file mode 100644 index 0000000000..b7259c3e82 --- /dev/null +++ b/api/tests/unit_tests/services/test_file_service.py @@ -0,0 +1,420 @@ +import base64 +import hashlib +import os +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy import Engine +from sqlalchemy.orm import Session, sessionmaker +from werkzeug.exceptions import NotFound + +from configs import dify_config +from models.enums import CreatorUserRole +from models.model import Account, EndUser, UploadFile +from services.errors.file import BlockedFileExtensionError, FileTooLargeError, UnsupportedFileTypeError +from services.file_service import FileService + + +class TestFileService: + @pytest.fixture + def mock_db_session(self): + session = MagicMock(spec=Session) + # Mock context manager behavior + session.__enter__.return_value = session + return session + + @pytest.fixture + def mock_session_maker(self, mock_db_session): + maker = MagicMock(spec=sessionmaker) + maker.return_value = mock_db_session + return maker + + @pytest.fixture + def file_service(self, mock_session_maker): + return FileService(session_factory=mock_session_maker) + + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + service = FileService(session_factory=engine) + assert isinstance(service._session_maker, sessionmaker) + + def test_init_with_sessionmaker(self): + maker = MagicMock(spec=sessionmaker) + service = FileService(session_factory=maker) + assert service._session_maker == maker + + def test_init_invalid_factory(self): + with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."): + FileService(session_factory="invalid") + + @patch("services.file_service.storage") + @patch("services.file_service.naive_utc_now") + @patch("services.file_service.extract_tenant_id") + @patch("services.file_service.file_helpers.get_signed_file_url") + def test_upload_file_success( + self, mock_get_url, mock_tenant_id, mock_now, mock_storage, file_service, mock_db_session + ): + # Setup + mock_tenant_id.return_value = "tenant_id" + mock_now.return_value = "2024-01-01" + mock_get_url.return_value = "http://signed-url" + + user = MagicMock(spec=Account) + user.id = "user_id" + content = b"file content" + filename = "test.jpg" + mimetype = "image/jpeg" + + # Execute + result = file_service.upload_file(filename=filename, content=content, mimetype=mimetype, user=user) + + # Assert + assert isinstance(result, UploadFile) + assert result.name == filename + assert result.tenant_id == "tenant_id" + assert result.size == len(content) + assert result.extension == "jpg" + assert result.mime_type == mimetype + assert result.created_by_role == CreatorUserRole.ACCOUNT + assert result.created_by == "user_id" + assert result.hash == hashlib.sha3_256(content).hexdigest() + assert result.source_url == "http://signed-url" + + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once_with(result) + mock_db_session.commit.assert_called_once() + + def test_upload_file_invalid_characters(self, file_service): + with pytest.raises(ValueError, match="Filename contains invalid characters"): + file_service.upload_file(filename="invalid/file.txt", content=b"", mimetype="text/plain", user=MagicMock()) + + def test_upload_file_long_filename(self, file_service, mock_db_session): + # Setup + long_name = "a" * 210 + ".txt" + user = MagicMock(spec=Account) + user.id = "user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename=long_name, content=b"test", mimetype="text/plain", user=user) + assert len(result.name) <= 205 # 200 + . + extension + assert result.name.endswith(".txt") + + def test_upload_file_blocked_extension(self, file_service): + with patch.object(dify_config, "inner_UPLOAD_FILE_EXTENSION_BLACKLIST", "exe"): + with pytest.raises(BlockedFileExtensionError): + file_service.upload_file( + filename="test.exe", content=b"", mimetype="application/octet-stream", user=MagicMock() + ) + + def test_upload_file_unsupported_type_for_datasets(self, file_service): + with pytest.raises(UnsupportedFileTypeError): + file_service.upload_file( + filename="test.jpg", content=b"", mimetype="image/jpeg", user=MagicMock(), source="datasets" + ) + + def test_upload_file_too_large(self, file_service): + # 16MB file for an image with 15MB limit + content = b"a" * (16 * 1024 * 1024) + with patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 15): + with pytest.raises(FileTooLargeError): + file_service.upload_file(filename="test.jpg", content=content, mimetype="image/jpeg", user=MagicMock()) + + def test_upload_file_end_user(self, file_service, mock_db_session): + user = MagicMock(spec=EndUser) + user.id = "end_user_id" + + with ( + patch("services.file_service.storage"), + patch("services.file_service.extract_tenant_id") as mock_tenant, + patch("services.file_service.file_helpers.get_signed_file_url"), + ): + mock_tenant.return_value = "tenant" + result = file_service.upload_file(filename="test.txt", content=b"test", mimetype="text/plain", user=user) + assert result.created_by_role == CreatorUserRole.END_USER + + def test_is_file_size_within_limit(self): + with ( + patch.object(dify_config, "UPLOAD_IMAGE_FILE_SIZE_LIMIT", 10), + patch.object(dify_config, "UPLOAD_VIDEO_FILE_SIZE_LIMIT", 20), + patch.object(dify_config, "UPLOAD_AUDIO_FILE_SIZE_LIMIT", 30), + patch.object(dify_config, "UPLOAD_FILE_SIZE_LIMIT", 5), + ): + # Image + assert FileService.is_file_size_within_limit(extension="jpg", file_size=10 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="png", file_size=11 * 1024 * 1024) is False + + # Video + assert FileService.is_file_size_within_limit(extension="mp4", file_size=20 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="avi", file_size=21 * 1024 * 1024) is False + + # Audio + assert FileService.is_file_size_within_limit(extension="mp3", file_size=30 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="wav", file_size=31 * 1024 * 1024) is False + + # Default + assert FileService.is_file_size_within_limit(extension="txt", file_size=5 * 1024 * 1024) is True + assert FileService.is_file_size_within_limit(extension="pdf", file_size=6 * 1024 * 1024) is False + + def test_get_file_base64_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "test_key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load_once.return_value = b"test content" + + # Execute + result = file_service.get_file_base64("file_id") + + # Assert + assert result == base64.b64encode(b"test content").decode() + mock_storage.load_once.assert_called_once_with("test_key") + + def test_get_file_base64_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_base64("non_existent") + + def test_upload_text_success(self, file_service, mock_db_session): + # Setup + text = "sample text" + text_name = "test.txt" + user_id = "user_id" + tenant_id = "tenant_id" + + with patch("services.file_service.storage") as mock_storage: + # Execute + result = file_service.upload_text(text, text_name, user_id, tenant_id) + + # Assert + assert result.name == text_name + assert result.size == len(text) + assert result.tenant_id == tenant_id + assert result.created_by == user_id + assert result.used is True + assert result.extension == "txt" + mock_storage.save.assert_called_once() + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_upload_text_long_name(self, file_service, mock_db_session): + long_name = "a" * 210 + with patch("services.file_service.storage"): + result = file_service.upload_text("text", long_name, "user", "tenant") + assert len(result.name) == 200 + + def test_get_file_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "pdf" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.ExtractProcessor.load_from_upload_file") as mock_extract: + mock_extract.return_value = "Extracted text content" + + # Execute + result = file_service.get_file_preview("file_id") + + # Assert + assert result == "Extracted text content" + + def test_get_file_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_preview("non_existent") + + def test_get_file_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "exe" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_file_preview("file_id") + + def test_get_image_preview_success(self, file_service, mock_db_session): + # Setup + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "jpg" + upload_file.mime_type = "image/jpeg" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_image_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk1"]) + + # Execute + gen, mime = file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + # Assert + assert list(gen) == [b"chunk1"] + assert mime == "image/jpeg" + + def test_get_image_preview_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with patch("services.file_service.file_helpers.verify_image_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(UnsupportedFileTypeError): + file_service.get_image_preview("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with ( + patch("services.file_service.file_helpers.verify_file_signature") as mock_verify, + patch("services.file_service.storage") as mock_storage, + ): + mock_verify.return_value = True + mock_storage.load.return_value = iter([b"chunk"]) + + gen, file = file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + assert list(gen) == [b"chunk"] + assert file == upload_file + + def test_get_file_generator_by_file_id_invalid_sig(self, file_service): + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = False + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_file_generator_by_file_id_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with patch("services.file_service.file_helpers.verify_file_signature") as mock_verify: + mock_verify.return_value = True + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_file_generator_by_file_id("file_id", "ts", "nonce", "sign") + + def test_get_public_image_preview_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "png" + upload_file.mime_type = "image/png" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"image content" + gen, mime = file_service.get_public_image_preview("file_id") + assert gen == b"image content" + assert mime == "image/png" + + def test_get_public_image_preview_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found or signature is invalid"): + file_service.get_public_image_preview("file_id") + + def test_get_public_image_preview_unsupported_type(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.extension = "txt" + mock_db_session.query().where().first.return_value = upload_file + with pytest.raises(UnsupportedFileTypeError): + file_service.get_public_image_preview("file_id") + + def test_get_file_content_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + mock_db_session.query().where().first.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + mock_storage.load.return_value = b"hello world" + result = file_service.get_file_content("file_id") + assert result == "hello world" + + def test_get_file_content_not_found(self, file_service, mock_db_session): + mock_db_session.query().where().first.return_value = None + with pytest.raises(NotFound, match="File not found"): + file_service.get_file_content("file_id") + + def test_delete_file_success(self, file_service, mock_db_session): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "file_id" + upload_file.key = "key" + # For session.scalar(select(...)) + mock_db_session.scalar.return_value = upload_file + + with patch("services.file_service.storage") as mock_storage: + file_service.delete_file("file_id") + mock_storage.delete.assert_called_once_with("key") + mock_db_session.delete.assert_called_once_with(upload_file) + + def test_delete_file_not_found(self, file_service, mock_db_session): + mock_db_session.scalar.return_value = None + file_service.delete_file("file_id") + # Should return without doing anything + + @patch("services.file_service.db") + def test_get_upload_files_by_ids_empty(self, mock_db): + result = FileService.get_upload_files_by_ids("tenant_id", []) + assert result == {} + + @patch("services.file_service.db") + def test_get_upload_files_by_ids(self, mock_db): + upload_file = MagicMock(spec=UploadFile) + upload_file.id = "550e8400-e29b-41d4-a716-446655440000" + upload_file.tenant_id = "tenant_id" + mock_db.session.scalars().all.return_value = [upload_file] + + result = FileService.get_upload_files_by_ids("tenant_id", ["550e8400-e29b-41d4-a716-446655440000"]) + assert result["550e8400-e29b-41d4-a716-446655440000"] == upload_file + + def test_sanitize_zip_entry_name(self): + assert FileService._sanitize_zip_entry_name("path/to/file.txt") == "file.txt" + assert FileService._sanitize_zip_entry_name("../../../etc/passwd") == "passwd" + assert FileService._sanitize_zip_entry_name(" ") == "file" + assert FileService._sanitize_zip_entry_name("a\\b") == "a_b" + + def test_dedupe_zip_entry_name(self): + used = {"a.txt"} + assert FileService._dedupe_zip_entry_name("b.txt", used) == "b.txt" + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (1).txt" + used.add("a (1).txt") + assert FileService._dedupe_zip_entry_name("a.txt", used) == "a (2).txt" + + def test_build_upload_files_zip_tempfile(self): + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.txt" + upload_file.key = "key" + + with ( + patch("services.file_service.storage") as mock_storage, + patch("services.file_service.os.remove") as mock_remove, + ): + mock_storage.load.return_value = [b"chunk1", b"chunk2"] + + with FileService.build_upload_files_zip_tempfile(upload_files=[upload_file]) as tmp_path: + assert os.path.exists(tmp_path) + + mock_remove.assert_called_once() diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/unit_tests/services/test_hit_testing_service.py new file mode 100644 index 0000000000..80e9729f5b --- /dev/null +++ b/api/tests/unit_tests/services/test_hit_testing_service.py @@ -0,0 +1,385 @@ +import json +from typing import Any, cast +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from core.rag.models.document import Document +from models.dataset import Dataset +from services.hit_testing_service import HitTestingService + + +class TestHitTestingService: + """Test suite for HitTestingService""" + + # ===== Utility Method Tests ===== + + def test_escape_query_for_search_should_escape_double_quotes(self): + """Test that escape_query_for_search escapes double quotes correctly""" + # Arrange + query = 'test "query" with quotes' + expected = 'test \\"query\\" with quotes' + + # Act + result = HitTestingService.escape_query_for_search(query) + + # Assert + assert result == expected + + def test_hit_testing_args_check_should_pass_with_valid_query(self): + """Test that hit_testing_args_check passes with a valid query""" + # Arrange + args = {"query": "valid query"} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_pass_with_valid_attachments(self): + """Test that hit_testing_args_check passes with valid attachment_ids""" + # Arrange + args = {"attachment_ids": ["id1", "id2"]} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): + """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" + # Arrange + args = {} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query or attachment_ids is required" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): + """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" + # Arrange + args = {"query": "a" * 251} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query cannot exceed 250 characters" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): + """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" + # Arrange + args = {"attachment_ids": "not a list"} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Attachment_ids must be a list" in str(exc_info.value) + + # ===== Response Formatting Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") + def test_compact_retrieve_response_should_format_correctly(self, mock_format): + """Test that compact_retrieve_response formats the response correctly""" + # Arrange + query = "test query" + mock_doc = MagicMock(spec=Document) + documents = [mock_doc] + + mock_record = MagicMock() + mock_record.model_dump.return_value = {"content": "formatted content"} + mock_format.return_value = [mock_record] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 1 + assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" + mock_format.assert_called_once_with(documents) + + def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): + """Test that compact_external_retrieve_response returns records when dataset provider is external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "external" + query = "test query" + documents = [ + {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, + {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, + ] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 2 + assert cast(dict[str, Any], result["records"][0])["content"] == "c1" + assert cast(dict[str, Any], result["records"][1])["title"] == "t2" + + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): + """Test that compact_external_retrieve_response returns empty records for non-external provider""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + documents = [{"content": "c1"}] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== External Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): + """Test that external_retrieve successfully retrieves from external provider and commits query""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.provider = "external" + query = 'test "query"' + account = MagicMock() + account.id = "account_id" + + mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] + + # Act + result = cast( + dict[str, Any], + HitTestingService.external_retrieve( + dataset=dataset, + query=query, + account=account, + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" + + # Verify call to RetrievalService.external_retrieve with escaped query + mock_ext_retrieve.assert_called_once_with( + dataset_id="dataset_id", + query='test \\"query\\"', + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ) + + # Verify DatasetQuery record was added and committed + mock_add.assert_called_once() + mock_commit.assert_called_once() + + def test_external_retrieve_should_return_empty_for_non_external_provider(self): + """Test that external_retrieve returns empty results immediately if provider is not external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + account = MagicMock() + + # Act + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve uses default model when retrieval_model is not provided""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.retrieval_model = None + query = "test query" + account = MagicMock() + account.id = "account_id" + + mock_retrieve.return_value = [] + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + mock_retrieve.assert_called_once() + # Verify top_k from default_retrieval_model (4) + assert mock_retrieve.call_args.kwargs["top_k"] == 4 + mock_commit.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): + """Test that retrieve correctly calls metadata filtering when conditions are present""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response + mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_get_meta.assert_called_once() + mock_retrieve.assert_called_once() + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): + """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response: condition returned but no IDs + mock_get_meta.return_value = ({}, "condition_string") + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + ), + ) + + # Assert + assert result["records"] == [] + mock_retrieve.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + attachment_ids = ["att1", "att2"] + + retrieval_model = { + "search_method": "semantic_search", + "top_k": 4, + "reranking_enable": False, + "score_threshold_enabled": False, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + attachment_ids=attachment_ids, + ) + + # Assert + mock_retrieve.assert_called_once_with( + retrieval_method=ANY, + dataset_id="dataset_id", + query=query, + attachment_ids=attachment_ids, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + ) + # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) + # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) + called_query = mock_add.call_args[0][0] + query_content = json.loads(called_query.content) + assert len(query_content) == 3 # 1 text + 2 images + assert query_content[0]["content_type"] == "text_query" + assert query_content[1]["content_type"] == "image_query" + assert query_content[1]["content"] == "att1" + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve passes reranking and threshold parameters correctly""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "hybrid_search", + "top_k": 10, + "reranking_enable": True, + "reranking_model": {"provider": "test"}, + "reranking_mode": "weighted_sum", + "score_threshold_enabled": True, + "score_threshold": 0.5, + "weights": {"vector": 0.5, "keyword": 0.5}, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_retrieve.assert_called_once() + kwargs = mock_retrieve.call_args.kwargs + assert kwargs["score_threshold"] == 0.5 + assert kwargs["reranking_model"] == {"provider": "test"} + assert kwargs["reranking_mode"] == "weighted_sum" + assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5} diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py index e0d6ad1b39..a23c44b26e 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/unit_tests/services/test_human_input_delivery_test_service.py @@ -1,97 +1,330 @@ from types import SimpleNamespace +from unittest.mock import MagicMock, patch import pytest +from sqlalchemy.engine import Engine -from core.workflow.nodes.human_input.entities import ( +from configs import dify_config +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, + MemberRecipient, ) -from core.workflow.runtime import VariablePool +from dify_graph.runtime import VariablePool from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, + DeliveryTestEmailRecipient, DeliveryTestError, + DeliveryTestRegistry, + DeliveryTestResult, + DeliveryTestStatus, + DeliveryTestUnsupportedError, EmailDeliveryTestHandler, + HumanInputDeliveryTestService, + _build_form_link, ) -def _make_email_method() -> EmailDeliveryMethod: - return EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients( - whole_workspace=False, - items=[ExternalRecipient(email="tester@example.com")], - ), - subject="Test subject", - body="Test body", +@pytest.fixture +def mock_db(monkeypatch): + mock_db = MagicMock() + monkeypatch.setattr(service_module, "db", mock_db) + return mock_db + + +def _make_valid_email_config(): + return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + + +def test_build_form_link(): + with patch.object(dify_config, "APP_WEB_URL", "http://example.com/"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + with patch.object(dify_config, "APP_WEB_URL", "http://example.com"): + assert _build_form_link("token123") == "http://example.com/form/token123" + + assert _build_form_link(None) is None + + with patch.object(dify_config, "APP_WEB_URL", None): + assert _build_form_link("token123") is None + + +class TestDeliveryTestRegistry: + def test_register(self): + registry = DeliveryTestRegistry() + assert len(registry._handlers) == 0 + handler = MagicMock() + registry.register(handler) + assert len(registry._handlers) == 1 + assert registry._handlers[0] == handler + + def test_register_and_dispatch(self): + handler = MagicMock() + handler.supports.return_value = True + handler.send_test.return_value = DeliveryTestResult(status=DeliveryTestStatus.OK) + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + result = registry.dispatch(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + handler.supports.assert_called_once_with(method) + handler.send_test.assert_called_once_with(context=context, method=method) + + def test_dispatch_unsupported(self): + handler = MagicMock() + handler.supports.return_value = False + + registry = DeliveryTestRegistry([handler]) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): + registry.dispatch(context=context, method=method) + + def test_default(self, mock_db): + registry = DeliveryTestRegistry.default() + assert len(registry._handlers) == 1 + assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) + + +def test_human_input_delivery_test_service(): + registry = MagicMock(spec=DeliveryTestRegistry) + service = HumanInputDeliveryTestService(registry=registry) + context = MagicMock(spec=DeliveryTestContext) + method = MagicMock() + + service.send_test(context=context, method=method) + registry.dispatch.assert_called_once_with(context=context, method=method) + + +class TestEmailDeliveryTestHandler: + def test_init_with_engine(self): + engine = MagicMock(spec=Engine) + handler = EmailDeliveryTestHandler(session_factory=engine) + assert handler._session_factory.kw["bind"] == engine + + def test_supports(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + assert handler.supports(method) is True + assert handler.supports(MagicMock()) is False + + def test_send_test_unsupported_method(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + with pytest.raises(DeliveryTestUnsupportedError): + handler.send_test(context=MagicMock(), method=MagicMock()) + + def test_send_test_feature_disabled(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), ) - ) + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + with pytest.raises(DeliveryTestError, match="Email delivery is not available"): + handler.send_test(context=context, method=method) -def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch): - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False), - ) + def test_send_test_mail_not_inited(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: False) - handler = EmailDeliveryTestHandler(session_factory=object()) - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - ) - method = _make_email_method() + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="Mail client is not initialized."): + handler.send_test(context=context, method=method) + + def test_send_test_no_recipients(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=[]) + + context = DeliveryTestContext( + tenant_id="t1", app_id="a1", node_id="n1", node_title="title", rendered_content="content" + ) + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + with pytest.raises(DeliveryTestError, match="No recipients configured"): + handler.send_test(context=context, method=method) + + def test_send_test_success(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + mock_mail_send = MagicMock() + monkeypatch.setattr(service_module.mail, "send", mock_mail_send) + monkeypatch.setattr(service_module, "render_email_template", lambda t, s: f"RENDERED_{t}") + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=["test@example.com"]) + + variable_pool = VariablePool() + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + variable_pool=variable_pool, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + method = EmailDeliveryMethod(config=_make_valid_email_config()) + + result = handler.send_test(context=context, method=method) + + assert result.status == DeliveryTestStatus.OK + assert result.delivered_to == ["test@example.com"] + mock_mail_send.assert_called_once() + args, kwargs = mock_mail_send.call_args + assert kwargs["to"] == "test@example.com" + assert "RENDERED_Subj" in kwargs["subject"] + + def test_send_test_sanitizes_subject(self, monkeypatch): + monkeypatch.setattr( + service_module.FeatureService, + "get_features", + lambda _id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(service_module.mail, "is_inited", lambda: True) + mock_mail_send = MagicMock() + monkeypatch.setattr(service_module.mail, "send", mock_mail_send) + monkeypatch.setattr( + service_module, + "render_email_template", + lambda template, substitutions: template.replace("{{ recipient_email }}", substitutions["recipient_email"]), + ) + + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) + handler._resolve_recipients = MagicMock(return_value=["test@example.com"]) + + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(whole_workspace=False, items=[]), + subject="Notice\r\nBCC:{{ recipient_email }}", + body="Body", + ) + ) - with pytest.raises(DeliveryTestError, match="Email delivery is not available"): handler.send_test(context=context, method=method) + _, kwargs = mock_mail_send.call_args + assert kwargs["subject"] == "Notice BCC:test@example.com" -def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch): - class DummyMail: - def __init__(self): - self.sent: list[dict[str, str]] = [] + def test_resolve_recipients(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) - def is_inited(self) -> bool: - return True - - def send(self, *, to: str, subject: str, html: str): - self.sent.append({"to": to, "subject": subject, "html": html}) - - mail = DummyMail() - monkeypatch.setattr(service_module, "mail", mail) - monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template) - monkeypatch.setattr( - service_module.FeatureService, - "get_features", - lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), - ) - - handler = EmailDeliveryTestHandler(session_factory=object()) - handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment] - - method = EmailDeliveryMethod( - config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]), - subject="Subject", - body="Value {{#node1.value#}}", + # Test Case 1: External Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + subject="", + body="", + ) ) - ) - variable_pool = VariablePool() - variable_pool.add(["node1", "value"], "OK") - context = DeliveryTestContext( - tenant_id="tenant-1", - app_id="app-1", - node_id="node-1", - node_title="Human Input", - rendered_content="content", - variable_pool=variable_pool, - ) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - handler.send_test(context=context, method=method) + # Test Case 2: Member Recipient + method = EmailDeliveryMethod( + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + subject="", + body="", + ) + ) + handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) + assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] - assert mail.sent[0]["html"] == "Value OK" + # Test Case 3: Whole Workspace + method = EmailDeliveryMethod( + config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + ) + handler._query_workspace_member_emails = MagicMock( + return_value={"u1": "u1@example.com", "u2": "u2@example.com"} + ) + recipients = handler._resolve_recipients(tenant_id="t1", method=method) + assert set(recipients) == {"u1@example.com", "u2@example.com"} + + def test_query_workspace_member_emails(self): + mock_session = MagicMock() + mock_session_factory = MagicMock(return_value=mock_session) + mock_session.__enter__.return_value = mock_session + + handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) + + # Empty user_ids + assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} + + # user_ids is None (all) + mock_execute = MagicMock() + mock_session.execute.return_value = mock_execute + mock_execute.all.return_value = [("u1", "u1@example.com")] + + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) + assert result == {"u1": "u1@example.com"} + + # user_ids with values + result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) + assert result == {"u1": "u1@example.com"} + + def test_build_substitutions(self): + context = DeliveryTestContext( + tenant_id="t1", + app_id="a1", + node_id="n1", + node_title="title", + rendered_content="content", + template_vars={"custom": "var"}, + recipients=[DeliveryTestEmailRecipient(email="test@example.com", form_token="token123")], + ) + + subs = EmailDeliveryTestHandler._build_substitutions(context=context, recipient_email="test@example.com") + + assert subs["node_title"] == "title" + assert subs["form_content"] == "content" + assert subs["recipient_email"] == "test@example.com" + assert subs["custom"] == "var" + assert subs["form_token"] == "token123" + assert "form/token123" in subs["form_link"] + + # Without matching recipient + subs_no_match = EmailDeliveryTestHandler._build_substitutions( + context=context, recipient_email="other@example.com" + ) + assert subs_no_match["form_token"] == "" + assert subs_no_match["form_link"] == "" diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 5800d029ca..375e47d7fc 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -9,14 +9,20 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from core.workflow.nodes.human_input.entities import ( +from dify_graph.nodes.human_input.entities import ( FormDefinition, FormInput, UserAction, ) -from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus from models.human_input import RecipientType -from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError +from services.human_input_service import ( + Form, + FormExpiredError, + FormSubmittedError, + HumanInputService, + InvalidFormDataError, +) @pytest.fixture @@ -285,3 +291,172 @@ def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_fa assert "Missing required inputs" in str(exc_info.value) repo.mark_submitted.assert_not_called() + + +def test_form_properties(sample_form_record): + form = Form(sample_form_record) + assert form.id == "form-id" + assert form.workflow_run_id == "workflow-run-id" + assert form.tenant_id == "tenant-id" + assert form.app_id == "app-id" + assert form.recipient_id == "recipient-id" + assert form.recipient_type == RecipientType.STANDALONE_WEB_APP + assert form.status == HumanInputFormStatus.WAITING + assert form.form_kind == HumanInputFormKind.RUNTIME + assert isinstance(form.created_at, datetime) + assert isinstance(form.expiration_time, datetime) + + +def test_form_submitted_error_init(): + error = FormSubmittedError(form_id="test-form") + assert "form_id=test-form" in error.description + assert error.code == 412 + + +def test_human_input_service_init_with_engine(mocker): + engine = MagicMock(spec=human_input_service_module.Engine) + sessionmaker_mock = mocker.patch("services.human_input_service.sessionmaker") + + HumanInputService(session_factory=engine) + sessionmaker_mock.assert_called_once_with(bind=engine) + + +def test_get_form_by_token_none(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_by_token("invalid") is None + + +def test_get_form_definition_by_token_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + # RecipientType mismatch + assert service.get_form_definition_by_token(RecipientType.CONSOLE, "token") is None + + +def test_get_form_definition_by_token_success(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + service = HumanInputService(session_factory, form_repository=repo) + form = service.get_form_definition_by_token(RecipientType.STANDALONE_WEB_APP, "token") + assert form is not None + assert form.id == sample_form_record.form_id + + +def test_get_form_definition_by_token_for_console_mismatch(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record # is STANDALONE_WEB_APP + + service = HumanInputService(session_factory, form_repository=repo) + assert service.get_form_definition_by_token_for_console("token") is None + + +def test_submit_form_by_token_delivery_not_enabled(mock_session_factory): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = None + + service = HumanInputService(session_factory, form_repository=repo) + with pytest.raises(human_input_service_module.WebAppDeliveryNotEnabledError): + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "action", {}) + + +def test_submit_form_by_token_no_workflow_run_id(sample_form_record, mock_session_factory, mocker): + session_factory, _ = mock_session_factory + repo = MagicMock(spec=HumanInputFormSubmissionRepository) + repo.get_by_token.return_value = sample_form_record + + # Return record with no workflow_run_id + result_record = dataclasses.replace(sample_form_record, workflow_run_id=None) + repo.mark_submitted.return_value = result_record + + service = HumanInputService(session_factory, form_repository=repo) + enqueue_spy = mocker.patch.object(service, "enqueue_resume") + + service.submit_form_by_token(RecipientType.STANDALONE_WEB_APP, "token", "submit", {}) + enqueue_spy.assert_not_called() + + +def test_ensure_form_active_errors(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + # Submitted + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + with pytest.raises(human_input_service_module.FormSubmittedError): + service.ensure_form_active(Form(submitted_record)) + + # Timeout status + timeout_record = dataclasses.replace(sample_form_record, status=HumanInputFormStatus.TIMEOUT) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(timeout_record)) + + # Expired time + expired_time_record = dataclasses.replace( + sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + ) + with pytest.raises(FormExpiredError): + service.ensure_form_active(Form(expired_time_record)) + + +def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + + with pytest.raises(human_input_service_module.FormSubmittedError): + service._ensure_not_submitted(Form(submitted_record)) + + +def test_enqueue_resume_workflow_not_found(mocker, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = None + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + with pytest.raises(AssertionError) as excinfo: + service.enqueue_resume("workflow-run-id") + assert "WorkflowRun not found" in str(excinfo.value) + + +def test_enqueue_resume_app_not_found(mocker, mock_session_factory): + session_factory, session = mock_session_factory + service = HumanInputService(session_factory) + + workflow_run = MagicMock() + workflow_run.app_id = "app-id" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run + mocker.patch( + "services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + session.execute.return_value.scalar_one_or_none.return_value = None + logger_spy = mocker.patch("services.human_input_service.logger") + + service.enqueue_resume("workflow-run-id") + logger_spy.error.assert_called_once() + + +def test_is_globally_expired_zero_timeout(monkeypatch, sample_form_record, mock_session_factory): + session_factory, _ = mock_session_factory + service = HumanInputService(session_factory) + + monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 0) + assert service._is_globally_expired(Form(sample_form_record)) is False diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py new file mode 100644 index 0000000000..bc0caee071 --- /dev/null +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -0,0 +1,146 @@ +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest + +from services.knowledge_service import ExternalDatasetTestService + + +class TestKnowledgeService: + """Test suite for ExternalDatasetTestService""" + + # ===== Happy Path Tests ===== + + @patch("services.knowledge_service.boto3.client") + @patch("services.knowledge_service.dify_config") + def test_knowledge_retrieval_should_succeed_with_valid_results( + self, mock_dify_config: MagicMock, mock_boto_client: MagicMock + ): + """Test that knowledge_retrieval successfully parses results from Bedrock""" + # Arrange + mock_dify_config.AWS_SECRET_ACCESS_KEY = "dummy_secret" + mock_dify_config.AWS_ACCESS_KEY_ID = "dummy_id" + + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + query = "test query" + knowledge_id = "kb-123" + + # Mock successful response + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.9, + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc1.pdf"}, + "content": {"text": "content from doc1"}, + }, + { + "score": 0.4, # Below threshold + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc2.pdf"}, + "content": {"text": "content from doc2"}, + }, + ], + } + + # Act + result = cast( + dict[str, Any], ExternalDatasetTestService.knowledge_retrieval(retrieval_setting, query, knowledge_id) + ) + + # Assert + assert len(result["records"]) == 1 + record = result["records"][0] + assert record["score"] == 0.9 + assert record["title"] == "s3://bucket/doc1.pdf" + assert record["content"] == "content from doc1" + + # verify retrieve called correctly + mock_client.retrieve.assert_called_once_with( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 4, + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + + # NEW: verify boto3.client created with proper service name and config values + mock_boto_client.assert_called_once_with( + "bedrock-agent-runtime", + aws_secret_access_key="dummy_secret", + aws_access_key_id="dummy_id", + region_name="us-east-1", + ) + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_when_no_results(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records when Bedrock returns nothing""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + # ===== Error Handling Tests ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_on_http_error(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records if Bedrock returns non-200 status""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + def test_knowledge_retrieval_should_raise_when_boto_client_creation_fails(self): + """Test that exceptions from boto3.client propagate (e.g., network/credentials issues)""" + with patch("services.knowledge_service.boto3.client") as mock_boto: + mock_boto.side_effect = Exception("client init failed") + with pytest.raises(Exception) as exc_info: + ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + assert "client init failed" in str(exc_info.value) + + # ===== Edge Cases ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_handle_missing_threshold_in_settings(self, mock_boto: MagicMock): + """Test that knowledge_retrieval uses 0.0 as default threshold if not provided""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.1, + "metadata": {"x-amz-bedrock-kb-source-uri": "uri"}, + "content": {"text": "text"}, + } + ], + } + + # Act + # retrieval_setting missing "score_threshold" + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert len(result["records"]) == 1 + assert result["records"][0]["score"] == 0.1 diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index 3c38888753..e7740ef93a 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -4,9 +4,15 @@ from unittest.mock import MagicMock, patch import pytest from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.enums import FeedbackFromSource, FeedbackRating from models.model import App, AppMode, EndUser, Message -from services.errors.message import FirstMessageNotExistsError, LastMessageNotExistsError -from services.message_service import MessageService +from services.errors.message import ( + FirstMessageNotExistsError, + LastMessageNotExistsError, + MessageNotExistsError, + SuggestedQuestionsAfterAnswerDisabledError, +) +from services.message_service import MessageService, attach_message_extra_contents class TestMessageServiceFactory: @@ -244,14 +250,12 @@ class TestMessageServicePaginationByFirstId: mock_query_first = MagicMock() mock_query_history = MagicMock() + query_calls = [] + def query_side_effect(*args): if args[0] == Message: - # First call returns mock for first_message query - if not hasattr(query_side_effect, "call_count"): - query_side_effect.call_count = 0 - query_side_effect.call_count += 1 - - if query_side_effect.call_count == 1: + query_calls.append(args) + if len(query_calls) == 1: return mock_query_first else: return mock_query_history @@ -647,3 +651,410 @@ class TestMessageServicePaginationByLastId: assert len(result.data) == 10 # Last message trimmed assert result.has_more is True assert result.limit == 10 + + +class TestMessageServiceUtilities: + """Unit tests for MessageService module-level utility functions.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 16: attach_message_extra_contents with empty list + def test_attach_message_extra_contents_empty(self): + """Test attach_message_extra_contents with empty list does nothing.""" + # Act & Assert (should not raise error) + attach_message_extra_contents([]) + + # Test 17: attach_message_extra_contents with messages + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_with_messages(self, mock_create_repo, factory): + """Test attach_message_extra_contents correctly attaches content.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1"), factory.create_message_mock(message_id="msg-2")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + + # Mock extra content models + mock_content1 = MagicMock() + mock_content1.model_dump.return_value = {"key": "value1"} + mock_content2 = MagicMock() + mock_content2.model_dump.return_value = {"key": "value2"} + + mock_repo.get_by_message_ids.return_value = [[mock_content1], [mock_content2]] + + # Act + attach_message_extra_contents(messages) + + # Assert + mock_repo.get_by_message_ids.assert_called_once_with(["msg-1", "msg-2"]) + messages[0].set_extra_contents.assert_called_once_with([{"key": "value1"}]) + messages[1].set_extra_contents.assert_called_once_with([{"key": "value2"}]) + + # Test 18: attach_message_extra_contents with index out of bounds + @patch("services.message_service._create_execution_extra_content_repository") + def test_attach_message_extra_contents_index_out_of_bounds(self, mock_create_repo, factory): + """Test attach_message_extra_contents handles missing content lists.""" + # Arrange + messages = [factory.create_message_mock(message_id="msg-1")] + + mock_repo = MagicMock() + mock_create_repo.return_value = mock_repo + mock_repo.get_by_message_ids.return_value = [] # Empty returned list + + # Act + attach_message_extra_contents(messages) + + # Assert + messages[0].set_extra_contents.assert_called_once_with([]) + + # Test 19: _create_execution_extra_content_repository + @patch("services.message_service.db") + @patch("services.message_service.sessionmaker") + @patch("services.message_service.SQLAlchemyExecutionExtraContentRepository") + def test_create_execution_extra_content_repository(self, mock_repo_class, mock_sessionmaker, mock_db): + """Test _create_execution_extra_content_repository creates expected repository.""" + from services.message_service import _create_execution_extra_content_repository + + # Act + _create_execution_extra_content_repository() + + # Assert + mock_sessionmaker.assert_called_once() + mock_repo_class.assert_called_once() + + +class TestMessageServiceGetMessage: + """Unit tests for MessageService.get_message method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 20: get_message success for EndUser + @patch("services.message_service.db") + def test_get_message_end_user_success(self, mock_db, factory): + """Test get_message returns message for EndUser.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock(user_id="end-user-123") + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + mock_query.where.assert_called_once() + + # Test 21: get_message success for Account (Admin) + @patch("services.message_service.db") + def test_get_message_account_success(self, mock_db, factory): + """Test get_message returns message for Account.""" + # Arrange + from models import Account + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = message + + # Act + result = MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + # Assert + assert result == message + + # Test 22: get_message not found + @patch("services.message_service.db") + def test_get_message_not_found(self, mock_db, factory): + """Test get_message raises MessageNotExistsError when not found.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = None + + # Act & Assert + with pytest.raises(MessageNotExistsError): + MessageService.get_message(app_model=app, user=user, message_id="msg-123") + + +class TestMessageServiceFeedback: + """Unit tests for MessageService feedback-related methods.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 23: create_feedback - new feedback for EndUser + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_new_end_user(self, mock_get_message, mock_db, factory): + """Test creating new feedback for an end user.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + message.user_feedback = None + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating=FeedbackRating.LIKE, + content="Good answer", + ) + + # Assert + assert result.rating == FeedbackRating.LIKE + assert result.content == "Good answer" + assert result.from_source == FeedbackFromSource.USER + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + # Test 24: create_feedback - update feedback for Account + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_update_account(self, mock_get_message, mock_db, factory): + """Test updating existing feedback for an account.""" + # Arrange + from models import Account, MessageFeedback + + app = factory.create_app_mock() + user = MagicMock(spec=Account) + user.id = "account-123" + message = factory.create_message_mock() + feedback = MagicMock(spec=MessageFeedback) + message.admin_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating=FeedbackRating.DISLIKE, + content="Bad answer", + ) + + # Assert + assert result == feedback + assert feedback.rating == FeedbackRating.DISLIKE + assert feedback.content == "Bad answer" + mock_db.session.commit.assert_called_once() + + # Test 25: create_feedback - delete feedback (rating is None) + @patch("services.message_service.db") + @patch.object(MessageService, "get_message") + def test_create_feedback_delete(self, mock_get_message, mock_db, factory): + """Test deleting feedback by passing rating=None.""" + # Arrange + app = factory.create_app_mock() + user = factory.create_end_user_mock() + message = factory.create_message_mock() + feedback = MagicMock() + message.user_feedback = feedback + mock_get_message.return_value = message + + # Act + result = MessageService.create_feedback( + app_model=app, + message_id="msg-123", + user=user, + rating=None, + content=None, + ) + + # Assert + assert result == feedback + mock_db.session.delete.assert_called_once_with(feedback) + mock_db.session.commit.assert_called_once() + + # Test 26: get_all_messages_feedbacks + @patch("services.message_service.db") + def test_get_all_messages_feedbacks(self, mock_db, factory): + """Test get_all_messages_feedbacks returns list of dicts.""" + # Arrange + app = factory.create_app_mock() + feedback = MagicMock() + feedback.to_dict.return_value = {"id": "fb-1"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.limit.return_value = mock_query + mock_query.offset.return_value = mock_query + mock_query.all.return_value = [feedback] + + # Act + result = MessageService.get_all_messages_feedbacks(app_model=app, page=1, limit=10) + + # Assert + assert result == [{"id": "fb-1"}] + mock_query.limit.assert_called_with(10) + mock_query.offset.assert_called_with(0) + + +class TestMessageServiceSuggestedQuestions: + """Unit tests for MessageService.get_suggested_questions_after_answer method.""" + + @pytest.fixture + def factory(self): + """Provide test data factory.""" + return TestMessageServiceFactory() + + # Test 27: get_suggested_questions_after_answer - user is None + def test_get_suggested_questions_user_none(self, factory): + app = factory.create_app_mock() + with pytest.raises(ValueError, match="user cannot be None"): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=None, message_id="msg-123", invoke_from=MagicMock() + ) + + # Test 28: get_suggested_questions_after_answer - Advanced Chat success + @patch("services.message_service.ModelManager") + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_advanced_chat_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_config_manager, + mock_workflow_service, + mock_model_manager, + factory, + ): + """Test successful suggested questions generation in Advanced Chat mode.""" + from core.app.entities.app_invoke_entities import InvokeFrom + + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = True + mock_config_manager.get_app_config.return_value = app_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=InvokeFrom.WEB_APP + ) + + # Assert + assert result == ["Q1?"] + mock_workflow_service.return_value.get_published_workflow.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 29: get_suggested_questions_after_answer - Chat app success (no override) + @patch("services.message_service.db") + @patch("services.message_service.ModelManager") + @patch("services.message_service.TokenBufferMemory") + @patch("services.message_service.LLMGenerator") + @patch("services.message_service.TraceQueueManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_chat_app_success( + self, + mock_conversation_service, + mock_get_message, + mock_trace_manager, + mock_llm_gen, + mock_memory, + mock_model_manager, + mock_db, + factory, + ): + """Test successful suggested questions generation in basic Chat mode.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.CHAT.value) + user = factory.create_end_user_mock() + message = factory.create_message_mock() + mock_get_message.return_value = message + + conversation = MagicMock() + conversation.override_model_configs = None + mock_conversation_service.get_conversation.return_value = conversation + + app_model_config = MagicMock() + app_model_config.suggested_questions_after_answer_dict = {"enabled": True} + app_model_config.model_dict = {"provider": "openai", "name": "gpt-4"} + + mock_query = MagicMock() + mock_db.session.query.return_value = mock_query + mock_query.where.return_value = mock_query + mock_query.first.return_value = app_model_config + + mock_llm_gen.generate_suggested_questions_after_answer.return_value = ["Q1?"] + + # Act + result = MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + ) + + # Assert + assert result == ["Q1?"] + mock_query.first.assert_called_once() + mock_llm_gen.generate_suggested_questions_after_answer.assert_called_once() + + # Test 30: get_suggested_questions_after_answer - Disabled Error + @patch("services.message_service.WorkflowService") + @patch("services.message_service.AdvancedChatAppConfigManager") + @patch.object(MessageService, "get_message") + @patch("services.message_service.ConversationService") + def test_get_suggested_questions_disabled_error( + self, mock_conversation_service, mock_get_message, mock_config_manager, mock_workflow_service, factory + ): + """Test SuggestedQuestionsAfterAnswerDisabledError is raised when feature is disabled.""" + # Arrange + app = factory.create_app_mock(mode=AppMode.ADVANCED_CHAT.value) + user = factory.create_end_user_mock() + mock_get_message.return_value = factory.create_message_mock() + + workflow = MagicMock() + mock_workflow_service.return_value.get_published_workflow.return_value = workflow + + app_config = MagicMock() + app_config.additional_features.suggested_questions_after_answer = False + mock_config_manager.get_app_config.return_value = app_config + + # Act & Assert + with pytest.raises(SuggestedQuestionsAfterAnswerDisabledError): + MessageService.get_suggested_questions_after_answer( + app_model=app, user=user, message_id="msg-123", invoke_from=MagicMock() + ) diff --git a/api/tests/unit_tests/services/test_messages_clean_service.py b/api/tests/unit_tests/services/test_messages_clean_service.py index 67ae2c9142..f3efc4463e 100644 --- a/api/tests/unit_tests/services/test_messages_clean_service.py +++ b/api/tests/unit_tests/services/test_messages_clean_service.py @@ -540,6 +540,20 @@ class TestMessagesCleanServiceFromTimeRange: assert service._batch_size == 1000 # default assert service._dry_run is False # default + def test_explicit_task_label(self): + start_from = datetime.datetime(2024, 1, 1) + end_before = datetime.datetime(2024, 1, 2) + policy = BillingDisabledPolicy() + + service = MessagesCleanService.from_time_range( + policy=policy, + start_from=start_from, + end_before=end_before, + task_label="60to30", + ) + + assert service._metrics._base_attributes["task_label"] == "60to30" + class TestMessagesCleanServiceFromDays: """Unit tests for MessagesCleanService.from_days factory method.""" @@ -554,11 +568,9 @@ class TestMessagesCleanServiceFromDays: MessagesCleanService.from_days(policy=policy, days=-1) # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 14, 0, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days(policy=policy, days=0) # Assert @@ -586,11 +598,9 @@ class TestMessagesCleanServiceFromDays: dry_run = True # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days( policy=policy, days=days, @@ -613,11 +623,9 @@ class TestMessagesCleanServiceFromDays: policy = BillingDisabledPolicy() # Act - with patch("services.retention.conversation.messages_clean_service.datetime", autospec=True) as mock_datetime: + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: fixed_now = datetime.datetime(2024, 6, 15, 10, 30, 0) - mock_datetime.datetime.now.return_value = fixed_now - mock_datetime.timedelta = datetime.timedelta - + mock_now.return_value = fixed_now service = MessagesCleanService.from_days(policy=policy) # Assert @@ -625,3 +633,54 @@ class TestMessagesCleanServiceFromDays: assert service._end_before == expected_end_before assert service._batch_size == 1000 # default assert service._dry_run is False # default + assert service._metrics._base_attributes["task_label"] == "custom" + + +class TestMessagesCleanServiceRun: + """Unit tests for MessagesCleanService.run instrumentation behavior.""" + + def test_run_records_completion_metrics_on_success(self): + # Arrange + service = MessagesCleanService( + policy=BillingDisabledPolicy(), + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 1, 2), + batch_size=100, + dry_run=False, + ) + expected_stats = { + "batches": 1, + "total_messages": 10, + "filtered_messages": 5, + "total_deleted": 5, + } + service._clean_messages_by_time_range = MagicMock(return_value=expected_stats) # type: ignore[method-assign] + completion_calls: list[dict[str, object]] = [] + service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign] + + # Act + result = service.run() + + # Assert + assert result == expected_stats + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "success" + + def test_run_records_completion_metrics_on_failure(self): + # Arrange + service = MessagesCleanService( + policy=BillingDisabledPolicy(), + start_from=datetime.datetime(2024, 1, 1), + end_before=datetime.datetime(2024, 1, 2), + batch_size=100, + dry_run=False, + ) + service._clean_messages_by_time_range = MagicMock(side_effect=RuntimeError("clean failed")) # type: ignore[method-assign] + completion_calls: list[dict[str, object]] = [] + service._metrics.record_completion = lambda **kwargs: completion_calls.append(kwargs) # type: ignore[method-assign] + + # Act & Assert + with pytest.raises(RuntimeError, match="clean failed"): + service.run() + assert len(completion_calls) == 1 + assert completion_calls[0]["status"] == "failed" diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..bbdc16d4f8 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_service.py @@ -0,0 +1,558 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataArgs, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +@dataclass +class _DocumentStub: + id: str + name: str + uploader: str + upload_date: datetime + last_update_date: datetime + data_source_type: str + doc_metadata: dict[str, object] | None + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + mocked_db = mocker.patch("services.metadata_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.fixture +def mock_redis_client(mocker: MockerFixture) -> MagicMock: + return mocker.patch("services.metadata_service.redis_client") + + +@pytest.fixture +def mock_current_account(mocker: MockerFixture) -> MagicMock: + mock_user = SimpleNamespace(id="user-1") + return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) + + +def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: + now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) + return _DocumentStub( + id=document_id, + name=f"doc-{document_id}", + uploader="qa@example.com", + upload_date=now, + last_update_date=now, + data_source_type="upload_file", + doc_metadata=doc_metadata, + ) + + +def _dataset(**kwargs: Any) -> Dataset: + return cast(Dataset, SimpleNamespace(**kwargs)) + + +def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="x" * 256) + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="priority") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + mock_current_account.assert_called_once() + + +def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_persist_metadata_when_input_is_valid( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="number", name="score") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + assert result.tenant_id == "tenant-1" + assert result.dataset_id == "dataset-1" + assert result.type == "number" + assert result.name == "score" + assert result.created_by == "user-1" + mock_db.session.add.assert_called_once_with(result) + mock_db.session.commit.assert_called_once() + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + too_long_name = "x" * 256 + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) + + +def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_update_bound_documents_and_return_metadata( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) + mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) + + metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) + bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] + + doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) + doc_2 = _build_document("2", None) + mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") + mock_get_documents.return_value = [doc_1, doc_2] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") + + # Assert + assert result is metadata + assert metadata.name == "new_name" + assert metadata.updated_by == "user-1" + assert metadata.updated_at == fixed_now + assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} + assert doc_2.doc_metadata == {"new_name": None} + mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = None + mock_db.session.query.side_effect = [query_duplicate, query_metadata] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_delete_metadata_should_remove_metadata_and_related_document_fields( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + metadata = SimpleNamespace(id="metadata-1", name="obsolete") + bindings = [SimpleNamespace(document_id="doc-1")] + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_metadata, query_bindings] + + document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) + + # Act + result = MetadataService.delete_metadata("dataset-1", "metadata-1") + + # Assert + assert result is metadata + assert document.doc_metadata == {"remaining": "value"} + mock_db.session.delete.assert_called_once_with(metadata) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_delete_metadata_should_return_none_when_metadata_is_missing( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + # Act + result = MetadataService.delete_metadata("dataset-1", "missing-id") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_get_built_in_fields_should_return_all_expected_fields() -> None: + # Arrange + expected_names = { + BuiltInField.document_name, + BuiltInField.uploader, + BuiltInField.upload_date, + BuiltInField.last_update_date, + BuiltInField.source, + } + + # Act + result = MetadataService.get_built_in_fields() + + # Assert + assert {item["name"] for item in result} == expected_names + assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] + + +def test_enable_built_in_field_should_return_immediately_when_already_enabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_enable_built_in_field_should_populate_documents_and_enable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + doc_1 = _build_document("1", {"custom": "value"}) + doc_2 = _build_document("2", None) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[doc_1, doc_2], + ) + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is True + assert doc_1.doc_metadata is not None + assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" + assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert doc_2.doc_metadata is not None + assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_disable_built_in_field_should_return_immediately_when_already_disabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document( + "1", + { + BuiltInField.document_name: "doc", + BuiltInField.uploader: "user", + BuiltInField.upload_date: 1.0, + BuiltInField.last_update_date: 2.0, + BuiltInField.source: MetadataDataSource.upload_file, + "custom": "keep", + }, + ) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[document], + ) + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is False + assert document.doc_metadata == {"custom": "keep"} + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + document = _build_document("1", {"legacy": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + delete_chain = mock_db.session.query.return_value.filter_by.return_value + delete_chain.delete.return_value = 1 + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata == {"priority": "high"} + delete_chain.delete.assert_called_once() + assert mock_db.session.commit.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document("1", {"existing": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata is not None + assert document.doc_metadata["existing"] == "value" + assert document.doc_metadata["new_key"] == "new_value" + assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert mock_db.session.commit.call_count == 1 + assert mock_db.session.add.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) + operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + Assert + with pytest.raises(ValueError, match="Document not found"): + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + mock_db.session.rollback.assert_called_once() + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") + + +@pytest.mark.parametrize( + ("dataset_id", "document_id", "expected_key"), + [ + ("dataset-1", None, "dataset_metadata_lock_dataset-1"), + (None, "doc-1", "document_metadata_lock_doc-1"), + ], +) +def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( + dataset_id: str | None, + document_id: str | None, + expected_key: str, + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) + + # Assert + mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="knowledge base metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="document metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") + + +def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset( + id="dataset-1", + built_in_field_enabled=True, + doc_metadata=[ + {"id": "meta-1", "name": "priority", "type": "string"}, + {"id": "built-in", "name": "ignored", "type": "string"}, + {"id": "meta-2", "name": "score", "type": "number"}, + ], + ) + count_chain = mock_db.session.query.return_value.filter_by.return_value + count_chain.count.side_effect = [3, 1] + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result["built_in_field_enabled"] is True + assert result["doc_metadata"] == [ + {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, + {"id": "meta-2", "name": "score", "type": "number", "count": 1}, + ] + + +def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result == {"doc_metadata": [], "built_in_field_enabled": False} + mock_db.session.query.assert_not_called() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..49e572584b --- /dev/null +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,808 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, +) +from models.provider import LoadBalancingModelConfig +from services.model_load_balancing_service import ModelLoadBalancingService + + +def _build_provider_credential_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ] + ) + + +def _build_model_credential_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ], + ) + + +def _build_provider_configuration( + *, + custom_provider: bool = False, + load_balancing_enabled: bool | None = None, + model_schema: ModelCredentialSchema | None = None, + provider_schema: ProviderCredentialSchema | None = None, +) -> MagicMock: + provider_configuration = MagicMock() + provider_configuration.provider = SimpleNamespace( + provider="openai", + model_credential_schema=model_schema, + provider_credential_schema=provider_schema, + ) + provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider) + provider_configuration.extract_secret_variables.return_value = ["api_key"] + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials + provider_configuration.get_provider_model_setting.return_value = ( + None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled) + ) + return provider_configuration + + +def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: + return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def service(mocker: MockerFixture) -> ModelLoadBalancingService: + # Arrange + provider_manager = MagicMock() + mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + svc = ModelLoadBalancingService() + svc.provider_manager = provider_manager + return svc + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.model_load_balancing_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.mark.parametrize( + ("method_name", "expected_provider_method"), + [ + ("enable_model_load_balancing", "enable_model_load_balancing"), + ("disable_model_load_balancing", "disable_model_load_balancing"), + ], +) +def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists( + method_name: str, + expected_provider_method: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + # Assert + getattr(provider_configuration, expected_provider_method).assert_called_once_with( + model="gpt-4o-mini", model_type=ModelType.LLM + ) + + +@pytest.mark.parametrize( + "method_name", + ["enable_model_load_balancing", "disable_model_load_balancing"], +) +def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing( + method_name: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=True, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace( + id="cfg-1", + name="primary", + encrypted_config=json.dumps({"api_key": "encrypted-key"}), + credential_id="cred-1", + enabled=True, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + return_value="plain-key", + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(False, 0), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + ) + + # Assert + assert is_enabled is True + assert len(configs) == 2 + assert configs[0]["name"] == "__inherit__" + assert configs[1]["name"] == "primary" + assert configs[1]["credentials"] == {"api_key": "plain-key"} + assert mock_db.session.add.call_count == 1 + assert mock_db.session.commit.call_count == 1 + + +def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=None, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + normal_config = SimpleNamespace( + id="cfg-1", + name="normal", + encrypted_config=json.dumps({"api_key": "bad-encrypted"}), + credential_id="cred-1", + enabled=True, + ) + inherit_config = SimpleNamespace( + id="cfg-2", + name="__inherit__", + encrypted_config="not-json", + credential_id=None, + enabled=False, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + normal_config, + inherit_config, + ] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + side_effect=ValueError("cannot decrypt"), + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(True, 15), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + config_from="predefined-model", + ) + + # Assert + assert is_enabled is False + assert configs[0]["name"] == "__inherit__" + assert configs[0]["credentials"] == {} + assert configs[1]["credentials"] == {"api_key": "bad-encrypted"} + assert configs[1]["in_cooldown"] is True + assert configs[1]["ttl"] == 15 + + +def test_get_load_balancing_config_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + +def test_get_load_balancing_config_should_return_none_when_config_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result is None + + +def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: { + "masked": credentials.get("api_key", "") + } + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) + mock_db.session.query.return_value.where.return_value.first.return_value = config + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result == { + "id": "cfg-1", + "name": "primary", + "credentials": {"masked": ""}, + "enabled": True, + } + + +def test_init_inherit_config_should_create_and_persist_inherit_configuration( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + model_type = ModelType.LLM + + # Act + inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type) + + # Assert + assert inherit_config.tenant_id == "tenant-1" + assert inherit_config.provider_name == "openai" + assert inherit_config.model_name == "gpt-4o-mini" + assert inherit_config.model_type == "text-generation" + assert inherit_config.name == "__inherit__" + mock_db.session.add.assert_called_once_with(inherit_config) + mock_db.session.commit.assert_called_once() + + +def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list( + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing configs"): + service.update_load_balancing_configs( # type: ignore[arg-type] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], "invalid-configs"), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config"): + service.update_load_balancing_configs( # type: ignore[list-item] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], ["bad-item"]), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"enabled": True}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config enabled"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "cfg-without-enabled"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + current_config = SimpleNamespace(id="cfg-1") + mock_db.session.scalars.return_value.all.return_value = [current_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-2", "name": "invalid", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None) + mock_db.session.scalars.return_value.all.return_value = [existing_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new-config", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config_1 = SimpleNamespace( + id="cfg-1", + name="existing-one", + enabled=True, + encrypted_config=json.dumps({"api_key": "old"}), + updated_at=None, + ) + existing_config_2 = SimpleNamespace( + id="cfg-2", + name="existing-two", + enabled=True, + encrypted_config=None, + updated_at=None, + ) + mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2] + mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"}) + mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache") + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [ + {"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}}, + {"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}}, + ], + "custom-model", + ) + + # Assert + assert existing_config_1.name == "updated-name" + assert existing_config_1.enabled is False + assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"} + assert mock_db.session.add.call_count == 1 + mock_db.session.delete.assert_called_once_with(existing_config_2) + assert mock_db.session.commit.call_count >= 3 + mock_clear_cache.assert_any_call("tenant-1", "cfg-1") + mock_clear_cache.assert_any_call("tenant-1", "cfg-2") + + +def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') + mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + # Assert + created_config = mock_db.session.add.call_args.args[0] + assert created_config.name == "Main Credential" + assert created_config.credential_id == "cred-1" + assert created_config.credential_source_type == "provider" + assert created_config.encrypted_config == '{"api_key":"enc"}' + mock_db.session.commit.assert_called() + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + + +def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1") + mock_db.session.query.return_value.where.return_value.first.return_value = existing_config + mock_validate = mocker.patch.object(service, "_custom_credentials_validate") + + # Act + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + # Assert + assert mock_validate.call_count == 2 + assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config + assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + + +def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + load_balancing_model_config = _load_balancing_model_config( + encrypted_config=json.dumps({"api_key": "old-encrypted-token"}) + ) + mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value") + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + load_balancing_model_config=load_balancing_model_config, + validate=False, + ) + + # Assert + assert result == {"api_key": "enc:old-plain-value", "region": "us"} + mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value") + + +def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema()) + load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") + mock_factory = MagicMock() + mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + load_balancing_model_config=load_balancing_model_config, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:validated"} + mock_factory.model_credentials_validate.assert_called_once() + mock_factory.provider_credentials_validate.assert_not_called() + mock_encrypt.assert_called_once_with("tenant-1", "validated") + + +def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + mock_factory = MagicMock() + mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:provider-validated"} + mock_factory.provider_credentials_validate.assert_called_once() + mock_factory.model_credentials_validate.assert_not_called() + + +def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise( + service: ModelLoadBalancingService, +) -> None: + # Arrange + model_schema = _build_model_credential_schema() + provider_schema = _build_provider_credential_schema() + provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema) + provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema) + provider_configuration_without_schema = _build_provider_configuration() + + # Act + schema_from_model = service._get_credential_schema(provider_configuration_with_model) + schema_from_provider = service._get_credential_schema(provider_configuration_with_provider) + + # Assert + assert schema_from_model is model_schema + assert schema_from_provider is provider_schema + with pytest.raises(ValueError, match="No credential schema found"): + service._get_credential_schema(provider_configuration_without_schema) + + +def test_clear_credentials_cache_should_delete_load_balancing_cache_entry( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_cache_instance = MagicMock() + mock_cache_cls = mocker.patch( + "services.model_load_balancing_service.ProviderCredentialsCache", + return_value=mock_cache_instance, + ) + + # Act + service._clear_credentials_cache("tenant-1", "cfg-1") + + # Assert + mock_cache_cls.assert_called_once() + assert mock_cache_cls.call_args.kwargs == { + "tenant_id": "tenant-1", + "identity_id": "cfg-1", + "cache_type": mocker.ANY, + } + assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL" + mock_cache_instance.delete.assert_called_once() diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index e2360b116d..6a6b63f003 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -3,9 +3,9 @@ import types import pytest from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration -from core.model_runtime.entities.common_entities import I18nObject -from core.model_runtime.entities.model_entities import ModelType -from core.model_runtime.entities.provider_entities import ConfigurateMethod +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod from models.provider import ProviderType from services.model_provider_service import ModelProviderService diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..231ceb74dc --- /dev/null +++ b/api/tests/unit_tests/services/test_oauth_server_service.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from werkzeug.exceptions import BadRequest + +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +@pytest.fixture +def mock_redis_client(mocker: MockerFixture) -> MagicMock: + return mocker.patch("services.oauth_server.redis_client") + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mock the OAuth server Session context manager.""" + mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + mocker.patch("services.oauth_server.Session", return_value=session_cm) + return session + + +def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: + # Arrange + mock_execute_result = MagicMock() + expected_app = MagicMock() + mock_execute_result.scalar_one_or_none.return_value = expected_app + mock_session.execute.return_value = mock_execute_result + + # Act + result = OAuthServerService.get_oauth_provider_app("client-1") + + # Assert + assert result is expected_app + mock_session.execute.assert_called_once() + mock_execute_result.scalar_one_or_none.assert_called_once() + + +def test_sign_oauth_authorization_code_should_store_code_and_return_value( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + + # Act + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + # Assert + expected_code = str(deterministic_uuid) + assert code == expected_code + mock_redis_client.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), + "user-1", + ex=600, + ) + + +def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + Assert + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + +def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) + mock_redis_client.get.return_value = b"user-1" + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + + # Act + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + # Assert + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + mock_redis_client.delete.assert_called_once_with(code_key) + mock_redis_client.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis_client.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + +def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + Assert + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + +def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + mock_redis_client.get.return_value = b"user-1" + + # Act + access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + # Assert + assert access_token == str(deterministic_uuid) + assert returned_refresh_token == "refresh-1" + mock_redis_client.set.assert_called_once_with( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + + +def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: + # Arrange + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + # Act + result = OAuthServerService.sign_oauth_access_token( + grant_type=grant_type, + client_id="client-1", + ) + + # Assert + assert result is None + + +def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + + # Act + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + # Assert + assert refresh_token == str(deterministic_uuid) + mock_redis_client.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + +def test_validate_oauth_access_token_should_return_none_when_token_not_found( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + # Assert + assert result is None + + +def test_validate_oauth_access_token_should_load_user_when_token_exists( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + mock_redis_client.get.return_value = b"user-88" + expected_user = MagicMock() + mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) + + # Act + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + # Assert + assert result is expected_user + mock_load_user.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_operation_service.py b/api/tests/unit_tests/services/test_operation_service.py new file mode 100644 index 0000000000..a4c69b23ac --- /dev/null +++ b/api/tests/unit_tests/services/test_operation_service.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from services.operation_service import OperationService + + +class TestOperationService: + """Test suite for OperationService""" + + # ===== Internal Method Tests ===== + + @patch("httpx.request") + def test_should_call_with_correct_parameters_when__send_request_invoked( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request calls httpx.request with the correct URL, headers, and data""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + monkeypatch.setattr(OperationService, "secret_key", "s3cr3t") + + mock_response = MagicMock() + mock_response.json.return_value = {"status": "success"} + mock_request.return_value = mock_response + + method = "POST" + endpoint = "/test_endpoint" + json_data = {"key": "value"} + + # Act + result = OperationService._send_request(method, endpoint, json=json_data) + + # Assert + assert result == {"status": "success"} + + # Verify call parameters + expected_url = "https://billing.example/test_endpoint" + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert args[0] == method + assert args[1] == expected_url + assert kwargs["json"] == json_data + assert kwargs["headers"]["Billing-Api-Secret-Key"] == "s3cr3t" + assert kwargs["headers"]["Content-Type"] == "application/json" + + @patch("httpx.request") + def test_should_propagate_httpx_error_when__send_request_raises( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request handles httpx raising an error""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + mock_request.side_effect = httpx.RequestError("network error") + + # Act & Assert + with pytest.raises(httpx.RequestError): + OperationService._send_request("POST", "/test") + + # ===== Public Method Tests ===== + + @pytest.mark.parametrize( + ("utm_info", "expected_params"), + [ + ( + { + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + { + "tenant_id": "tenant-123", + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + ), + ( + {}, # Empty utm_info + { + "tenant_id": "tenant-123", + "utm_source": "", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ( + {"utm_source": "newsletter"}, # Partial utm_info + { + "tenant_id": "tenant-123", + "utm_source": "newsletter", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ], + ) + @patch.object(OperationService, "_send_request") + def test_should_map_parameters_correctly_when_record_utm_called( + self, mock_send: MagicMock, utm_info: dict, expected_params: dict + ): + """Test that record_utm correctly maps utm_info to parameters and calls _send_request""" + # Arrange + tenant_id = "tenant-123" + mock_send.return_value = {"status": "recorded"} + + # Act + result = OperationService.record_utm(tenant_id, utm_info) + + # Assert + assert result == {"status": "recorded"} + mock_send.assert_called_once_with("POST", "/tenant_utms", params=expected_params) diff --git a/api/tests/unit_tests/services/test_ops_service.py b/api/tests/unit_tests/services/test_ops_service.py new file mode 100644 index 0000000000..ab7b473790 --- /dev/null +++ b/api/tests/unit_tests/services/test_ops_service.py @@ -0,0 +1,381 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.config_entity import TracingProviderEnum +from models.model import App, TraceAppConfig +from services.ops_service import OpsService + + +class TestOpsService: + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + mock_db.session.query.assert_called_with(TraceAppConfig) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, None] + + # Act + result = OpsService.get_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + assert mock_db.session.query.call_count == 2 + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_none_config(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = None + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + # Act & Assert + with pytest.raises(ValueError, match="Tracing config cannot be None."): + OpsService.get_tracing_app_config("app_id", "arize") + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "default_url"), + [ + ("arize", "https://app.arize.com/"), + ("phoenix", "https://app.phoenix.arize.com/projects/"), + ("langsmith", "https://smith.langchain.com/"), + ("opik", "https://www.comet.com/opik/"), + ("weave", "https://wandb.ai/"), + ("aliyun", "https://arms.console.aliyun.com/"), + ("tencent", "https://console.cloud.tencent.com/apm"), + ("mlflow", "http://localhost:5000/"), + ("databricks", "https://www.databricks.com/"), + ], + ) + def test_get_tracing_app_config_providers_exception(self, mock_ops_trace_manager, mock_db, provider, default_url): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": default_url}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == default_url + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + "provider", ["arize", "phoenix", "langsmith", "opik", "weave", "aliyun", "tencent", "mlflow", "databricks"] + ) + def test_get_tracing_app_config_providers_success(self, mock_ops_trace_manager, mock_db, provider): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "success_url"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {} + mock_ops_trace_manager.get_trace_config_project_url.return_value = "success_url" + + # Act + result = OpsService.get_tracing_app_config("app_id", provider) + + # Assert + assert result["tracing_config"]["project_url"] == "success_url" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/project/key"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/project/key" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_get_tracing_app_config_langfuse_exception(self, mock_ops_trace_manager, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + trace_config.tracing_config = {"some": "config"} + trace_config.to_dict.return_value = {"tracing_config": {"project_url": "https://api.langfuse.com/"}} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [trace_config, app] + + mock_ops_trace_manager.decrypt_tracing_config.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.obfuscated_decrypt_token.return_value = {"host": "https://api.langfuse.com"} + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + + # Act + result = OpsService.get_tracing_app_config("app_id", "langfuse") + + # Assert + assert result["tracing_config"]["project_url"] == "https://api.langfuse.com/" + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act + result = OpsService.create_tracing_app_config("app_id", "invalid_provider", {}) + + # Assert + assert result == {"error": "Invalid tracing provider: invalid_provider"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {"public_key": "p", "secret_key": "s"}) + + # Assert + assert result == {"error": "Invalid Credentials"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + @pytest.mark.parametrize( + ("provider", "config"), + [ + (TracingProviderEnum.ARIZE, {}), + (TracingProviderEnum.LANGFUSE, {"public_key": "p", "secret_key": "s"}), + (TracingProviderEnum.LANGSMITH, {"api_key": "k", "project": "p"}), + (TracingProviderEnum.ALIYUN, {"license_key": "k", "endpoint": "https://aliyun.com"}), + ], + ) + def test_create_tracing_app_config_project_url_exception(self, mock_ops_trace_manager, mock_db, provider, config): + # Arrange + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.side_effect = Exception("error") + mock_ops_trace_manager.get_trace_config_project_key.side_effect = Exception("error") + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, config) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_langfuse_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.LANGFUSE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_key.return_value = "key" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + result = OpsService.create_tracing_app_config( + "app_id", provider, {"public_key": "p", "secret_key": "s", "host": "https://api.langfuse.com"} + ) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_already_exists(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.return_value = MagicMock(spec=TraceAppConfig) + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, None] + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_with_empty_other_keys(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {} + + # Act + # 'project' is in other_keys for Arize + # provide an empty string for the project in the tracing_config + # create_tracing_app_config will replace it with the default from the model + result = OpsService.create_tracing_app_config("app_id", provider, {"project": ""}) + + # Assert + assert result == {"result": "success"} + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_create_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + mock_ops_trace_manager.get_trace_config_project_url.return_value = "http://project_url" + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [None, app] + mock_ops_trace_manager.encrypt_tracing_config.return_value = {"encrypted": "config"} + + # Act + result = OpsService.create_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"result": "success"} + mock_db.session.add.assert_called() + mock_db.session.commit.assert_called() + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_provider(self, mock_ops_trace_manager, mock_db): + # Act & Assert + with pytest.raises(ValueError, match="Invalid tracing provider: invalid_provider"): + OpsService.update_tracing_app_config("app_id", "invalid_provider", {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_config(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_no_app(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, None] + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result is None + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_invalid_credentials(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = False + + # Act & Assert + with pytest.raises(ValueError, match="Invalid Credentials"): + OpsService.update_tracing_app_config("app_id", provider, {}) + + @patch("services.ops_service.db") + @patch("services.ops_service.OpsTraceManager") + def test_update_tracing_app_config_success(self, mock_ops_trace_manager, mock_db): + # Arrange + provider = TracingProviderEnum.ARIZE + current_config = MagicMock(spec=TraceAppConfig) + current_config.to_dict.return_value = {"some": "data"} + app = MagicMock(spec=App) + app.tenant_id = "tenant_id" + mock_db.session.query.return_value.where.return_value.first.side_effect = [current_config, app] + mock_ops_trace_manager.decrypt_tracing_config.return_value = {} + mock_ops_trace_manager.check_trace_config_is_effective.return_value = True + + # Act + result = OpsService.update_tracing_app_config("app_id", provider, {}) + + # Assert + assert result == {"some": "data"} + mock_db.session.commit.assert_called_once() + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_no_config(self, mock_db): + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is None + + @patch("services.ops_service.db") + def test_delete_tracing_app_config_success(self, mock_db): + # Arrange + trace_config = MagicMock(spec=TraceAppConfig) + mock_db.session.query.return_value.where.return_value.first.return_value = trace_config + + # Act + result = OpsService.delete_tracing_app_config("app_id", "arize") + + # Assert + assert result is True + mock_db.session.delete.assert_called_with(trace_config) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py new file mode 100644 index 0000000000..be64e431ba --- /dev/null +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -0,0 +1,1330 @@ +"""Unit tests for services.summary_index_service.""" + +from __future__ import annotations + +import sys +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +import services.summary_index_service as summary_module +from models.enums import SegmentStatus, SummaryStatus +from services.summary_index_service import SummaryIndexService + + +@dataclass(frozen=True) +class _SessionContext: + session: MagicMock + + def __enter__(self) -> MagicMock: + return self.session + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + +def _dataset(*, indexing_technique: str = "high_quality") -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = "dataset-1" + dataset.tenant_id = "tenant-1" + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding" + return dataset + + +def _segment(*, has_document: bool = True) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = "seg-1" + segment.document_id = "doc-1" + segment.dataset_id = "dataset-1" + segment.content = "hello world" + segment.enabled = True + segment.status = SegmentStatus.COMPLETED + segment.position = 1 + if has_document: + doc = MagicMock(name="document") + doc.doc_language = "en" + doc.doc_form = "text_model" + segment.document = doc + else: + segment.document = None + return segment + + +def _summary_record(*, summary_content: str = "summary", node_id: str | None = None) -> MagicMock: + record = MagicMock(spec=summary_module.DocumentSegmentSummary, name="summary_record") + record.id = "sum-1" + record.dataset_id = "dataset-1" + record.document_id = "doc-1" + record.chunk_id = "seg-1" + record.summary_content = summary_content + record.summary_index_node_id = node_id + record.summary_index_node_hash = None + record.tokens = None + record.status = SummaryStatus.GENERATING + record.error = None + record.enabled = True + record.created_at = datetime(2024, 1, 1, tzinfo=UTC) + record.updated_at = datetime(2024, 1, 1, tzinfo=UTC) + record.disabled_at = None + record.disabled_by = None + return record + + +def test_generate_summary_for_segment_passes_document_language(monkeypatch: pytest.MonkeyPatch) -> None: + usage = MagicMock() + usage.total_tokens = 10 + usage.prompt_tokens = 3 + usage.completion_tokens = 7 + + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("sum", usage))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + segment = _segment(has_document=True) + dataset = _dataset() + + content, got_usage = SummaryIndexService.generate_summary_for_segment(segment, dataset, {"a": 1}) + assert content == "sum" + assert got_usage is usage + + paragraph_module.ParagraphIndexProcessor.generate_summary.assert_called_once() + _, kwargs = paragraph_module.ParagraphIndexProcessor.generate_summary.call_args + assert kwargs["document_language"] == "en" + + +def test_generate_summary_for_segment_raises_when_empty(monkeypatch: pytest.MonkeyPatch) -> None: + paragraph_module = SimpleNamespace( + ParagraphIndexProcessor=SimpleNamespace(generate_summary=MagicMock(return_value=("", MagicMock()))) + ) + monkeypatch.setitem( + sys.modules, + "core.rag.index_processor.processor.paragraph_index_processor", + paragraph_module, + ) + + with pytest.raises(ValueError, match="Generated summary is empty"): + SummaryIndexService.generate_summary_for_segment(_segment(), _dataset(), {"a": 1}) + + +def test_create_summary_record_updates_existing_and_reenables(monkeypatch: pytest.MonkeyPatch) -> None: + existing = _summary_record(summary_content="old", node_id="n1") + existing.enabled = False + existing.disabled_at = datetime(2024, 1, 1) + existing.disabled_by = "u" + + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = existing + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + segment = _segment() + dataset = _dataset() + + result = SummaryIndexService.create_summary_record(segment, dataset, "new", status=SummaryStatus.GENERATING) + assert result is existing + assert existing.summary_content == "new" + assert existing.status == SummaryStatus.GENERATING + assert existing.enabled is True + assert existing.disabled_at is None + assert existing.disabled_by is None + assert existing.error is None + session.add.assert_called_once_with(existing) + session.flush.assert_called_once() + + +def test_create_summary_record_creates_new(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock(name="session") + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + record = SummaryIndexService.create_summary_record(_segment(), _dataset(), "new", status=SummaryStatus.GENERATING) + assert record.dataset_id == "dataset-1" + assert record.chunk_id == "seg-1" + assert record.summary_content == "new" + assert record.enabled is True + session.add.assert_called_once() + session.flush.assert_called_once() + + +def test_vectorize_summary_skips_non_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + vector_cls = MagicMock() + monkeypatch.setattr(summary_module, "Vector", vector_cls) + SummaryIndexService.vectorize_summary(_summary_record(), _segment(), _dataset(indexing_technique="economy")) + vector_cls.assert_not_called() + + +def test_vectorize_summary_raises_for_blank_content() -> None: + with pytest.raises(ValueError, match="Summary content is empty"): + SummaryIndexService.vectorize_summary(_summary_record(summary_content=" "), _segment(), _dataset()) + + +def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [5] + model_manager = MagicMock() + model_manager.get_model_instance.return_value = embedding_model + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + session = MagicMock(name="provided_session") + merged = _summary_record(summary_content="sum") + session.merge.return_value = merged + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=session) + + assert vector_instance.add_texts.call_count == 2 + summary_module.time.sleep.assert_called_once() # type: ignore[attr-defined] + session.flush.assert_called_once() + assert summary.status == SummaryStatus.COMPLETED + assert summary.summary_index_node_id == "uuid-1" + assert summary.summary_index_node_hash == "hash-1" + assert summary.tokens == 5 + + +def test_vectorize_summary_without_session_creates_record_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id="old-node") + + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + # Force deletion branch to run and swallow delete failures. + vector_for_delete = MagicMock() + vector_for_delete.delete_by_ids.side_effect = RuntimeError("delete failed") + vector_for_add = MagicMock() + vector_for_add.add_texts.return_value = None + vector_cls = MagicMock(side_effect=[vector_for_delete, vector_for_add]) + monkeypatch.setattr(summary_module, "Vector", vector_cls) + + model_manager = MagicMock() + model_manager.get_model_instance.side_effect = RuntimeError("no model") + monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + + # New session used after vectorization succeeds (record not found by id nor chunk_id). + session = MagicMock(name="session") + q1 = MagicMock() + q1.filter_by.return_value = q1 + q1.first.side_effect = [None, None] + session.query.return_value = q1 + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # One context for success path, no error handler session. + create_session_mock.assert_called() + session.add.assert_called() + session.commit.assert_called_once() + assert summary.status == SummaryStatus.COMPLETED + assert summary.summary_index_node_id == "old-node" # reused + + +def test_vectorize_summary_final_failure_updates_error_status(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + + vector_instance = MagicMock() + vector_instance.add_texts.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + + # error_session should find record and commit status update + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(return_value=_SessionContext(error_session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + assert summary.status == SummaryStatus.ERROR + assert "Vectorization failed" in (summary.error or "") + error_session.commit.assert_called_once() + + +def test_batch_create_summary_records_no_segments_noop(monkeypatch: pytest.MonkeyPatch) -> None: + create_session_mock = MagicMock() + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + SummaryIndexService.batch_create_summary_records([], _dataset()) + create_session_mock.assert_not_called() + + +def test_batch_create_summary_records_creates_and_updates(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + s1 = _segment() + s2 = _segment() + s2.id = "seg-2" + s2.document_id = "doc-2" + + existing = _summary_record() + existing.chunk_id = "seg-2" + existing.enabled = False + + session = MagicMock() + query = MagicMock() + query.filter.return_value = query + query.all.return_value = [existing] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.batch_create_summary_records([s1, s2], dataset, status=SummaryStatus.NOT_STARTED) + session.commit.assert_called_once() + assert existing.enabled is True + + +def test_update_summary_record_error_updates_when_exists(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + assert record.status == SummaryStatus.ERROR + assert record.error == "err" + session.commit.assert_called_once() + + +def test_generate_and_vectorize_summary_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert out is record + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_generate_and_vectorize_summary_vectorize_failure_sets_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", MagicMock(total_tokens=0))) + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert record.status == SummaryStatus.ERROR + # Outer exception handler overwrites the error with the raw exception message. + assert record.error == "boom" + + +def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + + vector_instance = MagicMock() + vector_instance.add_texts.return_value = None + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + existing.id = "other-id" + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, existing] # miss by id, hit by chunk_id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_id == "uuid-1" + + +def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + existing = _summary_record(summary_content="old", node_id="old-node") + session = MagicMock(name="session") + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = existing # hit by id + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + session.commit.assert_called_once() + assert existing.summary_index_node_hash == "hash-1" + + +def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + class _BadContext: + def __enter__(self): + return None + + def __exit__(self, exc_type, exc, tb) -> None: + return None + + error_session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.return_value = summary + error_session.query.return_value = q + + create_session_mock = MagicMock(side_effect=[_BadContext(), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + with pytest.raises(RuntimeError, match="Session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr( + summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) + ) + monkeypatch.setattr( + summary_module, + "ModelManager", + MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), + ) + + session = MagicMock() + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # miss by id and chunk_id + session.query.return_value = q + + error_session = MagicMock() + eq = MagicMock() + eq.filter_by.return_value = eq + eq.first.return_value = summary + error_session.query.return_value = eq + + create_session_mock = MagicMock(side_effect=[_SessionContext(session), _SessionContext(error_session)]) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + # Force the created record to be None so the "should not be None" guard triggers. + monkeypatch.setattr(summary_module, "DocumentSegmentSummary", MagicMock(return_value=None)) + + with pytest.raises(RuntimeError, match="summary_record_in_session should not be None"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + +def test_vectorize_summary_error_handler_tries_chunk_id_lookup_and_can_warn_not_found( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + summary = _summary_record(summary_content="sum", node_id=None) + + monkeypatch.setattr(summary_module.uuid, "uuid4", MagicMock(return_value="uuid-1")) + monkeypatch.setattr(summary_module.helper, "generate_text_hash", MagicMock(return_value="hash-1")) + monkeypatch.setattr(summary_module.time, "sleep", MagicMock()) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(add_texts=MagicMock(side_effect=RuntimeError("boom")))), + ) + + error_session = MagicMock(name="error_session") + q = MagicMock() + q.filter_by.return_value = q + q.first.side_effect = [None, None] # not found by id, not found by chunk_id + error_session.query.return_value = q + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(error_session))), + ) + + with pytest.raises(RuntimeError, match="boom"): + SummaryIndexService.vectorize_summary(summary, segment, dataset, session=None) + + # No record -> no commit in error session. + error_session.commit.assert_not_called() + + +def test_update_summary_record_error_warns_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_record_error(segment, dataset, "err") + logger_mock.warning.assert_called_once() + + +def test_generate_and_vectorize_summary_creates_missing_record_and_logs_usage(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + usage = MagicMock(total_tokens=4, prompt_tokens=1, completion_tokens=3) + monkeypatch.setattr(SummaryIndexService, "generate_summary_for_segment", MagicMock(return_value=("sum", usage))) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + result = SummaryIndexService.generate_and_vectorize_summary(segment, dataset, {"enable": True}) + assert result.status in {SummaryStatus.GENERATING, SummaryStatus.COMPLETED} + logger_mock.info.assert_called() + + +def test_generate_summaries_for_document_skip_conditions(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset(indexing_technique="economy") + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + dataset = _dataset() + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": False}) == [] + + document.doc_form = "qa_model" + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_runs_and_handles_errors(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + seg1 = _segment() + seg2 = _segment() + seg2.id = "seg-2" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg1, seg2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr( + SummaryIndexService, + "generate_and_vectorize_summary", + MagicMock(side_effect=[MagicMock(), RuntimeError("boom")]), + ) + update_err_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "update_summary_record_error", update_err_mock) + + records = SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) + assert len(records) == 1 + update_err_mock.assert_called_once() + + +def test_generate_summaries_for_document_no_segments_returns_empty(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.generate_summaries_for_document(dataset, document, {"enable": True}) == [] + + +def test_generate_summaries_for_document_applies_segment_ids_and_only_parent_chunks( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + document = MagicMock(spec=summary_module.DatasetDocument) + document.id = "doc-1" + document.doc_form = "text_model" + seg = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [seg] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + monkeypatch.setattr(SummaryIndexService, "batch_create_summary_records", MagicMock()) + monkeypatch.setattr(SummaryIndexService, "generate_and_vectorize_summary", MagicMock(return_value=MagicMock())) + + SummaryIndexService.generate_summaries_for_document( + dataset, + document, + {"enable": True}, + segment_ids=[seg.id], + only_parent_chunks=True, + ) + query.filter.assert_called() + + +def test_disable_summaries_for_segments_handles_vector_delete_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="s", node_id="n1") + summary2 = _summary_record(summary_content="s", node_id=None) + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary1, summary2] + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + summary_module, + "Vector", + MagicMock(return_value=MagicMock(delete_by_ids=MagicMock(side_effect=RuntimeError("boom")))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + + SummaryIndexService.disable_summaries_for_segments(dataset, segment_ids=["seg-1"], disabled_by="u") + assert summary1.enabled is False + assert summary1.disabled_by == "u" + session.commit.assert_called_once() + + +def test_disable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setitem( + sys.modules, "libs.datetime_utils", SimpleNamespace(naive_utc_now=MagicMock(return_value=datetime(2024, 1, 1))) + ) + SummaryIndexService.disable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_non_high_quality() -> None: + SummaryIndexService.enable_summaries_for_segments(_dataset(indexing_technique="economy")) + + +def test_enable_summaries_for_segments_revectorizes_and_enables(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + summary.enabled = False + + segment = _segment() + segment.id = summary.chunk_id + segment.enabled = True + segment.status = SegmentStatus.COMPLETED + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.return_value = segment + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + vec_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vec_mock) + + SummaryIndexService.enable_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vec_mock.assert_called_once() + assert summary.enabled is True + session.commit.assert_called_once() + + +def test_enable_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.enable_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_enable_summaries_for_segments_skips_segment_or_content_and_handles_vectorize_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + summary1 = _summary_record(summary_content="sum", node_id="n1") + summary1.enabled = False + summary2 = _summary_record(summary_content="", node_id="n2") + summary2.enabled = False + summary3 = _summary_record(summary_content="sum3", node_id="n3") + summary3.enabled = False + + bad_segment = _segment() + bad_segment.enabled = False + bad_segment.status = SegmentStatus.COMPLETED + + good_segment = _segment() + good_segment.enabled = True + good_segment.status = SegmentStatus.COMPLETED + + session = MagicMock() + summary_query = MagicMock() + summary_query.filter_by.return_value = summary_query + summary_query.filter.return_value = summary_query + summary_query.all.return_value = [summary1, summary2, summary3] + + seg_query = MagicMock() + seg_query.filter_by.return_value = seg_query + seg_query.first.side_effect = [bad_segment, good_segment, good_segment] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + return summary_query + return seg_query + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + SummaryIndexService.enable_summaries_for_segments(dataset) + logger_mock.exception.assert_called_once() + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_deletes_vectors_and_records(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + summary = _summary_record(summary_content="sum", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [summary] + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids=[summary.chunk_id]) + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(summary) + session.commit.assert_called_once() + + +def test_delete_summaries_for_segments_no_summaries_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.filter.return_value = query + query.all.return_value = [] + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + SummaryIndexService.delete_summaries_for_segments(dataset) + session.commit.assert_not_called() + + +def test_update_summary_for_segment_skip_conditions() -> None: + assert ( + SummaryIndexService.update_summary_for_segment(_segment(), _dataset(indexing_technique="economy"), "x") is None + ) + seg = _segment(has_document=True) + seg.document.doc_form = "qa_model" + assert SummaryIndexService.update_summary_for_segment(seg, _dataset(), "x") is None + + +def test_update_summary_for_segment_empty_content_deletes_existing(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + vector_instance.delete_by_ids.assert_called_once_with(["n1"]) + session.delete.assert_called_once_with(record) + session.commit.assert_called_once() + + +def test_update_summary_for_segment_empty_content_delete_vector_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, "") is None + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_empty_content_no_record_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.update_summary_for_segment(segment, dataset, " ") is None + + +def test_update_summary_for_segment_updates_existing_and_vectorizes(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + + vector_instance = MagicMock() + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vectorize_mock = MagicMock() + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new summary") + assert out is record + vectorize_mock.assert_called_once() + session.refresh.assert_called_once_with(record) + session.commit.assert_called() + + +def test_update_summary_for_segment_existing_vector_delete_warns(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + vector_instance = MagicMock() + vector_instance.delete_by_ids.side_effect = RuntimeError("boom") + monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + logger_mock = MagicMock() + monkeypatch.setattr(summary_module, "logger", logger_mock) + + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + logger_mock.warning.assert_called() + + +def test_update_summary_for_segment_existing_vectorize_failure_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(side_effect=RuntimeError("boom"))) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is record + assert out.status == SummaryStatus.ERROR + assert "Vectorization failed" in (out.error or "") + + +def test_update_summary_for_segment_new_record_success(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", MagicMock(return_value=None)) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out is created + session.refresh.assert_called() + session.commit.assert_called() + + +def test_update_summary_for_segment_outer_exception_sets_error_and_reraises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _dataset() + segment = _segment() + record = _summary_record(summary_content="old", node_id="n1") + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = record + session.query.return_value = query + session.flush.side_effect = RuntimeError("flush boom") + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + with pytest.raises(RuntimeError, match="flush boom"): + SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert record.status == SummaryStatus.ERROR + assert record.error == "flush boom" + session.commit.assert_called() + + +def test_get_segment_summary_and_document_summaries(monkeypatch: pytest.MonkeyPatch) -> None: + record = _summary_record(summary_content="sum", node_id="n1") + session = MagicMock() + + q1 = MagicMock() + q1.where.return_value = q1 + q1.first.return_value = record + + q2 = MagicMock() + q2.filter.return_value = q2 + q2.all.return_value = [record] + + def query_side_effect(model: object) -> MagicMock: + if model is summary_module.DocumentSegmentSummary: + # first call used by get_segment_summary, second by get_document_summaries + if not hasattr(query_side_effect, "_called"): + query_side_effect._called = True # type: ignore[attr-defined] + return q1 + return q2 + return MagicMock() + + session.query.side_effect = query_side_effect + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + assert SummaryIndexService.get_segment_summary("seg-1", "dataset-1") is record + assert SummaryIndexService.get_document_summaries("doc-1", "dataset-1", segment_ids=["seg-1"]) == [record] + + +def test_get_segments_summaries_non_empty(monkeypatch: pytest.MonkeyPatch) -> None: + record1 = _summary_record() + record1.chunk_id = "seg-1" + record2 = _summary_record() + record2.chunk_id = "seg-2" + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [record1, record2] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + out = SummaryIndexService.get_segments_summaries(["seg-1", "seg-2"], "dataset-1") + assert set(out.keys()) == {"seg-1", "seg-2"} + + +def test_get_document_summary_index_status_no_segments_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") is None + + +def test_get_documents_summary_index_status_empty_input(monkeypatch: pytest.MonkeyPatch) -> None: + assert SummaryIndexService.get_documents_summary_index_status([], "dataset-1", "tenant-1") == {} + + +def test_get_documents_summary_index_status_no_pending_sets_none(monkeypatch: pytest.MonkeyPatch) -> None: + session = MagicMock() + q = MagicMock() + q.where.return_value = q + q.all.return_value = [SimpleNamespace(id="seg-1", document_id="doc-1")] + session.query.return_value = q + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.COMPLETED)}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1"], "dataset-1", "tenant-1") + assert result["doc-1"] is None + + +def test_update_summary_for_segment_creates_new_and_vectorize_fails_returns_error_record( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _dataset() + segment = _segment() + + session = MagicMock() + query = MagicMock() + query.filter_by.return_value = query + query.first.return_value = None + session.query.return_value = query + + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session))), + ) + + created = _summary_record(summary_content="new", node_id=None) + monkeypatch.setattr(SummaryIndexService, "create_summary_record", MagicMock(return_value=created)) + session.merge.return_value = created + + vectorize_mock = MagicMock(side_effect=RuntimeError("boom")) + monkeypatch.setattr(SummaryIndexService, "vectorize_summary", vectorize_mock) + + out = SummaryIndexService.update_summary_for_segment(segment, dataset, "new") + assert out.status == SummaryStatus.ERROR + assert "Vectorization failed" in (out.error or "") + + +def test_get_segments_summaries_empty_list() -> None: + assert SummaryIndexService.get_segments_summaries([], "dataset-1") == {} + + +def test_get_document_summary_index_status_and_documents_status(monkeypatch: pytest.MonkeyPatch) -> None: + seg_row = SimpleNamespace(id="seg-1", document_id="doc-1") + session = MagicMock() + query = MagicMock() + query.where.return_value = query + query.all.return_value = [SimpleNamespace(id="seg-1")] + session.query.return_value = query + + create_session_mock = MagicMock(return_value=_SessionContext(session)) + monkeypatch.setattr(summary_module, "session_factory", SimpleNamespace(create_session=create_session_mock)) + + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.GENERATING)}), + ) + assert SummaryIndexService.get_document_summary_index_status("doc-1", "dataset-1", "tenant-1") == "SUMMARIZING" + + # Multiple docs + query2 = MagicMock() + query2.where.return_value = query2 + query2.all.return_value = [seg_row] + session2 = MagicMock() + session2.query.return_value = query2 + monkeypatch.setattr( + summary_module, + "session_factory", + SimpleNamespace(create_session=MagicMock(return_value=_SessionContext(session2))), + ) + monkeypatch.setattr( + SummaryIndexService, + "get_segments_summaries", + MagicMock(return_value={"seg-1": SimpleNamespace(status=SummaryStatus.NOT_STARTED)}), + ) + result = SummaryIndexService.get_documents_summary_index_status(["doc-1", "doc-2"], "dataset-1", "tenant-1") + assert result["doc-1"] == "SUMMARIZING" + assert result["doc-2"] is None + + +def test_get_document_summary_status_detail_counts_and_previews(monkeypatch: pytest.MonkeyPatch) -> None: + segment1 = _segment() + segment1.id = "seg-1" + segment1.position = 1 + segment2 = _segment() + segment2.id = "seg-2" + segment2.position = 2 + + summary1 = _summary_record(summary_content="x" * 150, node_id="n1") + summary1.chunk_id = "seg-1" + summary1.status = SummaryStatus.COMPLETED + summary1.error = None + summary1.created_at = datetime(2024, 1, 1, tzinfo=UTC) + summary1.updated_at = datetime(2024, 1, 2, tzinfo=UTC) + + segment_service = SimpleNamespace(get_segments_by_document_and_dataset=MagicMock(return_value=[segment1, segment2])) + monkeypatch.setitem(sys.modules, "services.dataset_service", SimpleNamespace(SegmentService=segment_service)) + + monkeypatch.setattr(SummaryIndexService, "get_document_summaries", MagicMock(return_value=[summary1])) + + detail = SummaryIndexService.get_document_summary_status_detail("doc-1", "dataset-1") + assert detail["total_segments"] == 2 + assert detail["summary_status"]["completed"] == 1 + assert detail["summary_status"]["not_started"] == 1 + assert detail["summaries"][0]["summary_preview"].endswith("...") + assert detail["summaries"][1]["status"] == "not_started" diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py index 264eac4d77..4d2d63e501 100644 --- a/api/tests/unit_tests/services/test_tag_service.py +++ b/api/tests/unit_tests/services/test_tag_service.py @@ -75,6 +75,7 @@ import pytest from werkzeug.exceptions import NotFound from models.dataset import Dataset +from models.enums import TagType from models.model import App, Tag, TagBinding from services.tag_service import TagService @@ -102,7 +103,7 @@ class TagServiceTestDataFactory: def create_tag_mock( tag_id: str = "tag-123", name: str = "Test Tag", - tag_type: str = "app", + tag_type: TagType = TagType.APP, tenant_id: str = "tenant-123", **kwargs, ) -> Mock: @@ -705,7 +706,7 @@ class TestTagServiceCRUD: # Verify tag attributes added_tag = mock_db_session.add.call_args[0][0] assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == "app", "Tag type should match" + assert added_tag.type == TagType.APP, "Tag type should match" assert added_tag.created_by == "user-123", "Created by should match current user" assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py new file mode 100644 index 0000000000..81a3b181fd --- /dev/null +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -0,0 +1,1249 @@ +from __future__ import annotations + +import contextlib +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from core.plugin.entities.plugin_daemon import CredentialType +from models.provider_ids import TriggerProviderID +from services.trigger.trigger_provider_service import TriggerProviderService + + +def _patch_redis_lock(mocker: MockerFixture) -> None: + mock_redis = mocker.patch("services.trigger.trigger_provider_service.redis_client") + mock_redis.lock.return_value = contextlib.nullcontext() + + +def _mock_get_trigger_provider(mocker: MockerFixture, provider: object | None) -> None: + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.get_trigger_provider", + return_value=provider, + ) + + +def _encrypter_mock( + *, + decrypted: dict | None = None, + encrypted: dict | None = None, + masked: dict | None = None, +) -> MagicMock: + enc = MagicMock() + enc.decrypt.return_value = decrypted or {} + enc.encrypt.return_value = encrypted or {} + enc.mask_credentials.return_value = masked or {} + enc.mask_plugin_credentials.return_value = masked or {} + return enc + + +@pytest.fixture +def provider_id() -> TriggerProviderID: + # Arrange + return TriggerProviderID("langgenius/github/github") + + +@pytest.fixture(autouse=True) +def mock_db_engine(mocker: MockerFixture) -> SimpleNamespace: + # Arrange + mocked_db = SimpleNamespace(engine=object()) + mocker.patch("services.trigger.trigger_provider_service.db", mocked_db) + return mocked_db + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mocks the database session context manager used by TriggerProviderService.""" + # Arrange + mock_session_instance = MagicMock() + mock_session_cm = MagicMock() + mock_session_cm.__enter__.return_value = mock_session_instance + mock_session_cm.__exit__.return_value = False + mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + return mock_session_instance + + +@pytest.fixture +def provider_controller() -> MagicMock: + # Arrange + controller = MagicMock() + controller.get_credential_schema_config.return_value = [] + controller.get_properties_schema.return_value = [] + controller.get_oauth_client_schema.return_value = [] + controller.plugin_unique_identifier = "langgenius/github:0.0.1" + return controller + + +def test_get_trigger_provider_should_return_api_entity_from_manager( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + provider = MagicMock() + provider.to_api_entity.return_value = {"provider": "ok"} + _mock_get_trigger_provider(mocker, provider) + + # Act + result = TriggerProviderService.get_trigger_provider("tenant-1", provider_id) + + # Assert + assert result == {"provider": "ok"} + + +def test_list_trigger_providers_should_return_api_entities_from_manager(mocker: MockerFixture) -> None: + # Arrange + provider_a = MagicMock() + provider_b = MagicMock() + provider_a.to_api_entity.return_value = {"id": "a"} + provider_b.to_api_entity.return_value = {"id": "b"} + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.list_all_trigger_providers", + return_value=[provider_a, provider_b], + ) + + # Act + result = TriggerProviderService.list_trigger_providers("tenant-1") + + # Assert + assert result == [{"id": "a"}, {"id": "b"}] + + +def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_subscriptions( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_session.query.return_value = query + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert result == [] + + +def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workflow_counts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + api_sub = SimpleNamespace( + id="sub-1", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + parameters={"event": "push"}, + workflows_in_use=0, + ) + db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) + usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) + + query_subs = MagicMock() + query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] + query_usage = MagicMock() + query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] + mock_session.query.side_effect = [query_subs, query_usage] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) + prop_enc = _encrypter_mock(decrypted={"hook": "plain"}, masked={"hook": "****"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert len(result) == 1 + assert result[0].credentials == {"token": "****"} + assert result[0].properties == {"hook": "****"} + assert result[0].workflows_in_use == 2 + + +def test_add_trigger_subscription_should_create_subscription_successfully_for_api_key( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) + prop_enc = _encrypter_mock(encrypted={"project": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(cred_enc, MagicMock()), (prop_enc, MagicMock())], + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={"event": "push"}, + properties={"project": "demo"}, + credentials={"api_key": "plain"}, + ) + + # Assert + assert result["result"] == "success" + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(encrypted={"p": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.UNAUTHORIZED, + parameters={}, + properties={"p": "v"}, + credentials={}, + subscription_id="sub-fixed", + ) + + # Assert + assert result == {"result": "success", "id": "sub-fixed"} + + +def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ + mock_session.query.return_value = query_count + _mock_get_trigger_provider(mocker, provider_controller) + mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="Maximum number of providers"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + mock_logger.exception.assert_called_once() + + +def test_add_trigger_subscription_should_raise_error_when_name_exists( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_count, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="Credential name 'main' already exists"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + + +def test_update_trigger_subscription_should_raise_error_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query_sub + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1") + + +def test_update_trigger_subscription_should_raise_error_when_name_conflicts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + provider_id="langgenius/github/github", + credential_type=CredentialType.API_KEY.value, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_sub, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1", name="new-name") + + +def test_update_trigger_subscription_should_update_fields_and_clear_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + properties={"project": "enc-old"}, + parameters={"event": "old"}, + credentials={"api_key": "enc-old"}, + credential_type=CredentialType.API_KEY.value, + credential_expires_at=0, + expires_at=0, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_sub, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) + cred_enc = _encrypter_mock(encrypted={"api_key": "new-key"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(prop_enc, MagicMock()), (cred_enc, MagicMock())], + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.update_trigger_subscription( + tenant_id="tenant-1", + subscription_id="sub-1", + name="new", + properties={"project": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "new"}, + credentials={"api_key": "plain-key"}, + credential_expires_at=100, + expires_at=200, + ) + + # Assert + assert subscription.name == "new" + assert subscription.parameters == {"event": "new"} + assert subscription.credentials == {"api_key": "new-key"} + assert subscription.credential_expires_at == 100 + assert subscription.expires_at == 200 + mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() + + +def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is None + + +def test_get_subscription_by_id_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"project": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + prop_enc = _encrypter_mock(decrypted={"project": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"project": "plain"} + + +def test_delete_trigger_provider_should_raise_error_when_subscription_missing( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + +def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.OAUTH2.value, + credentials={"token": "enc"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + side_effect=RuntimeError("remote fail"), + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + # Assert + mock_session.delete.assert_called_once_with(subscription) + mock_delete_cache.assert_called_once() + + +def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-2", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.UNAUTHORIZED.value, + credentials={}, + to_entity=lambda: SimpleNamespace(id="sub-2"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={}), MagicMock()), + ) + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-2") + + # Assert + mock_unsubscribe.assert_not_called() + mock_session.delete.assert_called_once_with(subscription) + + +def test_refresh_oauth_token_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + Assert + with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + provider_id=str(provider_id), + user_id="user-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"access_token": "enc"}, + credential_expires_at=0, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(cred_enc, cache), + ) + mocker.patch.object(TriggerProviderService, "get_oauth_client", return_value={"client_id": "id"}) + refreshed = SimpleNamespace(credentials={"access_token": "new"}, expires_at=12345) + oauth_handler = MagicMock() + oauth_handler.refresh_credentials.return_value = refreshed + mocker.patch("services.trigger.trigger_provider_service.OAuthHandler", return_value=oauth_handler) + + # Act + result = TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + # Assert + assert result == {"result": "success", "expires_at": 12345} + assert subscription.credentials == {"access_token": "new"} + assert subscription.credential_expires_at == 12345 + mock_session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_refresh_subscription_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + +def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + subscription = SimpleNamespace(expires_at=200) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "skipped", "expires_at": 200} + + +def test_refresh_subscription_should_refresh_and_persist_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + endpoint_id="endpoint-1", + expires_at=50, + provider_id=str(provider_id), + parameters={"event": "push"}, + properties={"p": "enc"}, + credentials={"c": "enc"}, + credential_type=CredentialType.API_KEY.value, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"c": "plain"}) + prop_cache = MagicMock() + prop_enc = _encrypter_mock(decrypted={"p": "plain"}, encrypted={"p": "new-enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, prop_cache), + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + provider_controller.refresh_trigger.return_value = SimpleNamespace(properties={"p": "new"}, expires_at=999) + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "success", "expires_at": 999} + assert subscription.properties == {"p": "new-enc"} + assert subscription.expires_at == 999 + mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() + + +def test_get_oauth_client_should_return_tenant_client_when_available( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + system_client = None + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = tenant_client + mock_session.query.return_value = query_tenant + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "plain"} + + +def test_get_oauth_client_should_return_none_when_plugin_not_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result is None + + +def test_get_oauth_client_should_return_decrypted_system_client_when_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + return_value={"client_id": "system"}, + ) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "system"} + + +def test_get_oauth_client_should_raise_error_when_system_decryption_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + side_effect=RuntimeError("bad data"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Error decrypting system oauth params"): + TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + +def test_is_oauth_system_client_exists_should_return_false_when_unverified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is False + + +@pytest.mark.parametrize("has_client", [True, False]) +def test_is_oauth_system_client_exists_should_reflect_database_record( + has_client: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is has_client + + +def test_save_custom_oauth_client_params_should_return_success_when_nothing_to_update( + provider_id: TriggerProviderID, +) -> None: + # Arrange + # Act + result = TriggerProviderService.save_custom_oauth_client_params("tenant-1", provider_id, None, None) + + # Assert + assert result == {"result": "success"} + + +def test_save_custom_oauth_client_params_should_create_record_and_clear_params_when_client_params_none( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query + _mock_get_trigger_provider(mocker, provider_controller) + fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params=None, + enabled=True, + ) + + # Assert + assert result == {"result": "success"} + assert fake_model.encrypted_oauth_params == "{}" + assert fake_model.enabled is True + mock_session.add.assert_called_once_with(fake_model) + mock_session.commit.assert_called_once() + + +def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(enc, cache), + ) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params={"client_id": HIDDEN_VALUE, "client_secret": "new"}, + enabled=None, + ) + + # Assert + assert result == {"result": "success"} + assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} + cache.delete.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {} + + +def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "pl***id"} + + +def test_delete_custom_oauth_client_params_should_delete_record_and_commit( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 + + # Act + result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"result": "success"} + mock_session.commit.assert_called_once() + + +@pytest.mark.parametrize("exists", [True, False]) +def test_is_oauth_custom_client_enabled_should_return_expected_boolean( + exists: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + + # Act + result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) + + # Assert + assert result is exists + + +def test_get_subscription_by_endpoint_should_return_none_when_not_found( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is None + + +def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={"token": "plain"}), MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(_encrypter_mock(decrypted={"hook": "plain"}), MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"hook": "plain"} + + +def test_verify_subscription_credentials_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_api_key_validation_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + provider_controller.validate_credentials.side_effect = RuntimeError("bad credentials") + + # Act + Assert + with pytest.raises(ValueError, match="Invalid credentials: bad credentials"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + +def test_verify_subscription_credentials_should_return_verified_when_api_key_validation_succeeds( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + # Assert + assert result == {"verified": True} + + +def test_verify_subscription_credentials_should_return_verified_for_non_api_key_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.OAUTH2.value, credentials={}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + # Assert + assert result == {"verified": True} + + +def test_rebuild_trigger_subscription_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_for_unsupported_credential_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.UNAUTHORIZED.value) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + Assert + with pytest.raises(ValueError, match="not supported for auto creation"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=False, message="remote error"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to delete previous subscription"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_resubscribe_and_update_existing_subscription( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old-key"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + new_subscription = SimpleNamespace(properties={"project": "new"}, expires_at=888) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=True, message="ok"), + ) + mock_subscribe = mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.subscribe_trigger", + return_value=new_subscription, + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + mock_update = mocker.patch.object(TriggerProviderService, "update_trigger_subscription") + + # Act + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "push"}, + name="updated", + ) + + # Assert + call_kwargs = mock_subscribe.call_args.kwargs + assert call_kwargs["credentials"]["api_key"] == "old-key" + assert call_kwargs["credentials"]["region"] == "us" + mock_update.assert_called_once_with( + tenant_id="tenant-1", + subscription_id="sub-1", + name="updated", + parameters={"event": "push"}, + credentials={"api_key": "old-key", "region": "us"}, + properties={"project": "new"}, + expires_at=888, + ) diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index 8199d586da..c703ab64d0 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -17,9 +17,9 @@ from uuid import uuid4 import pytest -from core.workflow.file.enums import FileTransferMethod, FileType -from core.workflow.file.models import File -from core.workflow.variables.segments import ( +from dify_graph.file.enums import FileTransferMethod, FileType +from dify_graph.file.models import File +from dify_graph.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py new file mode 100644 index 0000000000..7b0103a2a1 --- /dev/null +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -0,0 +1,704 @@ +"""Unit tests for `api/services/vector_service.py`.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from unittest.mock import MagicMock + +import pytest + +import services.vector_service as vector_service_module +from services.vector_service import VectorService + + +@dataclass(frozen=True) +class _UploadFileStub: + id: str + name: str + + +@dataclass(frozen=True) +class _ChildDocStub: + page_content: str + metadata: dict[str, Any] + + +@dataclass +class _ParentDocStub: + children: list[_ChildDocStub] + + +def _make_dataset( + *, + indexing_technique: str = "high_quality", + doc_form: str = "text_model", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + is_multimodal: bool = False, + embedding_model_provider: str | None = "openai", + embedding_model: str = "text-embedding", +) -> MagicMock: + dataset = MagicMock(name="dataset") + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.doc_form = doc_form + dataset.indexing_technique = indexing_technique + dataset.is_multimodal = is_multimodal + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + return dataset + + +def _make_segment( + *, + segment_id: str = "seg-1", + tenant_id: str = "tenant-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", + content: str = "hello", + index_node_id: str = "node-1", + index_node_hash: str = "hash-1", + attachments: list[dict[str, str]] | None = None, +) -> MagicMock: + segment = MagicMock(name="segment") + segment.id = segment_id + segment.tenant_id = tenant_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.index_node_id = index_node_id + segment.index_node_hash = index_node_hash + segment.attachments = attachments or [] + return segment + + +def _mock_db_session_for_update_multimodel(*, upload_files: list[_UploadFileStub] | None) -> MagicMock: + session = MagicMock(name="session") + + binding_query = MagicMock(name="binding_query") + binding_query.where.return_value = binding_query + binding_query.delete.return_value = 1 + + upload_query = MagicMock(name="upload_query") + upload_query.where.return_value = upload_query + upload_query.all.return_value = upload_files or [] + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.SegmentAttachmentBinding: + return binding_query + if model is vector_service_module.UploadFile: + return upload_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_regular_indexing_loads_documents_and_keywords(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=False) + segment = _make_segment() + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + index_processor.load.assert_called_once() + args, kwargs = index_processor.load.call_args + assert args[0] == dataset + assert len(args[1]) == 1 + assert args[2] is None + assert kwargs["with_keywords"] is True + assert kwargs["keywords_list"] == [["k1"]] + + +def test_create_segments_vector_regular_indexing_loads_multimodal_documents(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(is_multimodal=True) + segment = _make_segment( + attachments=[ + {"id": "img-1", "name": "a.png"}, + {"id": "img-2", "name": "b.png"}, + ] + ) + + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock(name="IndexProcessorFactory-instance") + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector([["k1"]], [segment], dataset, "text_model") + + assert index_processor.load.call_count == 2 + first_args, first_kwargs = index_processor.load.call_args_list[0] + assert first_args[0] == dataset + assert len(first_args[1]) == 1 + assert first_kwargs["with_keywords"] is True + + second_args, second_kwargs = index_processor.load.call_args_list[1] + assert second_args[0] == dataset + assert second_args[1] == [] + assert len(second_args[2]) == 2 + assert second_kwargs["with_keywords"] is False + + +def test_create_segments_vector_with_no_segments_does_not_load(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + index_processor = MagicMock(name="index_processor") + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector(None, [], dataset, "text_model") + index_processor.load.assert_not_called() + + +def _mock_parent_child_queries( + *, + dataset_document: object | None, + processing_rule: object | None, +) -> MagicMock: + session = MagicMock(name="session") + + doc_query = MagicMock(name="doc_query") + doc_query.filter_by.return_value = doc_query + doc_query.first.return_value = dataset_document + + rule_query = MagicMock(name="rule_query") + rule_query.where.return_value = rule_query + rule_query.first.return_value = processing_rule + + def query_side_effect(model: object) -> MagicMock: + if model is vector_service_module.DatasetDocument: + return doc_query + if model is vector_service_module.DatasetProcessRule: + return rule_query + return MagicMock(name=f"query({model})") + + session.query.side_effect = query_side_effect + db_mock = MagicMock(name="db") + db_mock.session = session + return db_mock + + +def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_explicit_model( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider="openai", + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock(name="dataset_document") + dataset_document.id = segment.document_id + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock(name="processing_rule") + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock(name="embedding_model_instance") + model_manager_instance = MagicMock(name="model_manager_instance") + model_manager_instance.get_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once_with( + segment, dataset_document, dataset, embedding_model_instance, processing_rule, False + ) + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_uses_default_embedding_model_when_provider_missing( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + embedding_model_provider=None, + indexing_technique="high_quality", + ) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + embedding_model_instance = MagicMock() + model_manager_instance = MagicMock() + model_manager_instance.get_default_model_instance.return_value = embedding_model_instance + monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + + generate_child_chunks_mock = MagicMock() + monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + model_manager_instance.get_default_model_instance.assert_called_once() + generate_child_chunks_mock.assert_called_once() + + +def test_create_segments_vector_parent_child_missing_document_logs_warning_and_continues( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=None, processing_rule=processing_rule), + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + index_processor = MagicMock() + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + logger_mock.warning.assert_called_once() + index_processor.load.assert_not_called() + + +def test_create_segments_vector_parent_child_missing_processing_rule_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX) + segment = _make_segment() + + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=None), + ) + + with pytest.raises(ValueError, match="No processing rule found"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_create_segments_vector_parent_child_non_high_quality_raises(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset( + doc_form=vector_service_module.IndexStructureType.PARENT_CHILD_INDEX, + indexing_technique="economy", + ) + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.dataset_process_rule_id = "rule-1" + processing_rule = MagicMock() + monkeypatch.setattr( + vector_service_module, + "db", + _mock_parent_child_queries(dataset_document=dataset_document, processing_rule=processing_rule), + ) + + with pytest.raises(ValueError, match="not high quality"): + VectorService.create_segments_vector( + None, [segment], dataset, vector_service_module.IndexStructureType.PARENT_CHILD_INDEX + ) + + +def test_update_segment_vector_high_quality_uses_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + segment = _make_segment() + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_segment_vector(["k"], segment, dataset) + + vector_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + vector_instance.add_texts.assert_called_once() + add_args, add_kwargs = vector_instance.add_texts.call_args + assert len(add_args[0]) == 1 + assert add_kwargs["duplicate_check"] is True + + +def test_update_segment_vector_economy_uses_keyword_with_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(["a", "b"], segment, dataset) + + keyword_instance.delete_by_ids.assert_called_once_with([segment.index_node_id]) + keyword_instance.add_texts.assert_called_once() + args, kwargs = keyword_instance.add_texts.call_args + assert len(args[0]) == 1 + assert kwargs["keywords_list"] == [["a", "b"]] + + +def test_update_segment_vector_economy_uses_keyword_without_keywords_list(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + segment = _make_segment() + + keyword_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Keyword", MagicMock(return_value=keyword_instance)) + + VectorService.update_segment_vector(None, segment, dataset) + keyword_instance.add_texts.assert_called_once() + _, kwargs = keyword_instance.add_texts.call_args + assert "keywords_list" not in kwargs + + +def test_generate_child_chunks_regenerate_cleans_then_saves_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model", tenant_id="tenant-1", dataset_id="dataset-1") + segment = _make_segment(segment_id="seg-1") + + dataset_document = MagicMock() + dataset_document.id = segment.document_id + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + child1 = _ChildDocStub(page_content="c1", metadata={"doc_id": "c1-id", "doc_hash": "c1-h"}) + child2 = _ChildDocStub(page_content="c2", metadata={"doc_id": "c2-id", "doc_hash": "c2-h"}) + transformed = [_ParentDocStub(children=[child1, child2])] + + index_processor = MagicMock() + index_processor.transform.return_value = transformed + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + child_chunk_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "ChildChunk", child_chunk_ctor) + + db_mock = MagicMock() + db_mock.session.add = MagicMock() + db_mock.session.commit = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=True, + ) + + index_processor.clean.assert_called_once() + _, transform_kwargs = index_processor.transform.call_args + assert transform_kwargs["process_rule"]["rules"]["parent_mode"] == vector_service_module.ParentMode.FULL_DOC + index_processor.load.assert_called_once() + assert db_mock.session.add.call_count == 2 + db_mock.session.commit.assert_called_once() + + +def test_generate_child_chunks_commits_even_when_no_children(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(doc_form="text_model") + segment = _make_segment() + dataset_document = MagicMock() + dataset_document.doc_language = "en" + dataset_document.created_by = "user-1" + + processing_rule = MagicMock() + processing_rule.to_dict.return_value = {"rules": {}} + + index_processor = MagicMock() + index_processor.transform.return_value = [_ParentDocStub(children=[])] + factory_instance = MagicMock() + factory_instance.init_index_processor.return_value = index_processor + monkeypatch.setattr(vector_service_module, "IndexProcessorFactory", MagicMock(return_value=factory_instance)) + + db_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.generate_child_chunks( + segment=segment, + dataset_document=dataset_document, + dataset=dataset, + embedding_model_instance=MagicMock(), + processing_rule=processing_rule, + regenerate=False, + ) + + index_processor.load.assert_not_called() + db_mock.session.add.assert_not_called() + db_mock.session.commit.assert_called_once() + + +def test_create_child_chunk_vector_high_quality_adds_texts(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_instance.add_texts.assert_called_once() + + +def test_create_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + + child_chunk = MagicMock() + child_chunk.content = "child" + child_chunk.index_node_id = "id" + child_chunk.index_node_hash = "h" + child_chunk.document_id = "doc-1" + child_chunk.dataset_id = "dataset-1" + + VectorService.create_child_chunk_vector(child_chunk, dataset) + vector_cls.assert_not_called() + + +def test_update_child_chunk_vector_high_quality_updates_vector(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality") + + new_chunk = MagicMock() + new_chunk.content = "n" + new_chunk.index_node_id = "nid" + new_chunk.index_node_hash = "nh" + new_chunk.document_id = "d" + new_chunk.dataset_id = "ds" + + upd_chunk = MagicMock() + upd_chunk.content = "u" + upd_chunk.index_node_id = "uid" + upd_chunk.index_node_hash = "uh" + upd_chunk.document_id = "d" + upd_chunk.dataset_id = "ds" + + del_chunk = MagicMock() + del_chunk.index_node_id = "did" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.update_child_chunk_vector([new_chunk], [upd_chunk], [del_chunk], dataset) + + vector_instance.delete_by_ids.assert_called_once_with(["uid", "did"]) + vector_instance.add_texts.assert_called_once() + docs = vector_instance.add_texts.call_args.args[0] + assert len(docs) == 2 + + +def test_update_child_chunk_vector_economy_noop(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy") + vector_cls = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + VectorService.update_child_chunk_vector([], [], [], dataset) + vector_cls.assert_not_called() + + +def test_delete_child_chunk_vector_deletes_by_id(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset() + child_chunk = MagicMock() + child_chunk.index_node_id = "cid" + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + + VectorService.delete_child_chunk_vector(child_chunk, dataset) + vector_instance.delete_by_ids.assert_called_once_with(["cid"]) + + +# --------------------------------------------------------------------------- +# update_multimodel_vector (missing coverage in previous suites) +# --------------------------------------------------------------------------- + + +def test_update_multimodel_vector_returns_when_not_high_quality(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="economy", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_returns_when_no_actual_change(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="t", attachments=[{"id": "a"}, {"id": "b"}]) + + vector_cls = MagicMock() + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["b", "a"], dataset=dataset) + vector_cls.assert_not_called() + db_mock.session.query.assert_not_called() + + +def test_update_multimodel_vector_deletes_bindings_and_commits_on_empty_new_ids( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}, {"id": "old-2"}]) + + vector_instance = MagicMock(name="vector_instance") + vector_cls = MagicMock(return_value=vector_instance) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + + monkeypatch.setattr(vector_service_module, "Vector", vector_cls) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=[], dataset=dataset) + + vector_cls.assert_called_once_with(dataset=dataset) + vector_instance.delete_by_ids.assert_called_once_with(["old-1", "old-2"]) + db_mock.session.query.assert_called_once_with(vector_service_module.SegmentAttachmentBinding) + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_commits_when_no_upload_files_found(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["new-1"], dataset=dataset) + + db_mock.session.commit.assert_called_once() + db_mock.session.add_all.assert_not_called() + vector_instance.add_texts.assert_not_called() + + +def test_update_multimodel_vector_adds_bindings_and_vectors_and_skips_missing_upload_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + + binding_ctor = MagicMock(side_effect=lambda **kwargs: kwargs) + monkeypatch.setattr(vector_service_module, "SegmentAttachmentBinding", binding_ctor) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1", "missing"], dataset=dataset) + + logger_mock.warning.assert_called_once() + db_mock.session.add_all.assert_called_once() + bindings = db_mock.session.add_all.call_args.args[0] + assert len(bindings) == 1 + assert bindings[0]["attachment_id"] == "file-1" + + vector_instance.add_texts.assert_called_once() + documents = vector_instance.add_texts.call_args.args[0] + assert len(documents) == 1 + assert documents[0].page_content == "img.png" + assert documents[0].metadata["doc_id"] == "file-1" + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_updates_bindings_without_multimodal_vector_ops( + monkeypatch: pytest.MonkeyPatch, +) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=False) + segment = _make_segment(tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + vector_instance.delete_by_ids.assert_not_called() + vector_instance.add_texts.assert_not_called() + db_mock.session.add_all.assert_called_once() + db_mock.session.commit.assert_called_once() + + +def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch: pytest.MonkeyPatch) -> None: + dataset = _make_dataset(indexing_technique="high_quality", is_multimodal=True) + segment = _make_segment(segment_id="seg-1", tenant_id="tenant-1", attachments=[{"id": "old-1"}]) + + vector_instance = MagicMock() + monkeypatch.setattr(vector_service_module, "Vector", MagicMock(return_value=vector_instance)) + db_mock = _mock_db_session_for_update_multimodel(upload_files=[_UploadFileStub(id="file-1", name="img.png")]) + db_mock.session.commit.side_effect = RuntimeError("boom") + monkeypatch.setattr(vector_service_module, "db", db_mock) + monkeypatch.setattr( + vector_service_module, "SegmentAttachmentBinding", MagicMock(side_effect=lambda **kwargs: kwargs) + ) + + logger_mock = MagicMock() + monkeypatch.setattr(vector_service_module, "logger", logger_mock) + + with pytest.raises(RuntimeError, match="boom"): + VectorService.update_multimodel_vector(segment=segment, attachment_ids=["file-1"], dataset=dataset) + + logger_mock.exception.assert_called_once() + db_mock.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_web_conversation_service.py b/api/tests/unit_tests/services/test_web_conversation_service.py new file mode 100644 index 0000000000..7687d355e9 --- /dev/null +++ b/api/tests/unit_tests/services/test_web_conversation_service.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.app.entities.app_invoke_entities import InvokeFrom +from models import Account +from models.model import App, EndUser +from services.web_conversation_service import WebConversationService + + +@pytest.fixture +def app_model() -> App: + return cast(App, SimpleNamespace(id="app-1")) + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +def _end_user(**kwargs: Any) -> EndUser: + return cast(EndUser, SimpleNamespace(**kwargs)) + + +def test_pagination_by_last_id_should_raise_error_when_user_is_none( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + + # Act + Assert + with pytest.raises(ValueError, match="User is required"): + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=None, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + +def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_user = _account(id="user-1") + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id="conv-9", + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] is None + assert call_kwargs["exclude_ids"] is None + assert call_kwargs["last_id"] == "conv-9" + assert call_kwargs["sort_by"] == "-updated_at" + + +def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-1" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {})) + session.scalars.return_value.all.return_value = ["conv-1", "conv-2"] + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + pinned=True, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] == ["conv-1", "conv-2"] + assert call_kwargs["exclude_ids"] is None + + +def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-1" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + session.scalars.return_value.all.return_value = ["conv-3"] + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + pinned=False, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] is None + assert call_kwargs["exclude_ids"] == ["conv-3"] + + +def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: + # Arrange + mock_db = mocker.patch("services.web_conversation_service.db") + mocker.patch("services.web_conversation_service.ConversationService.get_conversation") + + # Act + WebConversationService.pin(app_model, "conv-1", None) + + # Assert + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_pin_should_return_early_when_conversation_is_already_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-1" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = object() + mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation") + + # Act + WebConversationService.pin(app_model, "conv-1", fake_user) + + # Assert + mock_get_conversation.assert_not_called() + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_pin_should_create_pinned_conversation_when_not_already_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-2" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_conversation = SimpleNamespace(id="conv-2") + mock_get_conversation = mocker.patch( + "services.web_conversation_service.ConversationService.get_conversation", + return_value=mock_conversation, + ) + + # Act + WebConversationService.pin(app_model, "conv-2", fake_user) + + # Assert + mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user) + added_obj = mock_db.session.add.call_args.args[0] + assert added_obj.app_id == "app-1" + assert added_obj.conversation_id == "conv-2" + assert added_obj.created_by_role == "account" + assert added_obj.created_by == "account-2" + mock_db.session.commit.assert_called_once() + + +def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: + # Arrange + mock_db = mocker.patch("services.web_conversation_service.db") + + # Act + WebConversationService.unpin(app_model, "conv-1", None) + + # Assert + mock_db.session.delete.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_unpin_should_return_early_when_conversation_is_not_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-3" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + WebConversationService.unpin(app_model, "conv-7", fake_user) + + # Assert + mock_db.session.delete.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_unpin_should_delete_pinned_conversation_when_exists( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-4" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + pinned_obj = SimpleNamespace(id="pin-1") + mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj + + # Act + WebConversationService.unpin(app_model, "conv-8", fake_user) + + # Assert + mock_db.session.delete.assert_called_once_with(pinned_obj) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_webapp_auth_service.py b/api/tests/unit_tests/services/test_webapp_auth_service.py new file mode 100644 index 0000000000..262c1f1524 --- /dev/null +++ b/api/tests/unit_tests/services/test_webapp_auth_service.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from werkzeug.exceptions import NotFound, Unauthorized + +from models import Account, AccountStatus +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType + +ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback" +TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token" +TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data" + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.webapp_auth_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None: + # Arrange + mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) + + # Act + Assert + with pytest.raises(AccountNotFoundError): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + Assert + with pytest.raises(AccountLoginError, match="Account is banned"): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +@pytest.mark.parametrize("password_value", [None, "hash"]) +def test_authenticate_should_raise_password_error_when_password_is_invalid( + password_value: str | None, + mocker: MockerFixture, +) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + mocker.patch("services.webapp_auth_service.compare_password", return_value=False) + + # Act + Assert + with pytest.raises(AccountPasswordError, match="Invalid email or password"): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + mocker.patch("services.webapp_auth_service.compare_password", return_value=True) + + # Act + result = WebAppAuthService.authenticate("user@example.com", "pwd") + + # Assert + assert result is account + + +def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None: + # Arrange + account = _account(id="a1", email="u@example.com") + mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token") + + # Act + result = WebAppAuthService.login(account) + + # Assert + assert result == "jwt-token" + mock_get_token.assert_called_once_with(account=account) + + +def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None: + # Arrange + mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) + + # Act + result = WebAppAuthService.get_user_through_email("missing@example.com") + + # Assert + assert result is None + + +def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.BANNED) + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + Assert + with pytest.raises(Unauthorized, match="Account is banned"): + WebAppAuthService.get_user_through_email("user@example.com") + + +def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE) + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + result = WebAppAuthService.get_user_through_email("user@example.com") + + # Assert + assert result is account + + +def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Email must be provided"): + WebAppAuthService.send_email_code_login_email(account=None, email=None) + + +def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account( + mocker: MockerFixture, +) -> None: + # Arrange + account = _account(email="user@example.com") + mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6]) + mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1") + mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") + + # Act + result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") + + # Assert + assert result == "token-1" + mock_generate_token.assert_called_once() + assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"} + mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456") + + +def test_send_email_code_login_email_should_send_mail_for_email_without_account( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0]) + mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2") + mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") + + # Act + result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans") + + # Assert + assert result == "token-2" + mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000") + + +def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None: + # Arrange + mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"}) + + # Act + result = WebAppAuthService.get_email_code_login_data("token-abc") + + # Assert + assert result == {"code": "123"} + mock_get_data.assert_called_once_with("token-abc", "email_code_login") + + +def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None: + # Arrange + mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token") + + # Act + WebAppAuthService.revoke_email_code_login_token("token-xyz") + + # Assert + mock_revoke.assert_called_once_with("token-xyz", "email_code_login") + + +def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(NotFound, match="Site not found"): + WebAppAuthService.create_end_user("app-code", "user@example.com") + + +def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None: + # Arrange + site = SimpleNamespace(app_id="app-1") + app_query = MagicMock() + app_query.where.return_value.first.return_value = None + mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None] + + # Act + Assert + with pytest.raises(NotFound, match="App not found"): + WebAppAuthService.create_end_user("app-code", "user@example.com") + + +def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None: + # Arrange + site = SimpleNamespace(app_id="app-1") + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model] + + # Act + result = WebAppAuthService.create_end_user("app-code", "user@example.com") + + # Assert + assert result.tenant_id == "tenant-1" + assert result.app_id == "app-1" + assert result.session_id == "user@example.com" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None: + # Arrange + account = _account(id="a1", email="user@example.com") + mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60) + mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1") + + # Act + token = WebAppAuthService._get_account_jwt_token(account) + + # Assert + assert token == "jwt-1" + payload = mock_issue.call_args.args[0] + assert payload["user_id"] == "a1" + assert payload["session_id"] == "user@example.com" + assert payload["token_source"] == "webapp_login_token" + assert payload["auth_type"] == "internal" + assert payload["exp"] > int(datetime.now(UTC).timestamp()) + + +@pytest.mark.parametrize( + ("access_mode", "expected"), + [ + ("private", True), + ("private_all", True), + ("public", False), + ], +) +def test_is_app_require_permission_check_should_use_access_mode_when_provided( + access_mode: str, + expected: bool, +) -> None: + # Arrange + # Act + result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode) + + # Assert + assert result is expected + + +def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Either app_code or app_id must be provided"): + WebAppAuthService.is_app_require_permission_check() + + +def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="App ID could not be determined"): + WebAppAuthService.is_app_require_permission_check(app_code="app-code") + + +def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="private"), + ) + + # Act + result = WebAppAuthService.is_app_require_permission_check(app_code="app-code") + + # Assert + assert result is True + + +def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="public"), + ) + + # Act + result = WebAppAuthService.is_app_require_permission_check(app_id="app-1") + + # Assert + assert result is False + + +@pytest.mark.parametrize( + ("access_mode", "expected"), + [ + ("public", WebAppAuthType.PUBLIC), + ("private", WebAppAuthType.INTERNAL), + ("private_all", WebAppAuthType.INTERNAL), + ("sso_verified", WebAppAuthType.EXTERNAL), + ], +) +def test_get_app_auth_type_should_map_access_modes_correctly( + access_mode: str, + expected: WebAppAuthType, +) -> None: + # Arrange + # Act + result = WebAppAuthService.get_app_auth_type(access_mode=access_mode) + + # Assert + assert result == expected + + +def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="private_all"), + ) + + # Act + result = WebAppAuthService.get_app_auth_type(app_code="app-code") + + # Assert + assert result == WebAppAuthType.INTERNAL + + +def test_get_app_auth_type_should_raise_when_no_input_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"): + WebAppAuthService.get_app_auth_type() + + +def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Could not determine app authentication type"): + WebAppAuthService.get_app_auth_type(access_mode="unknown") diff --git a/api/tests/unit_tests/services/test_website_service.py b/api/tests/unit_tests/services/test_website_service.py new file mode 100644 index 0000000000..e973da7d56 --- /dev/null +++ b/api/tests/unit_tests/services/test_website_service.py @@ -0,0 +1,718 @@ +"""Unit tests for services.website_service. + +Focuses on provider dispatching, argument validation, and provider-specific branches +without making any real network/storage/redis calls. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +import services.website_service as website_service_module +from services.website_service import ( + CrawlOptions, + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +@dataclass(frozen=True) +class _DummyHttpxResponse: + payload: dict[str, Any] + + def json(self) -> dict[str, Any]: + return self.payload + + +@pytest.fixture(autouse=True) +def stub_current_user(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module, + "current_user", + type("User", (), {"current_tenant_id": "tenant-1"})(), + ) + + +def test_crawl_options_include_exclude_paths() -> None: + options = CrawlOptions(includes="a,b", excludes="x,y") + assert options.get_include_paths() == ["a", "b"] + assert options.get_exclude_paths() == ["x", "y"] + + empty = CrawlOptions(includes=None, excludes=None) + assert empty.get_include_paths() == [] + assert empty.get_exclude_paths() == [] + + +def test_website_crawl_api_request_from_args_valid_and_to_crawl_request() -> None: + args = { + "provider": "firecrawl", + "url": "https://example.com", + "options": { + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a,b", + "excludes": "x", + "prompt": "hi", + "max_depth": 3, + "use_sitemap": False, + }, + } + + api_req = WebsiteCrawlApiRequest.from_args(args) + crawl_req = api_req.to_crawl_request() + + assert crawl_req.provider == "firecrawl" + assert crawl_req.url == "https://example.com" + assert crawl_req.options.limit == 2 + assert crawl_req.options.crawl_sub_pages is True + assert crawl_req.options.only_main_content is True + assert crawl_req.options.get_include_paths() == ["a", "b"] + assert crawl_req.options.get_exclude_paths() == ["x"] + assert crawl_req.options.prompt == "hi" + assert crawl_req.options.max_depth == 3 + assert crawl_req.options.use_sitemap is False + + +@pytest.mark.parametrize( + ("args", "missing_msg"), + [ + ({}, "Provider is required"), + ({"provider": "firecrawl"}, "URL is required"), + ({"provider": "firecrawl", "url": "https://example.com"}, "Options are required"), + ], +) +def test_website_crawl_api_request_from_args_requires_fields(args: dict, missing_msg: str) -> None: + with pytest.raises(ValueError, match=missing_msg): + WebsiteCrawlApiRequest.from_args(args) + + +def test_website_crawl_status_api_request_from_args_requires_fields() -> None: + with pytest.raises(ValueError, match="Provider is required"): + WebsiteCrawlStatusApiRequest.from_args({}, job_id="job-1") + + with pytest.raises(ValueError, match="Job ID is required"): + WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="") + + req = WebsiteCrawlStatusApiRequest.from_args({"provider": "firecrawl"}, job_id="job-1") + assert req.provider == "firecrawl" + assert req.job_id == "job-1" + + +def test_get_credentials_and_config_selects_plugin_id_and_key_firecrawl(monkeypatch: pytest.MonkeyPatch) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", "firecrawl") + assert api_key == "k" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider="firecrawl", + plugin_id="langgenius/firecrawl_datasource", + ) + + +@pytest.mark.parametrize( + ("provider", "plugin_id"), + [ + ("watercrawl", "watercrawl/watercrawl_datasource"), + ("jinareader", "langgenius/jina_datasource"), + ], +) +def test_get_credentials_and_config_selects_plugin_id_and_key_api_key( + monkeypatch: pytest.MonkeyPatch, provider: str, plugin_id: str +) -> None: + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"api_key": "enc-key", "base_url": "b"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + api_key, config = WebsiteService._get_credentials_and_config("tenant-1", provider) + assert api_key == "enc-key" + assert config["base_url"] == "b" + + service_instance.get_datasource_credentials.assert_called_once_with( + tenant_id="tenant-1", + provider=provider, + plugin_id=plugin_id, + ) + + +def test_get_credentials_and_config_rejects_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", "unknown") + + +def test_get_credentials_and_config_hits_unreachable_guard_branch(monkeypatch: pytest.MonkeyPatch) -> None: + class FlakyProvider: + def __init__(self) -> None: + self._eq_calls = 0 + + def __hash__(self) -> int: + return 1 + + def __eq__(self, other: object) -> bool: + if other == "firecrawl": + self._eq_calls += 1 + return self._eq_calls == 1 + return False + + def __repr__(self) -> str: + return "FlakyProvider()" + + service_instance = MagicMock(name="DatasourceProviderService-instance") + service_instance.get_datasource_credentials.return_value = {"firecrawl_api_key": "k"} + monkeypatch.setattr(website_service_module, "DatasourceProviderService", MagicMock(return_value=service_instance)) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService._get_credentials_and_config("tenant-1", FlakyProvider()) # type: ignore[arg-type] + + +def test_get_decrypted_api_key_requires_api_key(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", MagicMock()) + with pytest.raises(ValueError, match="API key not found in configuration"): + WebsiteService._get_decrypted_api_key("tenant-1", {}) + + +def test_get_decrypted_api_key_decrypts(monkeypatch: pytest.MonkeyPatch) -> None: + decrypt_mock = MagicMock(return_value="plain") + monkeypatch.setattr(website_service_module.encrypter, "decrypt_token", decrypt_mock) + + assert WebsiteService._get_decrypted_api_key("tenant-1", {"api_key": "enc"}) == "plain" + decrypt_mock.assert_called_once_with(tenant_id="tenant-1", token="enc") + + +def test_document_create_args_validate_wraps_error_message() -> None: + with pytest.raises(ValueError, match=r"^Invalid arguments: Provider is required$"): + WebsiteService.document_create_args_validate({}) + + +def test_crawl_url_dispatches_by_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="firecrawl", url="https://example.com", options={"limit": 1}) + crawl_request = api_request.to_crawl_request() + + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_mock = MagicMock(return_value={"status": "active", "job_id": "j1"}) + monkeypatch.setattr(WebsiteService, "_crawl_with_firecrawl", firecrawl_mock) + + result = WebsiteService.crawl_url(api_request) + + assert result == {"status": "active", "job_id": "j1"} + firecrawl_mock.assert_called_once() + assert firecrawl_mock.call_args.kwargs["request"] == crawl_request + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("watercrawl", "_crawl_with_watercrawl"), + ("jinareader", "_crawl_with_jinareader"), + ], +) +def test_crawl_url_dispatches_other_providers(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + api_request = WebsiteCrawlApiRequest(provider=provider, url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + impl_mock = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + assert WebsiteService.crawl_url(api_request) == {"status": "active"} + impl_mock.assert_called_once() + + +def test_crawl_url_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + api_request = WebsiteCrawlApiRequest(provider="bad", url="https://example.com", options={"limit": 1}) + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.crawl_url(api_request) + + +def test_crawl_with_firecrawl_builds_params_single_page_and_sets_redis(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-1" + firecrawl_cls = MagicMock(return_value=firecrawl_instance) + monkeypatch.setattr(website_service_module, "FirecrawlApp", firecrawl_cls) + + redis_mock = MagicMock() + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + fixed_now = datetime(2024, 1, 1, tzinfo=UTC) + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = fixed_now + + req = WebsiteCrawlApiRequest( + provider="firecrawl", url="https://example.com", options={"limit": 5} + ).to_crawl_request() + req.options.crawl_sub_pages = False + req.options.only_main_content = True + + result = WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": "b"}) + + assert result == {"status": "active", "job_id": "job-1"} + + firecrawl_cls.assert_called_once_with(api_key="k", base_url="b") + firecrawl_instance.crawl_url.assert_called_once() + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["limit"] == 1 + assert params["includePaths"] == [] + assert params["excludePaths"] == [] + assert params["scrapeOptions"] == {"onlyMainContent": True} + + redis_mock.setex.assert_called_once() + key, ttl, value = redis_mock.setex.call_args.args + assert key == "website_crawl_job-1" + assert ttl == 3600 + assert float(value) == pytest.approx(fixed_now.timestamp(), rel=0, abs=1e-6) + + +def test_crawl_with_firecrawl_builds_params_multi_page_including_prompt(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock(name="FirecrawlApp-instance") + firecrawl_instance.crawl_url.return_value = "job-2" + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + monkeypatch.setattr(website_service_module, "redis_client", MagicMock()) + + req = WebsiteCrawlApiRequest( + provider="firecrawl", + url="https://example.com", + options={ + "crawl_sub_pages": True, + "limit": 3, + "only_main_content": False, + "includes": "a,b", + "excludes": "x", + "prompt": "use this", + }, + ).to_crawl_request() + + WebsiteService._crawl_with_firecrawl(request=req, api_key="k", config={"base_url": None}) + _, params = firecrawl_instance.crawl_url.call_args.args + assert params["includePaths"] == ["a", "b"] + assert params["excludePaths"] == ["x"] + assert params["limit"] == 3 + assert params["scrapeOptions"] == {"onlyMainContent": False} + assert params["prompt"] == "use this" + + +def test_crawl_with_watercrawl_passes_options_dict(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.crawl_url.return_value = {"status": "active", "job_id": "w1"} + provider_cls = MagicMock(return_value=provider_instance) + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", provider_cls) + + req = WebsiteCrawlApiRequest( + provider="watercrawl", + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ).to_crawl_request() + + result = WebsiteService._crawl_with_watercrawl(request=req, api_key="k", config={"base_url": "b"}) + assert result == {"status": "active", "job_id": "w1"} + + provider_cls.assert_called_once_with(api_key="k", base_url="b") + provider_instance.crawl_url.assert_called_once_with( + url="https://example.com", + options={ + "limit": 2, + "crawl_sub_pages": True, + "only_main_content": True, + "includes": "a", + "excludes": None, + "max_depth": 5, + "use_sitemap": False, + }, + ) + + +def test_crawl_with_jinareader_single_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + get_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"title": "t"}})) + monkeypatch.setattr(website_service_module.httpx, "get", get_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "data": {"title": "t"}} + get_mock.assert_called_once() + + +def test_crawl_with_jinareader_single_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + req = WebsiteCrawlApiRequest( + provider="jinareader", url="https://example.com", options={"crawl_sub_pages": False} + ).to_crawl_request() + req.options.crawl_sub_pages = False + + with pytest.raises(ValueError, match="Failed to crawl:"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_crawl_with_jinareader_multi_page_success(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"taskId": "t1"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 5, "use_sitemap": True}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + result = WebsiteService._crawl_with_jinareader(request=req, api_key="k") + assert result == {"status": "active", "job_id": "t1"} + post_mock.assert_called_once() + + +def test_crawl_with_jinareader_multi_page_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, "post", MagicMock(return_value=_DummyHttpxResponse({"code": 400})) + ) + req = WebsiteCrawlApiRequest( + provider="jinareader", + url="https://example.com", + options={"crawl_sub_pages": True, "limit": 2, "use_sitemap": False}, + ).to_crawl_request() + req.options.crawl_sub_pages = True + + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._crawl_with_jinareader(request=req, api_key="k") + + +def test_get_crawl_status_dispatches(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + firecrawl_status = MagicMock(return_value={"status": "active"}) + monkeypatch.setattr(WebsiteService, "_get_firecrawl_status", firecrawl_status) + + result = WebsiteService.get_crawl_status("job-1", "firecrawl") + assert result == {"status": "active"} + firecrawl_status.assert_called_once_with("job-1", "k", {"base_url": "b"}) + + watercrawl_status = MagicMock(return_value={"status": "active", "job_id": "w"}) + monkeypatch.setattr(WebsiteService, "_get_watercrawl_status", watercrawl_status) + assert WebsiteService.get_crawl_status("job-2", "watercrawl") == {"status": "active", "job_id": "w"} + watercrawl_status.assert_called_once_with("job-2", "k", {"base_url": "b"}) + + jinareader_status = MagicMock(return_value={"status": "active", "job_id": "j"}) + monkeypatch.setattr(WebsiteService, "_get_jinareader_status", jinareader_status) + assert WebsiteService.get_crawl_status("job-3", "jinareader") == {"status": "active", "job_id": "j"} + jinareader_status.assert_called_once_with("job-3", "k") + + +def test_get_crawl_status_typed_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_status_typed(WebsiteCrawlStatusApiRequest(provider="bad", job_id="j")) + + +def test_get_firecrawl_status_adds_time_consuming_when_completed_and_cached(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 2, "current": 2, "data": []} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = b"100.0" + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + with patch.object(website_service_module.datetime, "datetime") as datetime_mock: + datetime_mock.now.return_value = datetime.fromtimestamp(105.0, tz=UTC) + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": "b"}) + + assert result["status"] == "completed" + assert result["time_consuming"] == "5.00" + redis_mock.delete.assert_called_once_with("website_crawl_job-1") + + +def test_get_firecrawl_status_completed_without_cache_does_not_add_time(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "total": 1, "current": 1, "data": []} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + redis_mock = MagicMock() + redis_mock.get.return_value = None + monkeypatch.setattr(website_service_module, "redis_client", redis_mock) + + result = WebsiteService._get_firecrawl_status(job_id="job-1", api_key="k", config={"base_url": None}) + assert result["status"] == "completed" + assert "time_consuming" not in result + redis_mock.delete.assert_not_called() + + +def test_get_watercrawl_status_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_status.return_value = {"status": "active", "job_id": "w1"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + assert WebsiteService._get_watercrawl_status("job-1", "k", {"base_url": "b"}) == { + "status": "active", + "job_id": "w1", + } + provider_instance.get_crawl_status.assert_called_once_with("job-1") + + +def test_get_jinareader_status_active(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock( + return_value=_DummyHttpxResponse( + { + "data": { + "status": "active", + "urls": ["a", "b"], + "processed": {"a": {}}, + "failed": {"b": {}}, + "duration": 3000, + } + } + ) + ) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "active" + assert result["total"] == 2 + assert result["current"] == 2 + assert result["time_consuming"] == 3.0 + assert result["data"] == [] + post_mock.assert_called_once() + + +def test_get_jinareader_status_completed_formats_processed_items(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = { + "data": { + "status": "completed", + "urls": ["u1"], + "processed": {"u1": {}}, + "failed": {}, + "duration": 1000, + } + } + processed_payload = { + "data": { + "processed": { + "u1": { + "data": { + "title": "t", + "url": "u1", + "description": "d", + "content": "md", + } + } + } + } + } + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + result = WebsiteService._get_jinareader_status("job-1", "k") + assert result["status"] == "completed" + assert result["data"] == [{"title": "t", "source_url": "u1", "description": "d", "markdown": "md"}] + assert post_mock.call_count == 2 + + +def test_get_crawl_url_data_dispatches_invalid_provider() -> None: + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", "bad", "https://example.com", "tenant-1") + + +def test_get_crawl_url_data_hits_invalid_provider_branch_when_credentials_stubbed( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {}))) + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_crawl_url_data("job-1", object(), "u", "tenant-1") # type: ignore[arg-type] + + +@pytest.mark.parametrize( + ("provider", "method_name"), + [ + ("firecrawl", "_get_firecrawl_url_data"), + ("watercrawl", "_get_watercrawl_url_data"), + ("jinareader", "_get_jinareader_url_data"), + ], +) +def test_get_crawl_url_data_dispatches(monkeypatch: pytest.MonkeyPatch, provider: str, method_name: str) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + impl_mock = MagicMock(return_value={"ok": True}) + monkeypatch.setattr(WebsiteService, method_name, impl_mock) + + result = WebsiteService.get_crawl_url_data("job-1", provider, "u", "tenant-1") + assert result == {"ok": True} + impl_mock.assert_called_once() + + +def test_get_firecrawl_url_data_reads_from_storage_when_present(monkeypatch: pytest.MonkeyPatch) -> None: + stored_list = [{"source_url": "https://example.com", "title": "t"}] + stored = json.dumps(stored_list).encode("utf-8") + + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = stored + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock()) + + result = WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) + assert result == {"source_url": "https://example.com", "title": "t"} + assert result is not stored_list[0] + + +def test_get_firecrawl_url_data_returns_none_when_storage_empty(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = True + storage_mock.load_once.return_value = b"" + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {}) is None + + +def test_get_firecrawl_url_data_raises_when_job_not_completed(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "active"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + with pytest.raises(ValueError, match="Crawl job is not completed"): + WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": None}) + + +def test_get_firecrawl_url_data_returns_none_when_not_found(monkeypatch: pytest.MonkeyPatch) -> None: + storage_mock = MagicMock() + storage_mock.exists.return_value = False + monkeypatch.setattr(website_service_module, "storage", storage_mock) + + firecrawl_instance = MagicMock() + firecrawl_instance.check_crawl_status.return_value = {"status": "completed", "data": [{"source_url": "x"}]} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + assert WebsiteService._get_firecrawl_url_data("job-1", "https://example.com", "k", {"base_url": "b"}) is None + + +def test_get_watercrawl_url_data_delegates(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.get_crawl_url_data.return_value = {"source_url": "u"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._get_watercrawl_url_data("job-1", "u", "k", {"base_url": "b"}) + assert result == {"source_url": "u"} + provider_instance.get_crawl_url_data.assert_called_once_with("job-1", "u") + + +def test_get_jinareader_url_data_without_job_id_success(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + website_service_module.httpx, + "get", + MagicMock(return_value=_DummyHttpxResponse({"code": 200, "data": {"url": "u"}})), + ) + assert WebsiteService._get_jinareader_url_data("", "u", "k") == {"url": "u"} + + +def test_get_jinareader_url_data_without_job_id_failure(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(website_service_module.httpx, "get", MagicMock(return_value=_DummyHttpxResponse({"code": 500}))) + with pytest.raises(ValueError, match="Failed to crawl$"): + WebsiteService._get_jinareader_url_data("", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_returns_matching_item(monkeypatch: pytest.MonkeyPatch) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "u", "title": "t"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") == {"url": "u", "title": "t"} + assert post_mock.call_count == 2 + + +def test_get_jinareader_url_data_with_job_id_not_completed_raises(monkeypatch: pytest.MonkeyPatch) -> None: + post_mock = MagicMock(return_value=_DummyHttpxResponse({"data": {"status": "active"}})) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + with pytest.raises(ValueError, match=r"Crawl job is no\s*t completed"): + WebsiteService._get_jinareader_url_data("job-1", "u", "k") + + +def test_get_jinareader_url_data_with_job_id_completed_but_not_found_returns_none( + monkeypatch: pytest.MonkeyPatch, +) -> None: + status_payload = {"data": {"status": "completed", "processed": {"u1": {}}}} + processed_payload = {"data": {"processed": {"u1": {"data": {"url": "other"}}}}} + + post_mock = MagicMock(side_effect=[_DummyHttpxResponse(status_payload), _DummyHttpxResponse(processed_payload)]) + monkeypatch.setattr(website_service_module.httpx, "post", post_mock) + + assert WebsiteService._get_jinareader_url_data("job-1", "u", "k") is None + + +def test_get_scrape_url_data_dispatches_and_rejects_invalid_provider(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(WebsiteService, "_get_credentials_and_config", MagicMock(return_value=("k", {"base_url": "b"}))) + + scrape_mock = MagicMock(return_value={"data": "x"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_firecrawl", scrape_mock) + assert WebsiteService.get_scrape_url_data("firecrawl", "u", "tenant-1", True) == {"data": "x"} + scrape_mock.assert_called_once() + + watercrawl_mock = MagicMock(return_value={"data": "y"}) + monkeypatch.setattr(WebsiteService, "_scrape_with_watercrawl", watercrawl_mock) + assert WebsiteService.get_scrape_url_data("watercrawl", "u", "tenant-1", False) == {"data": "y"} + watercrawl_mock.assert_called_once() + + with pytest.raises(ValueError, match="Invalid provider"): + WebsiteService.get_scrape_url_data("jinareader", "u", "tenant-1", True) + + +def test_scrape_with_firecrawl_calls_app(monkeypatch: pytest.MonkeyPatch) -> None: + firecrawl_instance = MagicMock() + firecrawl_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "FirecrawlApp", MagicMock(return_value=firecrawl_instance)) + + result = WebsiteService._scrape_with_firecrawl( + request=website_service_module.ScrapeRequest( + provider="firecrawl", + url="u", + tenant_id="tenant-1", + only_main_content=True, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + firecrawl_instance.scrape_url.assert_called_once_with(url="u", params={"onlyMainContent": True}) + + +def test_scrape_with_watercrawl_calls_provider(monkeypatch: pytest.MonkeyPatch) -> None: + provider_instance = MagicMock() + provider_instance.scrape_url.return_value = {"markdown": "m"} + monkeypatch.setattr(website_service_module, "WaterCrawlProvider", MagicMock(return_value=provider_instance)) + + result = WebsiteService._scrape_with_watercrawl( + request=website_service_module.ScrapeRequest( + provider="watercrawl", + url="u", + tenant_id="tenant-1", + only_main_content=False, + ), + api_key="k", + config={"base_url": "b"}, + ) + assert result == {"markdown": "m"} + provider_instance.scrape_url.assert_called_once_with("u") diff --git a/api/tests/unit_tests/services/test_workflow_app_service.py b/api/tests/unit_tests/services/test_workflow_app_service.py new file mode 100644 index 0000000000..fa76521f2d --- /dev/null +++ b/api/tests/unit_tests/services/test_workflow_app_service.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import json +import uuid +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from dify_graph.enums import WorkflowExecutionStatus +from models import App, WorkflowAppLog +from models.enums import AppTriggerType, CreatorUserRole +from services.workflow_app_service import LogView, WorkflowAppService + + +@pytest.fixture +def service() -> WorkflowAppService: + # Arrange + return WorkflowAppService() + + +@pytest.fixture +def app_model() -> App: + # Arrange + return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + +def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog: + return cast(WorkflowAppLog, SimpleNamespace(**kwargs)) + + +def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None: + # Arrange + log = _workflow_app_log(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + # Act + details = view.details + proxied_status = view.status + + # Assert + assert details == {"trigger_metadata": {"type": "plugin"}} + assert proxied_status == "succeeded" + + +def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + log_1 = SimpleNamespace(id="log-1") + log_2 = SimpleNamespace(id="log-2") + session.scalar.return_value = 3 + session.scalars.return_value.all.return_value = [log_1, log_2] + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + page=1, + limit=2, + detail=False, + ) + + # Assert + assert result["page"] == 1 + assert result["limit"] == 2 + assert result["total"] == 3 + assert result["has_more"] is True + assert len(result["data"]) == 2 + assert isinstance(result["data"][0], LogView) + assert result["data"][0].details is None + + +def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true( + service: WorkflowAppService, + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + session.scalar.side_effect = [1] + log_1 = SimpleNamespace(id="log-1") + session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')] + mock_handle = mocker.patch.object( + service, + "handle_trigger_metadata", + return_value={"type": "trigger_plugin", "icon": "url"}, + ) + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + keyword="run-1", + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_before=None, + created_at_after=None, + page=1, + limit=20, + detail=True, + ) + + # Assert + assert result["total"] == 1 + assert len(result["data"]) == 1 + assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}} + mock_handle.assert_called_once() + + +def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + session.scalar.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Account not found: account@example.com"): + service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + created_by_account="account@example.com", + ) + + +def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0] + session.scalars.return_value.all.return_value = [] + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + created_by_account="account@example.com", + ) + + # Assert + assert result["total"] == 0 + assert result["data"] == [] + + +def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + log_account = SimpleNamespace( + id="log-1", + created_by="acc-1", + created_by_role=CreatorUserRole.ACCOUNT, + workflow_run_summary={"run": "1"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-01", + ) + log_end_user = SimpleNamespace( + id="log-2", + created_by="end-1", + created_by_role=CreatorUserRole.END_USER, + workflow_run_summary={"run": "2"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-02", + ) + log_unknown = SimpleNamespace( + id="log-3", + created_by="other", + created_by_role="system", + workflow_run_summary={"run": "3"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-03", + ) + session.scalar.return_value = 3 + session.scalars.side_effect = [ + SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]), + SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]), + SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]), + ] + + # Act + result = service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=1, + limit=20, + ) + + # Assert + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert result["data"][0]["created_by_account"].id == "acc-1" + assert result["data"][1]["created_by_end_user"].id == "end-1" + assert result["data"][2]["created_by_account"] is None + assert result["data"][2]["created_by_end_user"] is None + + +def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing( + service: WorkflowAppService, +) -> None: + # Arrange + # Act + result = service.handle_trigger_metadata("tenant-1", None) + + # Assert + assert result == {} + + +def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin( + service: WorkflowAppService, + mocker: MockerFixture, +) -> None: + # Arrange + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + mock_icon = mocker.patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) + + # Act + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + # Assert + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + +def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup( + service: WorkflowAppService, + mocker: MockerFixture, +) -> None: + # Arrange + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url") + + # Act + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + # Assert + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], +) +def test_safe_json_loads_should_handle_various_inputs( + value: object, + expected: object, + service: WorkflowAppService, +) -> None: + # Arrange + # Act + result = service._safe_json_loads(value) + + # Assert + assert result == expected + + +def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None: + # Arrange + # Act + short_result = service._safe_parse_uuid("short") + invalid_result = service._safe_parse_uuid("x" * 40) + + # Assert + assert short_result is None + assert invalid_result is None + + +def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None: + # Arrange + raw_uuid = str(uuid.uuid4()) + + # Act + result = service._safe_parse_uuid(raw_uuid) + + # Assert + assert result is not None + assert str(result) == raw_uuid diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 1f92ff590c..27664c7e29 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -16,7 +16,7 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from core.workflow.enums import WorkflowExecutionStatus +from dify_graph.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 3a4f2d392a..d26c2f674f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -10,18 +10,36 @@ This test suite covers: """ import json +import uuid +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + ErrorStrategy, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.variables.input_entities import VariableEntityType from libs.datetime_utils import naive_utc_now +from models.human_input import RecipientType from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from services.workflow_service import WorkflowService +from services.workflow_service import ( + WorkflowService, + _rebuild_file_for_user_inputs_in_start_node, + _rebuild_single_file, + _setup_variable_pool, +) class TestWorkflowAssociatedDataFactory: @@ -134,7 +152,7 @@ class TestWorkflowAssociatedDataFactory: return ( (node["id"], node["data"]) for node in nodes - if node.get("data", {}).get("type") == specific_node_type.value + if node.get("data", {}).get("type") == str(specific_node_type) ) # Return all nodes if no filter specified return ((node["id"], node["data"]) for node in nodes) @@ -179,7 +197,7 @@ class TestWorkflowAssociatedDataFactory: { "id": "start", "data": { - "type": NodeType.START.value, + "type": BuiltinNodeTypes.START, "title": "START", "variables": [], }, @@ -204,7 +222,7 @@ class TestWorkflowAssociatedDataFactory: { "id": "llm-1", "data": { - "type": NodeType.LLM.value, + "type": BuiltinNodeTypes.LLM, "title": "LLM", "model": { "provider": "openai", @@ -544,6 +562,89 @@ class TestWorkflowService: conversation_variables=[], ) + def test_restore_published_workflow_to_draft_keeps_source_features_unmodified( + self, workflow_service, mock_db_session + ): + app = TestWorkflowAssociatedDataFactory.create_app_mock() + account = TestWorkflowAssociatedDataFactory.create_account_mock() + legacy_features = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + normalized_features = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, + } + source_workflow = Workflow( + id="published-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version="2026-03-19T00:00:00", + graph=json.dumps(TestWorkflowAssociatedDataFactory.create_valid_workflow_graph()), + features=json.dumps(legacy_features), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = Workflow( + id="draft-workflow-id", + tenant_id=app.tenant_id, + app_id=app.id, + type=WorkflowType.WORKFLOW.value, + version=Workflow.VERSION_DRAFT, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps({}), + created_by=account.id, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + with ( + patch.object(workflow_service, "get_published_workflow_by_id", return_value=source_workflow), + patch.object(workflow_service, "get_draft_workflow", return_value=draft_workflow), + patch.object(workflow_service, "validate_graph_structure"), + patch.object(workflow_service, "validate_features_structure") as mock_validate_features, + patch("services.workflow_service.app_draft_workflow_was_synced"), + ): + result = workflow_service.restore_published_workflow_to_draft( + app_model=app, + workflow_id=source_workflow.id, + account=account, + ) + + mock_validate_features.assert_called_once_with(app_model=app, features=normalized_features) + assert result is draft_workflow + assert source_workflow.serialized_features == json.dumps(legacy_features) + assert draft_workflow.serialized_features == json.dumps(legacy_features) + mock_db_session.session.commit.assert_called_once() + # ==================== Workflow Validation Tests ==================== # These tests verify graph structure and feature configuration validation @@ -1001,12 +1102,12 @@ class TestWorkflowService: Used by the UI to populate the node palette and provide sensible defaults when users add new nodes to their workflow. """ - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: + with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping: # Mock node class with default config mock_node_class = MagicMock() mock_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [(NodeType.LLM, {"latest": mock_node_class})] + mock_mapping.return_value = {BuiltinNodeTypes.LLM: {"latest": mock_node_class}} with patch("services.workflow_service.LATEST_VERSION", "latest"): result = workflow_service.get_default_block_configs() @@ -1025,7 +1126,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1036,10 +1137,10 @@ class TestWorkflowService: mock_http_node_class.get_default_config.return_value = {"type": "http-request", "config": {}} mock_llm_node_class = MagicMock() mock_llm_node_class.get_default_config.return_value = {"type": "llm", "config": {}} - mock_mapping.items.return_value = [ - (NodeType.HTTP_REQUEST, {"latest": mock_http_node_class}), - (NodeType.LLM, {"latest": mock_llm_node_class}), - ] + mock_mapping.return_value = { + BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_http_node_class}, + BuiltinNodeTypes.LLM: {"latest": mock_llm_node_class}, + } result = workflow_service.get_default_block_configs() @@ -1060,7 +1161,7 @@ class TestWorkflowService: This includes default values for all required and optional parameters. """ with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), ): # Mock node class with default config @@ -1068,23 +1169,21 @@ class TestWorkflowService: mock_config = {"type": "llm", "config": {"provider": "openai"}} mock_node_class.get_default_config.return_value = mock_config - # Create a mock mapping that includes NodeType.LLM - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + # Create a mock mapping that includes BuiltinNodeTypes.LLM + mock_mapping.return_value = {BuiltinNodeTypes.LLM: {"latest": mock_node_class}} - result = workflow_service.get_default_block_config(NodeType.LLM.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.LLM) assert result == mock_config mock_node_class.get_default_config.assert_called_once() def test_get_default_block_config_invalid_node_type(self, workflow_service): """Test get_default_block_config returns empty dict for invalid node type.""" - with patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping: - # Mock mapping to not contain the node type - mock_mapping.__contains__.return_value = False + with patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping: + mock_mapping.return_value = {} # Use a valid NodeType but one that's not in the mapping - result = workflow_service.get_default_block_config(NodeType.LLM.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.LLM) assert result == {} @@ -1100,7 +1199,7 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch( "services.workflow_service.build_http_request_config", @@ -1110,10 +1209,9 @@ class TestWorkflowService: mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_node_class}} - result = workflow_service.get_default_block_config(NodeType.HTTP_REQUEST.value) + result = workflow_service.get_default_block_config(BuiltinNodeTypes.HTTP_REQUEST) assert result == expected mock_build_config.assert_called_once() @@ -1132,18 +1230,17 @@ class TestWorkflowService: ) with ( - patch("services.workflow_service.NODE_TYPE_CLASSES_MAPPING") as mock_mapping, + patch("services.workflow_service.get_node_type_classes_mapping") as mock_mapping, patch("services.workflow_service.LATEST_VERSION", "latest"), patch("services.workflow_service.build_http_request_config") as mock_build_config, ): mock_node_class = MagicMock() expected = {"type": "http-request", "config": {}} mock_node_class.get_default_config.return_value = expected - mock_mapping.__contains__.return_value = True - mock_mapping.__getitem__.return_value = {"latest": mock_node_class} + mock_mapping.return_value = {BuiltinNodeTypes.HTTP_REQUEST: {"latest": mock_node_class}} result = workflow_service.get_default_block_config( - NodeType.HTTP_REQUEST.value, + BuiltinNodeTypes.HTTP_REQUEST, filters={HTTP_REQUEST_CONFIG_FILTER_KEY: provided_config}, ) @@ -1155,14 +1252,14 @@ class TestWorkflowService: def test_get_default_block_config_http_request_malformed_config_raises_value_error(self, workflow_service): with ( patch( - "services.workflow_service.NODE_TYPE_CLASSES_MAPPING", - {NodeType.HTTP_REQUEST: {"latest": HttpRequestNode}}, + "services.workflow_service.get_node_type_classes_mapping", + return_value={BuiltinNodeTypes.HTTP_REQUEST: {"latest": HttpRequestNode}}, ), patch("services.workflow_service.LATEST_VERSION", "latest"), ): with pytest.raises(ValueError, match="http_request_config must be an HttpRequestNodeConfig instance"): workflow_service.get_default_block_config( - NodeType.HTTP_REQUEST.value, + BuiltinNodeTypes.HTTP_REQUEST, filters={HTTP_REQUEST_CONFIG_FILTER_KEY: "invalid"}, ) @@ -1230,3 +1327,1416 @@ class TestWorkflowService: with pytest.raises(ValueError, match="not supported convert to workflow"): workflow_service.convert_to_workflow(app, account, args) + + +# =========================================================================== +# TestWorkflowServiceCredentialValidation +# Tests for _validate_workflow_credentials and related private helpers +# =========================================================================== + + +class TestWorkflowServiceCredentialValidation: + """ + Tests for the private credential-validation helpers on WorkflowService. + + These helpers gate `publish_workflow` when `PluginManager` is enabled. + Each test focuses on a distinct branch inside `_validate_workflow_credentials`, + `_validate_llm_model_config`, `_check_default_tool_credential`, and the + load-balancing path. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + @staticmethod + def _make_workflow(nodes: list[dict]) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.tenant_id = "tenant-1" + wf.app_id = "app-1" + wf.graph_dict = {"nodes": nodes} + return wf + + # --- _validate_workflow_credentials: tool node (with credential_id) --- + + def test_validate_workflow_credentials_should_check_tool_credential_when_credential_id_present( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + "credential_id": "cred-123", + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + # Should not raise; mock allows the call + service._validate_workflow_credentials(workflow) + mock_check.assert_called_once() + + def test_validate_workflow_credentials_should_check_default_credential_when_no_credential_id( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + # No credential_id — should fall back to default + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + + # Assert + mock_default.assert_called_once_with("tenant-1", "my-provider") + + def test_validate_workflow_credentials_should_skip_tool_node_without_provider( + self, service: WorkflowService + ) -> None: + """Tool nodes without a provider_id should be silently skipped.""" + # Arrange + nodes = [{"id": "tool-node", "data": {"type": "tool"}}] + workflow = self._make_workflow(nodes) + + # Act + Assert (no error raised) + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + mock_default.assert_not_called() + + def test_validate_workflow_credentials_should_validate_llm_node_with_model_config( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_raise_for_llm_node_missing_model( + self, service: WorkflowService + ) -> None: + """LLM nodes without provider AND name should raise ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": {"type": "llm", "model": {"provider": "openai"}}, # name missing + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with pytest.raises(ValueError, match="Missing provider or model configuration"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_wrap_unexpected_exception_in_value_error( + self, service: WorkflowService + ) -> None: + """Non-ValueError exceptions from validation must be re-raised as ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch.object(service, "_validate_llm_model_config", side_effect=RuntimeError("boom")): + with pytest.raises(ValueError, match="boom"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_validate_agent_node_model(self, service: WorkflowService) -> None: + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {"provider": "openai", "model": "gpt-4"}}, + "tools": {"value": []}, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_validate_agent_tools(self, service: WorkflowService) -> None: + """Each agent tool with a provider should be checked for credential compliance.""" + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {}}, # no model config + "tools": { + "value": [ + {"provider_name": "provider-a", "credential_id": "cred-a"}, + {"provider_name": "provider-b"}, # uses default + ] + }, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check, + patch.object(service, "_check_default_tool_credential") as mock_default, + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_check.assert_called_once() # provider-a has credential_id + mock_default.assert_called_once_with("tenant-1", "provider-b") + + # --- _validate_llm_model_config --- + + def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: + """If ModelManager raises any exception it must be wrapped into ValueError.""" + # Arrange + with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + def test_validate_llm_model_config_success(self, service: WorkflowService) -> None: + """Test success path with ProviderManager and Model entities.""" + mock_model = MagicMock() + mock_model.model = "gpt-4" + mock_model.provider.provider = "openai" + + mock_configs = MagicMock() + mock_configs.get_models.return_value = [mock_model] + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # Assert + mock_model.raise_for_status.assert_called_once() + + def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: + """Test ValueError when model is not found in provider configurations.""" + mock_configs = MagicMock() + mock_configs.get_models.return_value = [] # No models + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + Assert + with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # --- _check_default_tool_credential --- + + def test_check_default_tool_credential_should_silently_pass_when_no_provider_found( + self, service: WorkflowService + ) -> None: + """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" + # Arrange + with patch("services.workflow_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Act + Assert (should NOT raise) + service._check_default_tool_credential("tenant-1", "some-provider") + + def test_check_default_tool_credential_should_raise_when_compliance_fails(self, service: WorkflowService) -> None: + # Arrange + mock_provider = MagicMock() + mock_provider.id = "builtin-cred-id" + with ( + patch("services.workflow_service.db") as mock_db, + patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_provider + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate default credential"): + service._check_default_tool_credential("tenant-1", "some-provider") + + # --- _is_load_balancing_enabled --- + + def test_is_load_balancing_enabled_should_return_false_when_provider_not_found( + self, service: WorkflowService + ) -> None: + # Arrange + with patch("services.workflow_service.db"): + service_instance = WorkflowService() + + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_configs = MagicMock() + mock_configs.get.return_value = None # provider not found + mock_get_configs.return_value = mock_configs + + # Act + result = service_instance._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + def test_is_load_balancing_enabled_should_return_true_when_setting_enabled(self, service: WorkflowService) -> None: + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_provider_config = MagicMock() + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = True + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + mock_configs = MagicMock() + mock_configs.get.return_value = mock_provider_config + mock_get_configs.return_value = mock_configs + + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is True + + def test_is_load_balancing_enabled_should_return_false_on_exception(self, service: WorkflowService) -> None: + """Any exception should be swallowed and return False.""" + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations", side_effect=RuntimeError("db down")): + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + # --- _get_load_balancing_configs --- + + def test_get_load_balancing_configs_should_return_empty_list_on_exception(self, service: WorkflowService) -> None: + """Any exception during LB config retrieval should return an empty list.""" + # Arrange + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=RuntimeError("fail"), + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert + assert result == [] + + def test_get_load_balancing_configs_should_merge_predefined_and_custom(self, service: WorkflowService) -> None: + # Arrange + predefined = [{"credential_id": "cred-a"}, {"credential_id": None}] + custom = [{"credential_id": "cred-b"}] + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=[ + (None, predefined), # first call: predefined-model + (None, custom), # second call: custom-model + ], + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert — only entries with a credential_id should be returned + assert len(result) == 2 + assert all(c["credential_id"] for c in result) + + # --- _validate_load_balancing_credentials --- + + def test_validate_load_balancing_credentials_should_skip_when_no_model_config( + self, service: WorkflowService + ) -> None: + """Missing provider or model in node_data should be a no-op.""" + # Arrange + workflow = self._make_workflow([]) + node_data: dict = {} # no model key + + # Act + Assert (no error expected) + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_skip_when_lb_not_enabled( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + + # Act + Assert (no error expected) + with patch.object(service, "_is_load_balancing_enabled", return_value=False): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_raise_when_compliance_fails( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + lb_configs = [{"credential_id": "cred-lb-1"}] + + # Act + Assert + with ( + patch.object(service, "_is_load_balancing_enabled", return_value=True), + patch.object(service, "_get_load_balancing_configs", return_value=lb_configs), + patch( + "core.helper.credential_utils.check_credential_policy_compliance", + side_effect=Exception("policy violation"), + ), + ): + with pytest.raises(ValueError, match="Invalid load balancing credentials"): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + +# =========================================================================== +# TestWorkflowServiceExecutionHelpers +# Tests for _apply_error_strategy, _populate_execution_result, _execute_node_safely +# =========================================================================== + + +class TestWorkflowServiceExecutionHelpers: + """ + Tests for the private execution-result handling methods: + _apply_error_strategy, _populate_execution_result, _execute_node_safely. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + # --- _apply_error_strategy --- + + def test_apply_error_strategy_should_return_exception_status_noderunresult(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="something went wrong", + error_type="SomeError", + inputs={"x": 1}, + outputs={}, + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + assert result.error == "something went wrong" + assert result.metadata[WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY] == ErrorStrategy.FAIL_BRANCH + + def test_apply_error_strategy_should_include_default_values_for_default_value_strategy( + self, service: WorkflowService + ) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.DEFAULT_VALUE + node.default_value_dict = {"output_key": "fallback"} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="err", + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.outputs.get("output_key") == "fallback" + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + + # --- _populate_execution_result --- + + def test_populate_execution_result_should_set_succeeded_fields_when_run_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"q": "hello"}, + process_data={"steps": 3}, + outputs={"answer": "hi"}, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}, + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert node_execution.outputs == {"answer": "hi"} + assert node_execution.error is None # SUCCEEDED status doesn't set error + + def test_populate_execution_result_should_set_failed_status_and_error_when_not_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + + # Act + service._populate_execution_result(node_execution, None, False, "catastrophic failure") + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_execution.error == "catastrophic failure" + + def test_populate_execution_result_should_set_error_field_for_exception_status( + self, service: WorkflowService + ) -> None: + """A succeeded=True result with EXCEPTION status should still populate the error field.""" + # Arrange + node_execution = MagicMock() + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error="constraint violated", + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.EXCEPTION + assert node_execution.error == "constraint violated" + + # --- _execute_node_safely --- + + def test_execute_node_safely_should_return_succeeded_result_on_happy_path(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_run_result.error = None + + succeeded_event = MagicMock(spec=NodeRunSucceededEvent) + succeeded_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield succeeded_event + + return node, _gen() + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert run_succeeded is True + assert error is None + + def test_execute_node_safely_should_return_failed_result_on_failed_event(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.FAILED + node_run_result.error = "node exploded" + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, _, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert run_succeeded is False + assert error == "node exploded" + + def test_execute_node_safely_should_handle_workflow_node_run_failed_error(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + exc = WorkflowNodeRunFailedError(node, "runtime failure") + + def invoke_fn(): + raise exc + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert out_result is None + assert run_succeeded is False + assert error == "runtime failure" + + def test_execute_node_safely_should_raise_when_no_result_event(self, service: WorkflowService) -> None: + """If the generator produces no NodeRunSucceededEvent/NodeRunFailedEvent, ValueError is expected.""" + # Arrange + node = MagicMock() + node.error_strategy = None + + def invoke_fn(): + def _gen(): + yield from [] + + return node, _gen() + + # Act + Assert + with pytest.raises(ValueError, match="no result returned"): + service._execute_node_safely(invoke_fn) + + # --- _apply_error_strategy with FAIL_BRANCH strategy --- + + def test_execute_node_safely_should_apply_error_strategy_on_failed_status(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + + original_result = MagicMock() + original_result.status = WorkflowNodeExecutionStatus.FAILED + original_result.error = "oops" + original_result.error_type = "ValueError" + original_result.inputs = {} + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = original_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, result, run_succeeded, _ = service._execute_node_safely(invoke_fn) + + # Assert — after applying error strategy status becomes EXCEPTION + assert result is not None + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + # run_succeeded should be True because EXCEPTION is in the succeeded set + assert run_succeeded is True + + +# =========================================================================== +# TestWorkflowServiceGetNodeLastRun +# Tests for get_node_last_run delegation to repository +# =========================================================================== + + +class TestWorkflowServiceGetNodeLastRun: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_node_last_run_should_delegate_to_repository(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "tenant-1" + app.id = "app-1" + workflow = MagicMock(spec=Workflow) + workflow.id = "wf-1" + expected = MagicMock() + + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = expected + + # Act + result = service.get_node_last_run(app, workflow, "node-42") + + # Assert + assert result is expected + service._node_execution_service_repo.get_node_last_execution.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="wf-1", + node_id="node-42", + ) + + def test_get_node_last_run_should_return_none_when_repository_returns_none(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "t" + app.id = "a" + workflow = MagicMock(spec=Workflow) + workflow.id = "w" + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = None + + # Act + result = service.get_node_last_run(app, workflow, "node-x") + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceModuleLevelHelpers +# Tests for module-level helper functions exported from workflow_service +# =========================================================================== + + +class TestSetupVariablePool: + """ + Tests for the module-level `_setup_variable_pool` function. + This helper initialises the VariablePool used for single-step workflow execution. + """ + + def _make_workflow(self, workflow_type: str = WorkflowType.WORKFLOW.value) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.app_id = "app-1" + wf.id = "wf-1" + wf.type = workflow_type + wf.environment_variables = [] + return wf + + def test_setup_variable_pool_should_use_full_system_variables_for_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="hello", + files=[], + user_id="u-1", + user_inputs={"k": "v"}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — VariablePool should be called with a SystemVariable (non-default) + MockPool.assert_called_once() + call_kwargs = MockPool.call_args.kwargs + assert call_kwargs["user_inputs"] == {"k": "v"} + + def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.SystemVariable.default") as mock_default, + ): + _setup_variable_pool( + query="", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.LLM, # not a start/trigger node + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — SystemVariable.default() should be used for non-start nodes + mock_default.assert_called_once() + MockPool.assert_called_once() + + def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( + self, + ) -> None: + """For ADVANCED_CHAT workflows on a START node, query/conversation_id/dialogue_count should be set.""" + from models.workflow import WorkflowType + + # Arrange + workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="what is AI?", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-abc", + conversation_variables=[], + ) + + # Assert — we just verify VariablePool was called (chatflow path executed) + MockPool.assert_called_once() + + +class TestRebuildSingleFile: + """ + Tests for the module-level `_rebuild_single_file` function. + Ensures correct delegation to `build_from_mapping` / `build_from_mappings`. + """ + + def test_rebuild_single_file_should_call_build_from_mapping_for_file_type( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = {"url": "https://example.com/file.pdf", "type": "document"} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE) + + # Assert + assert result is mock_file + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_value_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for file object"): + _rebuild_single_file("tenant-1", "not-a-dict", VariableEntityType.FILE) + + def test_rebuild_single_file_should_call_build_from_mappings_for_file_list( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = [{"url": "https://example.com/a.pdf"}, {"url": "https://example.com/b.pdf"}] + mock_files = [MagicMock(), MagicMock()] + + # Act + with patch("services.workflow_service.build_from_mappings", return_value=mock_files) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE_LIST) + + # Assert + assert result is mock_files + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_list_value_not_list( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected list for file list object"): + _rebuild_single_file("tenant-1", "not-a-list", VariableEntityType.FILE_LIST) + + def test_rebuild_single_file_should_return_empty_list_for_empty_file_list( + self, + ) -> None: + # Arrange + Act + result = _rebuild_single_file("tenant-1", [], VariableEntityType.FILE_LIST) + + # Assert + assert result == [] + + def test_rebuild_single_file_should_raise_when_first_element_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for first element"): + _rebuild_single_file("tenant-1", ["not-a-dict"], VariableEntityType.FILE_LIST) + + +class TestRebuildFileForUserInputsInStartNode: + """ + Tests for the module-level `_rebuild_file_for_user_inputs_in_start_node` function. + """ + + def _make_start_node_data(self, variables: list) -> MagicMock: + start_data = MagicMock() + start_data.variables = variables + return start_data + + def _make_variable(self, name: str, var_type: VariableEntityType) -> MagicMock: + var = MagicMock() + var.variable = name + var.type = var_type + return var + + def test_rebuild_should_pass_through_non_file_variables( + self, + ) -> None: + # Arrange + text_var = self._make_variable("query", VariableEntityType.TEXT_INPUT) + start_data = self._make_start_node_data([text_var]) + user_inputs = {"query": "hello world"} + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — non-file inputs are untouched + assert result["query"] == "hello world" + + def test_rebuild_should_rebuild_file_variable( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + file_value = {"url": "https://example.com/file.pdf"} + user_inputs = {"attachment": file_value} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file): + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — the dict value should be replaced by the rebuilt File object + assert result["attachment"] is mock_file + + def test_rebuild_should_skip_variable_not_in_inputs( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + user_inputs: dict = {} # attachment not provided + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — no key should be added for missing inputs + assert "attachment" not in result + + +class TestWorkflowServiceResolveDeliveryMethod: + """ + Tests for the static helper `_resolve_human_input_delivery_method`. + """ + + def _make_method(self, method_id) -> MagicMock: + m = MagicMock() + m.id = method_id + return m + + def test_resolve_delivery_method_should_return_method_when_id_matches(self) -> None: + # Arrange + method_a = self._make_method("method-1") + method_b = self._make_method("method-2") + node_data = MagicMock() + node_data.delivery_methods = [method_a, method_b] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-2" + ) + + # Assert + assert result is method_b + + def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: + # Arrange + method_a = self._make_method("method-1") + node_data = MagicMock() + node_data.delivery_methods = [method_a] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="does-not-exist" + ) + + # Assert + assert result is None + + def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: + # Arrange + node_data = MagicMock() + node_data.delivery_methods = [] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-1" + ) + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceDraftExecution +# Tests for run_draft_workflow_node +# =========================================================================== + + +class TestWorkflowServiceDraftExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_draft_workflow_node_should_execute_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.id = "app-1" + app.tenant_id = "tenant-1" + account = MagicMock() + account.id = "user-1" + + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.id = "wf-1" + draft_workflow.tenant_id = "tenant-1" + draft_workflow.app_id = "app-1" + draft_workflow.graph_dict = {"nodes": []} + + node_id = "start-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.START)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + # Mocking complex dependencies + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.StartNodeData") as mock_start_data, + patch( + "services.workflow_service._rebuild_file_for_user_inputs_in_start_node", + side_effect=lambda **kwargs: kwargs["user_inputs"], + ), + patch("services.workflow_service._setup_variable_pool"), + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory") as mock_repo_factory, + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.START + mock_node.title = "Start Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="start-node", + node_type=BuiltinNodeTypes.START, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + mock_repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = mock_repo + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "start" + mock_execution_record.node_id = "start-node" + mock_execution_record.load_full_outputs.return_value = {} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + result = service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={"key": "val"}, + query="hi", + files=[], + ) + + # Assert + assert result is not None + mock_run.assert_called_once() + mock_repo.save.assert_called_once() + mock_saver_cls.return_value.save.assert_called_once() + + def test_run_draft_workflow_node_should_execute_non_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + account = MagicMock() + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.graph_dict = {"nodes": []} + node_id = "llm-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.LLM)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory"), + patch("services.workflow_service.DraftVariableSaver"), + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.LLM + mock_node.title = "LLM Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "llm" + mock_execution_record.node_id = "llm-node" + mock_execution_record.load_full_outputs.return_value = {"answer": "hello"} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={}, + query="", + files=None, + ) + + # Assert + # For non-start nodes, VariablePool should be initialized with environment_variables + mock_pool_cls.assert_called_once() + args, kwargs = mock_pool_cls.call_args + assert "environment_variables" in kwargs + + +# =========================================================================== +# TestWorkflowServiceHumanInputOperations +# Tests for Human Input related methods +# =========================================================================== + + +class TestWorkflowServiceHumanInputOperations: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_human_input_form_preview_should_raise_if_workflow_not_init(self, service: WorkflowService) -> None: + service.get_draft_workflow = MagicMock(return_value=None) + with pytest.raises(ValueError, match="Workflow not initialized"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_should_raise_if_wrong_node_type(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "llm"}} + service.get_draft_workflow = MagicMock(return_value=draft) + with patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM): + with pytest.raises(ValueError, match="Node type must be human-input"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = { + "id": "node-1", + "data": MagicMock(type=BuiltinNodeTypes.HUMAN_INPUT), + } + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.render_form_content_before_submission.return_value = "rendered" + mock_node.resolve_default_values.return_value = {"def": 1} + mock_node.title = "Form Title" + mock_node.node_data = MagicMock() + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.HumanInputRequired") as mock_required_cls, + ): + service.get_human_input_form_preview(app_model=app_model, account=account, node_id="node-1") + mock_node.render_form_content_before_submission.assert_called_once() + mock_required_cls.return_value.model_dump.assert_called_once() + + def test_submit_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.node_data = MagicMock() + mock_node.node_data.outputs_field_names.return_value = ["field1"] + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.validate_human_input_submission"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + ): + result = service.submit_human_input_form_preview( + app_model=app_model, account=account, node_id="node-1", form_inputs={"field1": "val1"}, action="submit" + ) + assert result["__action_id"] == "submit" + mock_saver_cls.return_value.save.assert_called_once() + + def test_test_human_input_delivery_success(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, + patch("services.workflow_service.apply_debug_email_recipient"), + patch.object(service, "_build_human_input_variable_pool"), + patch.object(service, "_build_human_input_node"), + patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), + patch("services.workflow_service.HumanInputDeliveryTestService") as mock_test_srv, + ): + mock_resolve.return_value = MagicMock() + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="method-1" + ) + mock_test_srv.return_value.send_test.assert_called_once() + + def test_test_human_input_delivery_failure_cases(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method", return_value=None), + ): + with pytest.raises(ValueError, match="Delivery method not found"): + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="none" + ) + + def test_load_email_recipients_parsing_failure(self, service: WorkflowService) -> None: + # Arrange + mock_recipient = MagicMock() + mock_recipient.recipient_payload = "invalid json" + mock_recipient.recipient_type = RecipientType.EMAIL_MEMBER + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.Session") as mock_session_cls, + patch("services.workflow_service.select"), + patch("services.workflow_service.json.loads", side_effect=ValueError("bad json")), + ): + mock_session = mock_session_cls.return_value.__enter__.return_value + # sqlalchemy assertions check for .bind + mock_session.bind = MagicMock() # removed spec=Engine to avoid import issues for now + mock_session.scalars.return_value.all.return_value = [mock_recipient] + + # Act + # _load_email_recipients(form_id: str) is a static method + result = WorkflowService._load_email_recipients("form-1") + + # Assert + assert result == [] # Should fall back to empty list on parsing error + + def test_build_human_input_variable_pool(self, service: WorkflowService) -> None: + workflow = MagicMock() + workflow.environment_variables = [] + workflow.graph_dict = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.HumanInputNode.extract_variable_selector_to_variable_mapping"), + patch("services.workflow_service.load_into_variable_pool"), + patch("services.workflow_service.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + service._build_human_input_variable_pool( + app_model=MagicMock(), workflow=workflow, node_config={}, manual_inputs={}, user_id="user-1" + ) + mock_pool_cls.assert_called_once() + + +# =========================================================================== +# TestWorkflowServiceFreeNodeExecution +# Tests for run_free_workflow_node and handle_single_step_result +# =========================================================================== + + +class TestWorkflowServiceFreeNodeExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_free_workflow_node_success(self, service: WorkflowService) -> None: + node_execution = MagicMock() + with ( + patch.object(service, "_handle_single_step_result", return_value=node_execution), + patch("services.workflow_service.WorkflowEntry.run_free_node"), + ): + result = service.run_free_workflow_node({}, "tenant-1", "user-1", "node-1", {}) + assert result == node_execution + + def test_validate_graph_structure_coexist_error(self, service: WorkflowService) -> None: + graph = { + "nodes": [ + {"data": {"type": "start"}}, + {"data": {"type": "trigger-webhook"}}, # is_trigger_node=True + ] + } + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + service.validate_graph_structure(graph) + + def test_validate_features_structure_success(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "workflow" + features = {} + with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_val: + service.validate_features_structure(app, features) + mock_val.assert_called_once() + + def test_validate_features_structure_invalid_mode(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "invalid" + with pytest.raises(ValueError, match="Invalid app mode"): + service.validate_features_structure(app, {}) + + def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: + with patch( + "dify_graph.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + ): + with pytest.raises(ValueError, match="Invalid HumanInput node data"): + service._validate_human_input_node_data({}) + + def test_rebuild_single_file_unreachable(self) -> None: + # Test line 1523 (unreachable) + with pytest.raises(Exception, match="unreachable"): + _rebuild_single_file("tenant-1", {}, cast(Any, "invalid_type")) + + def test_build_human_input_node(self, service: WorkflowService) -> None: + """Cover _build_human_input_node (lines 1065-1088).""" + workflow = MagicMock() + workflow.id = "wf-1" + workflow.tenant_id = "t-1" + workflow.app_id = "app-1" + account = MagicMock() + account.id = "u-1" + node_config = {"id": "n-1"} + variable_pool = MagicMock() + + with ( + patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.HumanInputNode") as mock_node_cls, + patch("services.workflow_service.HumanInputFormRepositoryImpl"), + ): + node = service._build_human_input_node( + workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool + ) + assert node == mock_node_cls.return_value + mock_node_cls.assert_called_once() diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py new file mode 100644 index 0000000000..9bfd7eb2c5 --- /dev/null +++ b/api/tests/unit_tests/services/test_workspace_service.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from models.account import Tenant + +# --------------------------------------------------------------------------- +# Constants used throughout the tests +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +ACCOUNT_ID = "account-xyz" +FILES_BASE_URL = "https://files.example.com" + +DB_PATH = "services.workspace_service.db" +FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" +TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" +DIFY_CONFIG_PATH = "services.workspace_service.dify_config" +CURRENT_USER_PATH = "services.workspace_service.current_user" +CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" + + +# --------------------------------------------------------------------------- +# Helpers / factories +# --------------------------------------------------------------------------- + + +def _make_tenant( + tenant_id: str = TENANT_ID, + name: str = "My Workspace", + plan: str = "sandbox", + status: str = "active", + custom_config: dict | None = None, +) -> Tenant: + """Create a minimal Tenant-like namespace.""" + return cast( + Tenant, + SimpleNamespace( + id=tenant_id, + name=name, + plan=plan, + status=status, + created_at="2024-01-01T00:00:00Z", + custom_config_dict=custom_config or {}, + ), + ) + + +def _make_feature( + can_replace_logo: bool = False, + next_credit_reset_date: str | None = None, + billing_plan: str = "sandbox", +) -> MagicMock: + """Create a feature namespace matching what FeatureService.get_features returns.""" + feature = MagicMock() + feature.can_replace_logo = can_replace_logo + feature.next_credit_reset_date = next_credit_reset_date + feature.billing.subscription.plan = billing_plan + return feature + + +def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: + pool = MagicMock() + pool.quota_limit = quota_limit + pool.quota_used = quota_used + return pool + + +def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: + return SimpleNamespace(role=role) + + +def _tenant_info(result: object) -> dict[str, Any] | None: + return cast(dict[str, Any] | None, result) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_current_user() -> SimpleNamespace: + """Return a lightweight current_user stand-in.""" + return SimpleNamespace(id=ACCOUNT_ID) + + +@pytest.fixture +def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """ + Patch the common external boundaries used by WorkspaceService.get_tenant_info. + + Returns a dict of named mocks so individual tests can customise them. + """ + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) + mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "SELF_HOSTED" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "has_roles": mock_has_roles, + "config": mock_config, + } + + +# --------------------------------------------------------------------------- +# 1. None Tenant Handling +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: + """get_tenant_info should short-circuit and return None for a falsy tenant.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = None + + # Act + result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) + + # Assert + assert result is None + + +def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: + """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" + from services.workspace_service import WorkspaceService + + # Arrange / Act / Assert + assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# 2. Basic Tenant Info — happy path +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_base_fields( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """get_tenant_info should always return the six base scalar fields.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["id"] == TENANT_ID + assert result["name"] == "My Workspace" + assert result["plan"] == "sandbox" + assert result["status"] == "active" + assert result["created_at"] == "2024-01-01T00:00:00Z" + assert result["trial_end_reason"] is None + + +def test_get_tenant_info_should_populate_role_from_tenant_account_join( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The 'role' field should be taken from TenantAccountJoin, not the default.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["role"] == "admin" + + +def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The service asserts that TenantAccountJoin exists. + Missing join should raise AssertionError. + """ + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = None + tenant = _make_tenant() + + # Act + Assert + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + +# --------------------------------------------------------------------------- +# 3. Logo Customisation +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant( + custom_config={ + "replace_webapp_logo": True, + "remove_webapp_brand": True, + } + ) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" in result + assert result["custom_config"]["remove_webapp_brand"] is True + expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" + assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url + + +def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + +def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config should be absent when can_replace_logo is False.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block is gated on OWNER or ADMIN role.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = False # regular member + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_use_files_url_for_logo_url( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The logo URL should use dify_config.FILES_URL as the base.""" + from services.workspace_service import WorkspaceService + + # Arrange + custom_base = "https://cdn.mycompany.io" + basic_mocks["config"].FILES_URL = custom_base + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + +# --------------------------------------------------------------------------- +# 4. Cloud-Edition Credit Features +# --------------------------------------------------------------------------- + +CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX + + +@pytest.fixture +def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """Patches for CLOUD edition tests, billing plan = professional by default.""" + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch( + FEATURE_SERVICE_PATH, + return_value=_make_feature( + can_replace_logo=False, + next_credit_reset_date="2025-02-01", + billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, + ), + ) + mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "CLOUD" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "config": mock_config, + } + + +def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """next_credit_reset_date should be present in CLOUD edition.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch( + CREDIT_POOL_SERVICE_PATH, + side_effect=[None, None], # both paid and trial pools absent + ) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + +def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=1000, quota_used=200) + mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + +def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """quota_limit == -1 means unlimited; service should still use the paid pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=-1, quota_used=999) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid pool is exhausted (used >= limit), switch to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full + trial_pool = _make_pool(quota_limit=100, quota_used=10) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid_pool is None, fall back to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + trial_pool = _make_pool(quota_limit=50, quota_used=5) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """ + When the subscription plan IS SANDBOX, the paid pool branch is skipped + entirely and we fall back to the trial pool. + """ + from enums.cloud_plan import CloudPlan + from services.workspace_service import WorkspaceService + + # Arrange — override billing plan to SANDBOX + cloud_mocks["get_features"].return_value = _make_feature( + next_credit_reset_date="2025-02-01", + billing_plan=CloudPlan.SANDBOX, + ) + paid_pool = _make_pool(quota_limit=1000, quota_used=0) + trial_pool = _make_pool(quota_limit=200, quota_used=20) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + +def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When both paid and trial pools are absent, trial_credits should not be set.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 5. Self-hosted / Non-Cloud Edition +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + from services.workspace_service import WorkspaceService + + # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 6. DB query integrity +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The DB query for TenantAccountJoin must be scoped to the correct + tenant_id and current_user.id. + """ + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant(tenant_id="my-special-tenant") + mock_current_user = mocker.patch(CURRENT_USER_PATH) + mock_current_user.id = "special-user-id" + + # Act + WorkspaceService.get_tenant_info(tenant) + + # Assert — db.session.query was invoked (at least once) + basic_mocks["db_session"].query.assert_called() diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py new file mode 100644 index 0000000000..ce44818886 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.tools.entities.tool_entities import ApiProviderSchemaType +from services.tools.api_tools_manage_service import ApiToolManageService + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: + return SimpleNamespace(operation_id=operation_id) + + +def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), + ) + + # Act + result = ApiToolManageService.parser_api_schema("valid-schema") + + # Assert + assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value + assert len(result["credentials_schema"]) == 3 + assert "warning" in result + + +def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=RuntimeError("bad schema"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): + ApiToolManageService.parser_api_schema("invalid") + + +def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: + # Arrange + expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=expected, + ) + extra_info: dict[str, str] = {} + + # Act + result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) + + # Assert + assert result == expected + + +def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=ValueError("parse failed"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema: parse failed"): + ApiToolManageService.convert_schema_to_tool_bundles("schema") + + +def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="provider provider-a already exists"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name=" provider-a ", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + many_tools = [_tool_bundle(str(i)) for i in range(101)] + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=(many_tools, ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="the number of apis should be less than 100"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_create_provider_when_input_is_valid( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + mock_controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=mock_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.encrypt.return_value = {"auth_type": "none"} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") + + # Act + result = ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=["news"], + ) + + # Assert + assert result == {"result": "success"} + mock_controller.load_bundled_tools.assert_called_once() + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.get", + return_value=SimpleNamespace(status_code=200, text="schema-content"), + ) + mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) + + # Act + result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") + + # Assert + assert result == {"schema": "schema-content"} + + +@pytest.mark.parametrize("status_code", [400, 404, 500]) +def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( + status_code: int, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.get", + return_value=SimpleNamespace(status_code=status_code, text="schema-content"), + ) + mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): + ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") + mock_logger.exception.assert_called_once() + + +def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( + mock_db: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="you have not added provider provider-a"): + ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") + + +def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) + mock_db.session.query.return_value.where.return_value.first.return_value = provider + controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", + return_value=controller, + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) + mock_convert = mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", + side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], + ) + + # Act + result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") + + # Assert + assert result == [{"name": "tool-a"}, {"name": "tool-b"}] + assert mock_convert.call_count == 2 + + +def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( + mock_db: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="api provider provider-a does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + original_provider="provider-a", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy=None, + custom_disclaimer="custom", + labels=[], + ) + + +def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace(credentials={}, name="old") + mock_db.session.query.return_value.where.return_value.first.return_value = provider + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + original_provider="provider-a", + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy=None, + custom_disclaimer="custom", + labels=[], + ) + + +def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace( + credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, + name="old", + icon="", + schema="", + description="", + schema_type_str="", + tools_str="", + privacy_policy="", + custom_disclaimer="", + credentials_str="", + ) + mock_db.session.query.return_value.where.return_value.first.return_value = provider + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=controller, + ) + cache = MagicMock() + encrypter = MagicMock() + encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} + encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} + encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(encrypter, cache), + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") + + # Act + result = ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-new", + original_provider="provider-old", + icon={"emoji": "E"}, + credentials={"auth_type": "none", "api_key_value": "***"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=["news"], + ) + + # Assert + assert result == {"result": "success"} + assert provider.name == "provider-new" + assert provider.privacy_policy == "privacy" + assert provider.credentials_str != "" + cache.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="you have not added provider provider-a"): + ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") + + +def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: + # Arrange + provider = object() + mock_db.session.query.return_value.where.return_value.first.return_value = provider + + # Act + result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") + + # Assert + assert result == {"result": "success"} + mock_db.session.delete.assert_called_once_with(provider) + mock_db.session.commit.assert_called_once() + + +def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: + # Arrange + expected = {"provider": "value"} + mock_get = mocker.patch( + "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", + return_value=expected, + ) + + # Act + result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") + + # Assert + assert result == expected + mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") + + +def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: + # Arrange + schema_type = "bad-schema-type" + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=schema_type, # type: ignore[arg-type] + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=RuntimeError("invalid"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") + + # Act + Assert + with pytest.raises(ValueError, match="invalid tool name tool-b"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-b", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) + mock_db.session.query.return_value.where.return_value.first.return_value = db_provider + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + provider_controller = MagicMock() + tool_obj = MagicMock() + tool_obj.fork_tool_runtime.return_value = tool_obj + tool_obj.validate_credentials.side_effect = ValueError("validation failed") + provider_controller.get_tool.return_value = tool_obj + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=provider_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"auth_type": "none"} + mock_encrypter.mask_plugin_credentials.return_value = {} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + + # Act + result = ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + # Assert + assert result == {"error": "validation failed"} + + +def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) + mock_db.session.query.return_value.where.return_value.first.return_value = db_provider + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + provider_controller = MagicMock() + tool_obj = MagicMock() + tool_obj.fork_tool_runtime.return_value = tool_obj + tool_obj.validate_credentials.return_value = {"ok": True} + provider_controller.get_tool.return_value = tool_obj + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=provider_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"auth_type": "none"} + mock_encrypter.mask_plugin_credentials.return_value = {} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + + # Act + result = ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={"x": "1"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + # Assert + assert result == {"result": {"ok": True}} + + +def test_list_api_tools_should_return_all_user_providers_with_converted_tools( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_one = SimpleNamespace(name="p1") + provider_two = SimpleNamespace(name="p2") + mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] + + controller_one = MagicMock() + controller_one.get_tools.return_value = ["tool-a"] + controller_two = MagicMock() + controller_two.get_tools.return_value = ["tool-b", "tool-c"] + + user_provider_one = SimpleNamespace(labels=[], tools=[]) + user_provider_two = SimpleNamespace(labels=[], tools=[]) + + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", + side_effect=[controller_one, controller_two], + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", + side_effect=[user_provider_one, user_provider_two], + ) + mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") + mock_convert = mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", + side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], + ) + + # Act + result = ApiToolManageService.list_api_tools("tenant-1") + + # Assert + assert len(result) == 2 + assert user_provider_one.tools == [{"name": "tool-a"}] + assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] + assert mock_convert.call_count == 3 diff --git a/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py new file mode 100644 index 0000000000..439d203c58 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_builtin_tools_manage_service.py @@ -0,0 +1,455 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +MODULE = "services.tools.builtin_tools_manage_service" + + +def _mock_session(mock_session_cls): + """Helper: set up a Session context manager mock and return the inner session.""" + session = MagicMock() + mock_session_cls.return_value.__enter__ = MagicMock(return_value=session) + mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) + return session + + +class TestDeleteCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_and_returns_success(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + + result = BuiltinToolManageService.delete_custom_oauth_client_params("tenant-1", "google") + + assert result == {"result": "success"} + session.query.return_value.filter_by.return_value.delete.assert_called_once() + session.commit.assert_called_once() + + +class TestListBuiltinToolProviderTools: + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_transforms_each_tool(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [MagicMock(), MagicMock()] + mock_manager.get_builtin_provider.return_value = mock_controller + mock_transform.convert_tool_entity_to_api_entity.return_value = MagicMock() + + result = BuiltinToolManageService.list_builtin_tool_provider_tools("tenant-1", "google") + + assert len(result) == 2 + + @patch(f"{MODULE}.ToolLabelManager") + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.ToolManager") + def test_empty_tools(self, mock_manager, mock_transform, mock_labels): + mock_controller = MagicMock() + mock_controller.get_tools.return_value = [] + mock_manager.get_builtin_provider.return_value = mock_controller + + assert BuiltinToolManageService.list_builtin_tool_provider_tools("t", "p") == [] + + +class TestGetBuiltinToolProviderInfo: + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_raises_when_not_found(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.get_builtin_tool_provider_info("t", "no") + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_provider") + @patch(f"{MODULE}.ToolManager") + def test_clears_original_credentials(self, mock_manager, mock_get, mock_transform): + mock_get.return_value = MagicMock() + entity = MagicMock() + mock_transform.builtin_provider_to_user_provider.return_value = entity + + result = BuiltinToolManageService.get_builtin_tool_provider_info("t", "google") + + assert result.original_credentials == {} + + +class TestListBuiltinProviderCredentialsSchema: + @patch(f"{MODULE}.ToolManager") + def test_returns_schema(self, mock_manager): + mock_manager.get_builtin_provider.return_value.get_credentials_schema_by_type.return_value = [{"f": "k"}] + + result = BuiltinToolManageService.list_builtin_provider_credentials_schema("g", "api_key", "t") + + assert result == [{"f": "k"}] + + +class TestGetBuiltinToolProviderIcon: + @patch(f"{MODULE}.Path") + @patch(f"{MODULE}.ToolManager") + def test_returns_bytes_and_mime(self, mock_manager, mock_path): + mock_manager.get_hardcoded_provider_icon.return_value = ("/icon.svg", "image/svg+xml") + mock_path.return_value.read_bytes.return_value = b"" + + icon, mime = BuiltinToolManageService.get_builtin_tool_provider_icon("google") + + assert icon == b"" + assert mime == "image/svg+xml" + + +class TestIsOauthSystemClientExists: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock() + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_missing(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_system_client_exists("google") is False + + +class TestIsOauthCustomClientEnabled: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_true_when_enabled(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = MagicMock(enabled=True) + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is True + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_false_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + assert BuiltinToolManageService.is_oauth_custom_client_enabled("t", "g") is False + + +class TestDeleteBuiltinToolProvider: + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "id") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_deletes_provider_and_clears_cache(self, mock_db, mock_session_cls, mock_tm, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock() + session.query.return_value.where.return_value.first.return_value = db_provider + mock_cache = MagicMock() + mock_enc.return_value = (MagicMock(), mock_cache) + + result = BuiltinToolManageService.delete_builtin_tool_provider("t", "p", "c") + + assert result == {"result": "success"} + session.delete.assert_called_once_with(db_provider) + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestSetDefaultProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_not_found(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="provider not found"): + BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_sets_default_and_clears_old(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + target = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = target + + result = BuiltinToolManageService.set_default_provider("t", "u", "p", "id") + + assert result == {"result": "success"} + assert target.is_default is True + session.commit.assert_called_once() + + +class TestUpdateBuiltinToolProvider: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_raises_when_provider_not_exists(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(ValueError, match="you have not added provider"): + BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c") + + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.CredentialType") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_updates_credentials_and_commits(self, mock_db, mock_session_cls, mock_tm, mock_cred_type, mock_enc): + session = _mock_session(mock_session_cls) + db_provider = MagicMock(credential_type="api_key", credentials="{}") + session.query.return_value.where.return_value.first.return_value = db_provider + + mock_cred_instance = MagicMock() + mock_cred_instance.is_editable.return_value = True + mock_cred_instance.is_validate_allowed.return_value = False + mock_cred_type.of.return_value = mock_cred_instance + + mock_controller = MagicMock(need_credentials=True) + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "old"} + mock_encrypter.encrypt.return_value = {"key": "new"} + mock_cache = MagicMock() + mock_enc.return_value = (mock_encrypter, mock_cache) + + result = BuiltinToolManageService.update_builtin_tool_provider("u", "t", "p", "c", credentials={"key": "val"}) + + assert result == {"result": "success"} + session.commit.assert_called_once() + mock_cache.delete.assert_called_once() + + +class TestGetOauthClientSchema: + @patch(f"{MODULE}.BuiltinToolManageService.get_custom_oauth_client_params", return_value={}) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_system_client_exists", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=True) + @patch(f"{MODULE}.dify_config") + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.ToolManager") + def test_returns_schema_dict(self, mock_tm, mock_plugin, mock_config, mock_enabled, mock_sys, mock_params): + mock_config.CONSOLE_API_URL = "https://api.example.com" + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + result = BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema("t", "google") + + assert "schema" in result + assert result["is_oauth_custom_client_enabled"] is True + assert "redirect_uri" in result + + +class TestGetOauthClient: + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_user_client_params_when_exists( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"client_id": "id", "client_secret": "secret"} + mock_create_enc.return_value = (mock_encrypter, MagicMock()) + + user_client = MagicMock(oauth_params='{"encrypted": "data"}') + session.query.return_value.filter_by.return_value.first.return_value = user_client + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"client_id": "id", "client_secret": "secret"} + + @patch(f"{MODULE}.decrypt_system_oauth_params", return_value={"sys_key": "sys_val"}) + @patch(f"{MODULE}.PluginService") + @patch(f"{MODULE}.create_provider_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_to_system_client( + self, mock_db, mock_session_cls, mock_tm, mock_create_enc, mock_plugin, mock_decrypt + ): + session = _mock_session(mock_session_cls) + mock_controller = MagicMock() + mock_controller.get_oauth_client_schema.return_value = [] + mock_tm.get_builtin_provider.return_value = mock_controller + + mock_create_enc.return_value = (MagicMock(), MagicMock()) + + system_client = MagicMock(encrypted_oauth_params="enc") + session.query.return_value.filter_by.return_value.first.side_effect = [ + None, # user client + system_client, # system client + ] + + result = BuiltinToolManageService.get_oauth_client("t", "google") + + assert result == {"sys_key": "sys_val"} + + +class TestSaveCustomOauthClientParams: + def test_returns_early_when_no_params(self): + result = BuiltinToolManageService.save_custom_oauth_client_params("t", "p") + assert result == {"result": "success"} + + @patch(f"{MODULE}.ToolManager") + def test_raises_when_provider_not_found(self, mock_tm): + mock_tm.get_builtin_provider.return_value = None + + with pytest.raises((ValueError, Exception), match="not found|Provider"): + BuiltinToolManageService.save_custom_oauth_client_params("t", "p", enable_oauth_custom_client=True) + + +class TestGetCustomOauthClientParams: + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_empty_when_none(self, mock_db, mock_session_cls): + session = _mock_session(mock_session_cls) + session.query.return_value.filter_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_custom_oauth_client_params("t", "p") + + assert result == {} + + +class TestGetBuiltinToolProviderCredentialInfo: + @patch(f"{MODULE}.BuiltinToolManageService.is_oauth_custom_client_enabled", return_value=False) + @patch(f"{MODULE}.BuiltinToolManageService.get_builtin_tool_provider_credentials", return_value=[]) + @patch(f"{MODULE}.ToolManager") + def test_returns_credential_info(self, mock_tm, mock_creds, mock_oauth): + mock_tm.get_builtin_provider.return_value.get_supported_credential_types.return_value = ["api-key"] + + result = BuiltinToolManageService.get_builtin_tool_provider_credential_info("t", "google") + + assert result.credentials == [] + assert result.supported_credential_types == ["api-key"] + assert result.is_oauth_custom_client_enabled is False + + +class TestGetBuiltinToolProviderCredentials: + @patch(f"{MODULE}.db") + def test_returns_empty_when_no_providers(self, mock_db): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [] + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert result == [] + + @patch(f"{MODULE}.ToolTransformService") + @patch(f"{MODULE}.BuiltinToolManageService.create_tool_encrypter") + @patch(f"{MODULE}.ToolManager") + @patch(f"{MODULE}.db") + def test_returns_credential_entities(self, mock_db, mock_tm, mock_enc, mock_transform): + mock_db.session.no_autoflush.__enter__ = MagicMock(return_value=None) + mock_db.session.no_autoflush.__exit__ = MagicMock(return_value=False) + + provider = MagicMock(provider="google", is_default=False) + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.all.return_value = [provider] + + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"key": "decrypted"} + mock_encrypter.mask_plugin_credentials.return_value = {"key": "***"} + mock_enc.return_value = (mock_encrypter, MagicMock()) + + credential_entity = MagicMock() + mock_transform.convert_builtin_provider_to_credential_entity.return_value = credential_entity + + result = BuiltinToolManageService.get_builtin_tool_provider_credentials("t", "google") + + assert len(result) == 1 + assert result[0] is credential_entity + assert provider.is_default is True + + +class TestGetBuiltinProvider: + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_none_when_not_found(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is None + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.return_value.provider_name = "google" + mock_prov_id.return_value.organization = "langgenius" + db_provider = MagicMock(provider="google") + mock_prov_id_result = MagicMock() + mock_prov_id_result.to_string.return_value = "langgenius/google/google" + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "google" + m.organization = "langgenius" + m.to_string.return_value = "langgenius/google/google" + m.plugin_id = "langgenius/google" + return m + + mock_prov_id.side_effect = prov_id_side_effect + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("google", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_returns_provider_for_non_langgenius_org(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + + def prov_id_side_effect(name): + m = MagicMock() + m.provider_name = "custom-tool" + m.organization = "third-party" + m.to_string.return_value = "third-party/custom/custom-tool" + m.plugin_id = "third-party/custom" + return m + + mock_prov_id.side_effect = prov_id_side_effect + db_provider = MagicMock(provider="third-party/custom/custom-tool") + session.query.return_value.where.return_value.order_by.return_value.first.return_value = db_provider + + result = BuiltinToolManageService.get_builtin_provider("third-party/custom/custom-tool", "t") + + assert result is db_provider + + @patch(f"{MODULE}.ToolProviderID") + @patch(f"{MODULE}.Session") + @patch(f"{MODULE}.db") + def test_falls_back_on_exception(self, mock_db, mock_session_cls, mock_prov_id): + session = _mock_session(mock_session_cls) + mock_prov_id.side_effect = Exception("parse error") + fallback = MagicMock() + session.query.return_value.where.return_value.order_by.return_value.first.return_value = fallback + + result = BuiltinToolManageService.get_builtin_provider("old-provider", "t") + + assert result is fallback diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py new file mode 100644 index 0000000000..d35e014fab --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py @@ -0,0 +1,1045 @@ +from __future__ import annotations + +import hashlib +import json +from datetime import datetime +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.exc import IntegrityError + +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity +from core.mcp.entities import AuthActionType +from core.mcp.error import MCPAuthError, MCPError +from models.tools import MCPToolProvider +from services.tools.mcp_tools_manage_service import ( + EMPTY_CREDENTIALS_JSON, + EMPTY_TOOLS_JSON, + UNCHANGED_SERVER_URL_PLACEHOLDER, + MCPToolManageService, + OAuthDataType, + ProviderUrlValidationData, + ReconnectResult, + ServerUrlValidationResult, +) + + +class _ToolStub: + def __init__(self, name: str, description: str | None) -> None: + self._name = name + self._description = description + + def model_dump(self) -> dict[str, str | None]: + return {"name": self._name, "description": self._description} + + +@pytest.fixture +def mock_session() -> MagicMock: + # Arrange + return MagicMock() + + +@pytest.fixture +def service(mock_session: MagicMock) -> MCPToolManageService: + # Arrange + return MCPToolManageService(session=mock_session) + + +def _provider_entity_stub(*, authed: bool = True) -> MCPProviderEntity: + return cast( + MCPProviderEntity, + SimpleNamespace( + authed=authed, + timeout=30.0, + sse_read_timeout=300.0, + provider_id="server-1", + headers={"x-api-key": "enc"}, + decrypt_headers=lambda: {"x-api-key": "key"}, + retrieve_tokens=lambda: SimpleNamespace(token_type="bearer", access_token="token-1"), + decrypt_server_url=lambda: "https://mcp.example.com/sse", + to_api_response=lambda user_name=None: { + "id": "provider-1", + "author": user_name or "Anonymous", + "name": "MCP Tool", + "description": {"en_US": "", "zh_Hans": ""}, + "icon": "icon", + "label": {"en_US": "MCP Tool", "zh_Hans": "MCP Tool"}, + "type": "mcp", + "is_team_authorization": True, + "server_url": "https://mcp.example.com/******", + "updated_at": 1, + "server_identifier": "server-1", + "configuration": {"timeout": "30", "sse_read_timeout": "300"}, + "masked_headers": {}, + "is_dynamic_registration": True, + }, + decrypt_credentials=lambda: {"client_id": "plain-id", "client_secret": "plain-secret"}, + masked_credentials=lambda: {"client_id": "pl***id", "client_secret": "pl***et"}, + masked_headers=lambda: {"x-api-key": "ke***ey"}, + ), + ) + + +def _provider_stub(*, authed: bool = True) -> MCPToolProvider: + entity = _provider_entity_stub(authed=authed) + return cast( + MCPToolProvider, + SimpleNamespace( + id="provider-1", + tenant_id="tenant-1", + user_id="user-1", + name="Provider A", + server_identifier="server-1", + server_url="encrypted-url", + server_url_hash="old-hash", + authed=authed, + tools=EMPTY_TOOLS_JSON, + encrypted_credentials=json.dumps({"existing": "credential"}), + encrypted_headers=json.dumps({"x-api-key": "enc"}), + credentials={"existing": "credential"}, + timeout=30.0, + sse_read_timeout=300.0, + updated_at=datetime.now(), + icon="icon", + to_entity=lambda: entity, + load_user=lambda: SimpleNamespace(name="Tester"), + ), + ) + + +def test_server_url_validation_result_should_update_server_url_when_all_conditions_match() -> None: + # Arrange + result = ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}"), + ) + + # Act + should_update = result.should_update_server_url + + # Assert + assert should_update is True + + +def test_get_provider_should_return_provider_when_exists( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + provider = _provider_stub() + mock_session.scalar.return_value = provider + + # Act + result = service.get_provider(provider_id="provider-1", tenant_id="tenant-1") + + # Assert + assert result is provider + + +def test_get_provider_should_raise_error_when_provider_not_found( + service: MCPToolManageService, mock_session: MagicMock +) -> None: + # Arrange + mock_session.scalar.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool not found"): + service.get_provider(provider_id="provider-404", tenant_id="tenant-1") + + +def test_get_provider_entity_should_get_entity_by_provider_id_when_by_server_id_is_false( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_entity("provider-1", "tenant-1", by_server_id=False) + + # Assert + assert result is provider.to_entity() + mock_get_provider.assert_called_once_with(provider_id="provider-1", tenant_id="tenant-1") + + +def test_get_provider_entity_should_get_entity_by_server_identifier_when_by_server_id_is_true( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_entity("server-1", "tenant-1", by_server_id=True) + + # Assert + assert result is provider.to_entity() + mock_get_provider.assert_called_once_with(server_identifier="server-1", tenant_id="tenant-1") + + +def test_create_provider_should_raise_error_when_server_url_is_invalid(service: MCPToolManageService) -> None: + # Arrange + config = MCPConfiguration(timeout=30, sse_read_timeout=300) + + # Act + Assert + with pytest.raises(ValueError, match="Server URL is not valid"): + service.create_provider( + tenant_id="tenant-1", + name="Provider A", + server_url="invalid-url", + user_id="user-1", + icon="icon", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=config, + ) + + +def test_create_provider_should_create_and_return_user_provider_when_input_is_valid( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + config = MCPConfiguration(timeout=42, sse_read_timeout=123) + auth_data = MCPAuthentication(client_id="client-id", client_secret="secret") + mocker.patch.object(service, "_check_provider_exists") + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="encrypted-url") + mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x":"enc"}') + mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') + mocker.patch.object(service, "_prepare_icon", return_value='{"content":"😀"}') + expected_user_provider = {"id": "provider-1"} + mock_convert = mocker.patch( + "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", + return_value=expected_user_provider, + ) + + # Act + result = service.create_provider( + tenant_id="tenant-1", + name="Provider A", + server_url="https://mcp.example.com", + user_id="user-1", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=config, + authentication=auth_data, + headers={"x-api-key": "v1"}, + ) + + # Assert + assert result == expected_user_provider + mock_session.add.assert_called_once() + mock_session.flush.assert_called_once() + mock_convert.assert_called_once() + + +def test_update_provider_should_raise_error_when_new_name_conflicts( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="New Name", + server_url="https://mcp.example.com", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=MCPConfiguration(), + ) + + +def test_update_provider_should_update_fields_when_input_is_valid( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + validation = ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=ReconnectResult(authed=True, tools='[{"name":"t"}]', encrypted_credentials='{"x":"y"}'), + encrypted_server_url="new-encrypted-url", + server_url_hash="new-hash", + ) + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = None + mocker.patch.object(service, "_prepare_icon", return_value="new-icon") + mocker.patch.object(service, "_process_headers", return_value='{"x":"enc"}') + mocker.patch.object(service, "_process_credentials", return_value='{"client":"enc"}') + + # Act + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="Provider B", + server_url="https://mcp.example.com/new", + icon="😎", + icon_type="emoji", + icon_background="#000", + server_identifier="server-2", + headers={"x-api-key": "v2"}, + configuration=MCPConfiguration(timeout=50, sse_read_timeout=120), + authentication=MCPAuthentication(client_id="new-id", client_secret="new-secret"), + validation_result=validation, + ) + + # Assert + assert provider.name == "Provider B" + assert provider.server_identifier == "server-2" + assert provider.server_url == "new-encrypted-url" + assert provider.server_url_hash == "new-hash" + assert provider.authed is True + assert provider.tools == '[{"name":"t"}]' + assert provider.encrypted_credentials == '{"client":"enc"}' + assert provider.encrypted_headers == '{"x":"enc"}' + assert provider.timeout == 50 + assert provider.sse_read_timeout == 120 + mock_session.flush.assert_called_once() + + +def test_update_provider_should_handle_integrity_error_with_readable_message( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = None + mocker.patch.object(service, "_prepare_icon", return_value="icon") + mock_session.flush.side_effect = IntegrityError("stmt", {}, Exception("unique_mcp_provider_name")) + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool Provider A already exists"): + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="Provider A", + server_url="https://mcp.example.com", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=MCPConfiguration(), + ) + + +def test_delete_provider_should_delete_existing_provider( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + service.delete_provider(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + mock_session.delete.assert_called_once_with(provider) + + +def test_list_providers_should_return_empty_list_when_no_provider_exists( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.scalars.return_value.all.return_value = [] + + # Act + result = service.list_providers(tenant_id="tenant-1") + + # Assert + assert result == [] + + +def test_list_providers_should_convert_all_providers_and_attach_user_names( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_1 = _provider_stub() + provider_2 = _provider_stub() + provider_2.user_id = "user-2" + mock_session.scalars.return_value.all.return_value = [provider_1, provider_2] + mock_session.query.return_value.where.return_value.all.return_value = [ + SimpleNamespace(id="user-1", name="Alice"), + SimpleNamespace(id="user-2", name="Bob"), + ] + mock_convert = mocker.patch( + "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", + side_effect=[{"id": "1"}, {"id": "2"}], + ) + + # Act + result = service.list_providers(tenant_id="tenant-1", for_list=True, include_sensitive=False) + + # Assert + assert result == [{"id": "1"}, {"id": "2"}] + assert mock_convert.call_count == 2 + + +def test_list_provider_tools_should_raise_error_when_provider_is_not_authenticated( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=False) + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + Assert + with pytest.raises(ValueError, match="Please auth the tool first"): + service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + +def test_list_provider_tools_should_raise_error_when_remote_client_fails( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPError("connection failed") + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + Assert + with pytest.raises(ValueError, match="Failed to connect to MCP server"): + service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + +def test_list_provider_tools_should_update_db_and_return_response_on_success( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [ + _ToolStub("tool-a", None), + _ToolStub("tool-b", "desc"), + ] + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) + + # Act + result = service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + assert result.plugin_unique_identifier == "server-1" + assert provider.authed is True + payload = json.loads(provider.tools) + assert payload[0]["description"] == "" + assert payload[1]["description"] == "desc" + mock_session.flush.assert_called_once() + + +def test_update_provider_credentials_should_update_encrypted_credentials_and_auth_state( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + provider.encrypted_credentials = json.dumps({"existing": "value"}) + mocker.patch.object(service, "get_provider", return_value=provider) + mock_controller = MagicMock() + mocker.patch("core.tools.mcp_tool.provider.MCPToolProviderController.from_db", return_value=mock_controller) + mock_encryptor = MagicMock() + mock_encryptor.encrypt.return_value = {"access_token": "encrypted-token"} + mocker.patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter", return_value=mock_encryptor) + + # Act + service.update_provider_credentials( + provider_id="provider-1", + tenant_id="tenant-1", + credentials={"access_token": "plain-token"}, + authed=False, + ) + + # Assert + assert provider.authed is False + assert provider.tools == EMPTY_TOOLS_JSON + assert json.loads(cast(str, provider.encrypted_credentials))["access_token"] == "encrypted-token" + mock_session.flush.assert_called_once() + + +@pytest.mark.parametrize( + ("data_type", "data", "expected_authed"), + [ + (OAuthDataType.TOKENS, {"access_token": "token"}, True), + (OAuthDataType.MIXED, {"access_token": "token"}, True), + (OAuthDataType.MIXED, {"client_id": "id"}, None), + (OAuthDataType.CLIENT_INFO, {"client_id": "id"}, None), + ], +) +def test_save_oauth_data_should_delegate_with_expected_authed_value( + data_type: OAuthDataType, + data: dict[str, str], + expected_authed: bool | None, + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_update = mocker.patch.object(service, "update_provider_credentials") + + # Act + service.save_oauth_data("provider-1", "tenant-1", data, data_type) + + # Assert + assert mock_update.call_args.kwargs["authed"] == expected_authed + + +def test_clear_provider_credentials_should_reset_provider_state( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + service.clear_provider_credentials(provider_id="provider-1", tenant_id="tenant-1") + + # Assert + assert provider.tools == EMPTY_TOOLS_JSON + assert provider.encrypted_credentials == EMPTY_CREDENTIALS_JSON + assert provider.authed is False + + +def test_check_provider_exists_should_raise_different_errors_for_conflicts( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.scalar.return_value = SimpleNamespace( + name="name-a", + server_url_hash="hash-a", + server_identifier="server-a", + ) + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool name-a already exists"): + service._check_provider_exists("tenant-1", "name-a", "hash-b", "server-b") + with pytest.raises(ValueError, match="MCP tool with this server URL already exists"): + service._check_provider_exists("tenant-1", "name-b", "hash-a", "server-b") + with pytest.raises(ValueError, match="MCP tool server-a already exists"): + service._check_provider_exists("tenant-1", "name-b", "hash-b", "server-a") + + +def test_prepare_icon_should_return_json_for_emoji_and_raw_value_for_non_emoji(service: MCPToolManageService) -> None: + # Arrange + # Act + emoji_icon = service._prepare_icon("😀", "emoji", "#fff") + raw_icon = service._prepare_icon("https://icon.png", "file", "#000") + + # Assert + assert json.loads(emoji_icon)["content"] == "😀" + assert raw_icon == "https://icon.png" + + +def test_encrypt_dict_fields_should_encrypt_secret_fields(service: MCPToolManageService, mocker: MockerFixture) -> None: + # Arrange + mock_encryptor = MagicMock() + mock_encryptor.encrypt.return_value = {"Authorization": "enc-token"} + mocker.patch("core.tools.utils.encryption.create_provider_encrypter", return_value=(mock_encryptor, MagicMock())) + + # Act + result = service._encrypt_dict_fields({"Authorization": "token"}, ["Authorization"], "tenant-1") + + # Assert + assert result == {"Authorization": "enc-token"} + + +def test_prepare_encrypted_dict_should_return_json_string(service: MCPToolManageService, mocker: MockerFixture) -> None: + # Arrange + mocker.patch.object(service, "_encrypt_dict_fields", return_value={"x": "enc"}) + + # Act + result = service._prepare_encrypted_dict({"x": "v"}, "tenant-1") + + # Assert + assert result == '{"x": "enc"}' + + +def test_prepare_auth_headers_should_append_authorization_when_tokens_exist(service: MCPToolManageService) -> None: + # Arrange + provider_entity = _provider_entity_stub() + + # Act + headers = service._prepare_auth_headers(provider_entity) + + # Assert + assert headers["Authorization"] == "Bearer token-1" + + +def test_retrieve_remote_mcp_tools_should_return_tools_from_client( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", "desc")] + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + tools = service._retrieve_remote_mcp_tools("https://mcp.example.com", {}, _provider_entity_stub()) + + # Assert + assert len(tools) == 1 + assert tools[0].model_dump()["name"] == "tool-a" + + +def test_execute_auth_actions_should_dispatch_supported_actions( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_save = mocker.patch.object(service, "save_oauth_data") + auth_result = SimpleNamespace( + actions=[ + SimpleNamespace( + action_type=AuthActionType.SAVE_CLIENT_INFO, + data={"client_id": "c1"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_TOKENS, + data={"access_token": "t1"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_CODE_VERIFIER, + data={"code_verifier": "cv"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_TOKENS, + data={"access_token": "skip"}, + provider_id=None, + tenant_id="tenant-1", + ), + ], + response={"ok": "1"}, + ) + + # Act + result = service.execute_auth_actions(auth_result) + + # Assert + assert result == {"ok": "1"} + assert mock_save.call_count == 3 + + +def test_auth_with_actions_should_call_auth_and_execute_actions( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_entity = _provider_entity_stub() + auth_result = SimpleNamespace(actions=[], response={"status": "ok"}) + mocker.patch("services.tools.mcp_tools_manage_service.auth", return_value=auth_result) + mock_execute = mocker.patch.object(service, "execute_auth_actions", return_value={"status": "ok"}) + + # Act + result = service.auth_with_actions(provider_entity=provider_entity, authorization_code="code-1") + + # Assert + assert result == {"status": "ok"} + mock_execute.assert_called_once_with(auth_result) + + +def test_get_provider_for_url_validation_should_return_validation_data( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_for_url_validation(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + assert result.current_server_url_hash == "old-hash" + assert result.headers == {"x-api-key": "enc"} + + +def test_validate_server_url_standalone_should_skip_validation_for_unchanged_placeholder() -> None: + # Arrange + data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=UNCHANGED_SERVER_URL_PLACEHOLDER, + validation_data=data, + ) + + # Assert + assert result.needs_validation is False + + +def test_validate_server_url_standalone_should_raise_error_for_invalid_url() -> None: + # Arrange + data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) + + # Act + Assert + with pytest.raises(ValueError, match="Server URL is not valid"): + MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url="bad-url", + validation_data=data, + ) + + +def test_validate_server_url_standalone_should_return_no_validation_when_hash_unchanged(mocker: MockerFixture) -> None: + # Arrange + url = "https://mcp.example.com" + current_hash = hashlib.sha256(url.encode()).hexdigest() + data = ProviderUrlValidationData(current_server_url_hash=current_hash, headers={}, timeout=30, sse_read_timeout=300) + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-url") + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=url, + validation_data=data, + ) + + # Assert + assert result.needs_validation is False + assert result.encrypted_server_url == "enc-url" + assert result.server_url_hash == current_hash + + +def test_validate_server_url_standalone_should_reconnect_when_url_changes(mocker: MockerFixture) -> None: + # Arrange + url = "https://mcp-new.example.com" + data = ProviderUrlValidationData(current_server_url_hash="old", headers={}, timeout=30, sse_read_timeout=300) + reconnect_result = ReconnectResult(authed=True, tools='[{"name":"x"}]', encrypted_credentials="{}") + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-new") + mock_reconnect = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=reconnect_result) + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=url, + validation_data=data, + ) + + # Assert + assert result.validation_passed is True + assert result.reconnect_result == reconnect_result + mock_reconnect.assert_called_once() + + +def test_reconnect_with_url_should_delegate_to_private_method(mocker: MockerFixture) -> None: + # Arrange + expected = ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}") + mock_delegate = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=expected) + + # Act + result = MCPToolManageService.reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result == expected + mock_delegate.assert_called_once() + + +def test_private_reconnect_with_url_should_return_authed_true_when_connection_succeeds(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", None)] + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + result = MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result.authed is True + assert json.loads(result.tools)[0]["description"] == "" + + +def test_private_reconnect_with_url_should_return_authed_false_on_auth_error(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPAuthError("auth required") + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + result = MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result.authed is False + assert result.tools == EMPTY_TOOLS_JSON + + +def test_private_reconnect_with_url_should_raise_value_error_on_mcp_error(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPError("network failure") + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + Assert + with pytest.raises(ValueError, match="Failed to re-connect MCP server: network failure"): + MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + +def test_build_tool_provider_response_should_build_api_entity_with_tools( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = _provider_stub() + provider_entity = _provider_entity_stub() + tools = [_ToolStub("tool-a", "desc")] + mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) + + # Act + result = service._build_tool_provider_response(db_provider, provider_entity, tools) + + # Assert + assert result.plugin_unique_identifier == "server-1" + assert result.name == "MCP Tool" + + +@pytest.mark.parametrize( + ("orig_message", "expected_error"), + [ + ("unique_mcp_provider_name", "MCP tool name already exists"), + ("unique_mcp_provider_server_url", "MCP tool https://mcp.example.com already exists"), + ("unique_mcp_provider_server_identifier", "MCP tool server-1 already exists"), + ], +) +def test_handle_integrity_error_should_raise_readable_value_errors( + orig_message: str, + expected_error: str, + service: MCPToolManageService, +) -> None: + """Test that known integrity errors raise readable value errors.""" + # Arrange + error = IntegrityError("stmt", {}, Exception(orig_message)) + + # Act + Assert + with pytest.raises(ValueError, match=expected_error): + service._handle_integrity_error(error, "name", "https://mcp.example.com", "server-1") + + +def test_handle_integrity_error_should_reraise_unknown_error(service: MCPToolManageService) -> None: + """Test that unknown integrity errors are re-raised.""" + # Arrange + error = IntegrityError("stmt", {}, Exception("unknown-constraint")) + + # Act + Assert + with pytest.raises(IntegrityError) as exc_info: + service._handle_integrity_error(error, "name", "url", "identifier") + + assert exc_info.value is error + + +@pytest.mark.parametrize( + ("url", "expected"), + [ + ("https://mcp.example.com", True), + ("http://mcp.example.com", True), + ("", False), + ("invalid", False), + ("ftp://mcp.example.com", False), + ], +) +def test_is_valid_url_should_validate_supported_schemes( + url: str, + expected: bool, + service: MCPToolManageService, +) -> None: + # Arrange + # Act + result = service._is_valid_url(url) + + # Assert + assert result is expected + + +def test_update_optional_fields_should_update_only_non_none_values(service: MCPToolManageService) -> None: + # Arrange + provider = _provider_stub() + configuration = MCPConfiguration(timeout=99, sse_read_timeout=300) + + # Act + service._update_optional_fields(provider, configuration) + + # Assert + assert provider.timeout == 99 + assert provider.sse_read_timeout == 300 + + +def test_process_headers_should_return_none_when_empty_headers(service: MCPToolManageService) -> None: + # Arrange + provider = _provider_stub() + + # Act + result = service._process_headers({}, provider, "tenant-1") + + # Assert + assert result is None + + +def test_process_headers_should_merge_and_encrypt_headers( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "_merge_headers_with_masked", return_value={"x-api-key": "plain"}) + mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x-api-key":"enc"}') + + # Act + result = service._process_headers({"x-api-key": "*****"}, provider, "tenant-1") + + # Assert + assert result == '{"x-api-key":"enc"}' + + +def test_process_credentials_should_merge_and_encrypt_credentials( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + authentication = MCPAuthentication(client_id="masked-id", client_secret="masked-secret") + mocker.patch.object(service, "_merge_credentials_with_masked", return_value=("plain-id", "plain-secret")) + mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') + + # Act + result = service._process_credentials(authentication, provider, "tenant-1") + + # Assert + assert result == '{"client_information":{}}' + + +def test_merge_headers_with_masked_should_preserve_original_values_for_unchanged_masked_inputs( + service: MCPToolManageService, +) -> None: + # Arrange + provider = _provider_stub() + incoming_headers = {"x-api-key": "ke***ey", "new-header": "new-value", "dropped": "*****"} + + # Act + result = service._merge_headers_with_masked(incoming_headers, provider) + + # Assert + assert result["x-api-key"] == "key" + assert result["new-header"] == "new-value" + assert result["dropped"] == "*****" + + +def test_merge_credentials_with_masked_should_preserve_decrypted_values_when_masked_match( + service: MCPToolManageService, +) -> None: + # Arrange + provider = _provider_stub() + + # Act + client_id, client_secret = service._merge_credentials_with_masked("pl***id", "pl***et", provider) + + # Assert + assert client_id == "plain-id" + assert client_secret == "plain-secret" + + +def test_build_and_encrypt_credentials_should_encrypt_secret_when_client_secret_present( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch.object( + service, + "_encrypt_dict_fields", + return_value={ + "client_id": "id", + "client_name": "Dify", + "is_dynamic_registration": False, + "encrypted_client_secret": "enc-secret", + }, + ) + + # Act + result = service._build_and_encrypt_credentials("id", "secret", "tenant-1") + + # Assert + payload = json.loads(result) + assert payload["client_information"]["encrypted_client_secret"] == "enc-secret" + + +def test_build_and_encrypt_credentials_should_skip_secret_field_when_client_secret_is_none( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch.object( + service, + "_encrypt_dict_fields", + return_value={"client_id": "id", "client_name": "Dify", "is_dynamic_registration": False}, + ) + + # Act + result = service._build_and_encrypt_credentials("id", None, "tenant-1") + + # Assert + payload = json.loads(result) + assert "encrypted_client_secret" not in payload["client_information"] diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py index 7511fd6f0c..9537d207f0 100644 --- a/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_transform.py @@ -7,7 +7,7 @@ import pytest from core.mcp.types import Tool as MCPTool from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -175,6 +175,137 @@ class TestMCPToolTransform: # The actual parameter conversion is handled by convert_mcp_schema_to_parameter # which should be tested separately + def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self): + """Nullable object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "anyOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self): + """Nullable oneOf object schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "oneOf": [{"type": "object"}, {"type": "null"}], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_handles_null_type(self): + """Schemas with only a null type should fall back to string.""" + schema = { + "type": "object", + "properties": { + "null_prop_str": {"type": "null"}, + "null_prop_list": {"type": ["null"]}, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 2 + param_map = {parameter.name: parameter for parameter in result} + assert "null_prop_str" in param_map + assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING + assert "null_prop_list" in param_map + assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self): + """Property-level allOf with multiple object items should still resolve to object.""" + schema = { + "type": "object", + "properties": { + "config": { + "allOf": [ + { + "type": "object", + "properties": { + "enabled": {"type": "boolean"}, + }, + "required": ["enabled"], + }, + { + "type": "object", + "properties": { + "priority": {"type": "integer", "minimum": 1, "maximum": 10}, + }, + "required": ["priority"], + }, + ], + "description": "Config must match all schemas (allOf)", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "config" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["config"] + + def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self): + """Composed property schemas should keep the object parameter type.""" + schema = { + "type": "object", + "properties": { + "retrieval_model": { + "allOf": [ + {"type": "object"}, + {"properties": {"top_k": {"type": "integer"}}}, + ], + "description": "检索模型配置", + } + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "retrieval_model" + assert result[0].type == ToolParameter.ToolParameterType.OBJECT + assert result[0].input_schema == schema["properties"]["retrieval_model"] + + def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self): + """Self-referential composed schemas should stop resolving after the configured max depth.""" + recursive_property: dict[str, object] = {"description": "Recursive schema"} + recursive_property["anyOf"] = [recursive_property] + schema = { + "type": "object", + "properties": { + "recursive_config": recursive_property, + }, + } + + result = ToolTransformService.convert_mcp_schema_to_parameter(schema) + + assert len(result) == 1 + assert result[0].name == "recursive_config" + assert result[0].type == ToolParameter.ToolParameterType.STRING + assert result[0].input_schema is None + def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full): """Test mcp_provider_to_user_provider with for_list=True.""" # Set tools data with null description diff --git a/api/tests/unit_tests/services/tools/test_tool_labels_service.py b/api/tests/unit_tests/services/tools/test_tool_labels_service.py new file mode 100644 index 0000000000..6acdbb7901 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tool_labels_service.py @@ -0,0 +1,21 @@ +from services.tools.tool_labels_service import ToolLabelsService + + +def test_list_tool_labels_returns_default_labels(): + result = ToolLabelsService.list_tool_labels() + assert isinstance(result, list) + assert len(result) > 0 + + +def test_list_tool_labels_items_are_tool_labels(): + from core.tools.entities.tool_entities import ToolLabel + + result = ToolLabelsService.list_tool_labels() + for label in result: + assert isinstance(label, ToolLabel) + + +def test_list_tool_labels_matches_default_values(): + from core.tools.entities.values import default_tool_labels + + assert ToolLabelsService.list_tool_labels() is default_tool_labels diff --git a/api/tests/unit_tests/services/tools/test_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_tools_manage_service.py new file mode 100644 index 0000000000..73ac9a10c6 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_tools_manage_service.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +from services.tools.tools_manage_service import ToolCommonService + + +class TestToolCommonService: + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_transforms_and_returns(self, mock_manager, mock_transform): + mock_provider1 = MagicMock() + mock_provider1.to_dict.return_value = {"name": "provider1"} + mock_provider2 = MagicMock() + mock_provider2.to_dict.return_value = {"name": "provider2"} + mock_manager.list_providers_from_api.return_value = [mock_provider1, mock_provider2] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", None) + assert mock_transform.repack_provider.call_count == 2 + assert result == [{"name": "provider1"}, {"name": "provider2"}] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_with_type_filter(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("user-1", "tenant-1", typ="builtin") + + mock_manager.list_providers_from_api.assert_called_once_with("user-1", "tenant-1", "builtin") + assert result == [] + + @patch("services.tools.tools_manage_service.ToolTransformService") + @patch("services.tools.tools_manage_service.ToolManager") + def test_list_tool_providers_empty(self, mock_manager, mock_transform): + mock_manager.list_providers_from_api.return_value = [] + + result = ToolCommonService.list_tool_providers("u", "t") + + assert result == [] + mock_transform.repack_provider.assert_not_called() diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py index ae59da0a3d..e9bcc89445 100644 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py @@ -1,3 +1,9 @@ +""" +Unit tests for services.tools.workflow_tools_manage_service + +Covers WorkflowToolManageService: create, update, list, delete, get, list_single. +""" + import json from types import SimpleNamespace from unittest.mock import MagicMock @@ -9,9 +15,16 @@ from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.model import App from models.tools import WorkflowToolProvider from services.tools import workflow_tools_manage_service +from services.tools.workflow_tools_manage_service import WorkflowToolManageService + +# --------------------------------------------------------------------------- +# Shared helpers / fake infrastructure +# --------------------------------------------------------------------------- class DummyWorkflow: + """Minimal in-memory Workflow substitute.""" + def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: self._graph_dict = graph_dict self.version = version @@ -22,72 +35,42 @@ class DummyWorkflow: class FakeQuery: - def __init__(self, result): + """Chainable query object that always returns a fixed result.""" + + def __init__(self, result: object) -> None: self._result = result - def where(self, *args, **kwargs): + def where(self, *args: object, **kwargs: object) -> "FakeQuery": return self - def first(self): + def first(self) -> object: return self._result + def delete(self) -> int: + return 1 + class DummySession: + """Minimal SQLAlchemy session substitute.""" + def __init__(self) -> None: - self.added: list[object] = [] + self.added: list[WorkflowToolProvider] = [] + self.committed: bool = False def __enter__(self) -> "DummySession": return self - def __exit__(self, exc_type, exc, tb) -> bool: + def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: return False - def add(self, obj) -> None: + def add(self, obj: WorkflowToolProvider) -> None: self.added.append(obj) - def begin(self): - return DummyBegin(self) + def begin(self) -> "DummySession": + return self - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) + def commit(self) -> None: + self.committed = True def _build_parameters() -> list[WorkflowToolParameterConfiguration]: @@ -96,67 +79,877 @@ def _build_parameters() -> list[WorkflowToolParameterConfiguration]: ] -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) +def _build_fake_db( + *, + existing_tool: WorkflowToolProvider | None = None, + app: object | None = None, + tool_by_id: WorkflowToolProvider | None = None, +) -> tuple[MagicMock, DummySession]: + """ + Build a fake db object plus a DummySession for Session context-manager. - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + query(WorkflowToolProvider) returns existing_tool on first call, + then tool_by_id on subsequent calls (or None if not provided). + query(App) returns app. + """ + call_counts: dict[str, int] = {"wftp": 0} - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() + def query(model: type) -> FakeQuery: + if model is WorkflowToolProvider: + call_counts["wftp"] += 1 + if call_counts["wftp"] == 1: + return FakeQuery(existing_tool) + return FakeQuery(tool_by_id) + if model is App: + return FakeQuery(app) + return FakeQuery(None) - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query, commit=MagicMock()) + dummy_session = DummySession() + return fake_db, dummy_session + + +# --------------------------------------------------------------------------- +# TestCreateWorkflowTool +# --------------------------------------------------------------------------- + + +class TestCreateWorkflowTool: + """Tests for WorkflowToolManageService.create_workflow_tool.""" + + def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Human-input nodes must be rejected before any provider is created.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]}) + app = SimpleNamespace(workflow=workflow) + fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app)) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + + # Act + Assert + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon={"type": "emoji", "emoji": "🔧"}, + description="desc", + parameters=_build_parameters(), + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + mock_from_db.assert_not_called() + + def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Existing provider with same name or app_id raises ValueError.""" + # Arrange + existing = MagicMock(spec=WorkflowToolProvider) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(existing)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-1", + name="dup", + label="Dup", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the referenced App does not exist.""" + # Arrange + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(None) # App returns None + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="missing-app", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the App has no attached Workflow.""" + # Arrange + app_no_workflow = SimpleNamespace(workflow=None) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app_no_workflow) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr( + workflow_tools_manage_service.WorkflowToolProviderController, + "from_db", + MagicMock(side_effect=RuntimeError("bad config")), + ) + + # Act + Assert + with pytest.raises(ValueError, match="bad config"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: provider is added to session and success dict is returned.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0") + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + + icon = {"type": "emoji", "emoji": "🔧"} + + # Act + result = WorkflowToolManageService.create_workflow_tool( user_id="user-id", tenant_id="tenant-id", workflow_app_id="app-id", name="tool_name", label="Tool", - icon={"type": "emoji", "emoji": "tool"}, + icon=icon, description="desc", parameters=_build_parameters(), ) - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() + # Assert + assert result == {"result": "success"} + assert len(dummy_session.added) == 1 + created: WorkflowToolProvider = dummy_session.added[0] + assert created.name == "tool_name" + assert created.label == "Tool" + assert created.icon == json.dumps(icon) + assert created.version == "2.0.0" + + def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Labels are forwarded to ToolLabelManager when provided.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + mock_label_mgr = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) + mock_to_ctrl = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl + ) + + # Act + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + labels=["tag1", "tag2"], + ) + + # Assert + mock_label_mgr.assert_called_once() -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) +# --------------------------------------------------------------------------- +# TestUpdateWorkflowTool +# --------------------------------------------------------------------------- - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) +class TestUpdateWorkflowTool: + """Tests for WorkflowToolManageService.update_workflow_tool.""" - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + def _make_provider(self) -> WorkflowToolProvider: + p = MagicMock(spec=WorkflowToolProvider) + p.app_id = "app-id" + p.tenant_id = "tenant-id" + return p - icon = {"type": "emoji", "emoji": "tool"} + def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None: + """If another tool with the given name already exists, raise ValueError.""" + # Arrange + existing = MagicMock(spec=WorkflowToolProvider) - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) + def query(m: type) -> FakeQuery: + return FakeQuery(existing) - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="dup", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the workflow tool to update does not exist.""" + # Arrange + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + # 1st call: name uniqueness check → None (no duplicate) + # 2nd call: fetch tool by id → None (not found) + return FakeQuery(None) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="missing", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the tool's referenced App has been removed.""" + # Arrange + provider = self._make_provider() + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + # 1st: duplicate name check (None), 2nd: fetch provider + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(None) # App not found + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the App exists but has no Workflow.""" + # Arrange + provider = self._make_provider() + app_no_wf = SimpleNamespace(workflow=None) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app_no_wf) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions from from_db are re-raised as ValueError.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=MagicMock()), + ) + monkeypatch.setattr( + workflow_tools_manage_service.WorkflowToolProviderController, + "from_db", + MagicMock(side_effect=RuntimeError("from_db error")), + ) + + # Act + Assert + with pytest.raises(ValueError, match="from_db error"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: provider fields are updated and session committed.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0") + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + mock_commit = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=mock_commit), + ) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + + icon = {"type": "emoji", "emoji": "🛠"} + + # Act + result = WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="new_name", + label="New Label", + icon=icon, + description="new desc", + parameters=_build_parameters(), + ) + + # Assert + assert result == {"result": "success"} + mock_commit.assert_called_once() + assert provider.name == "new_name" + assert provider.label == "New Label" + assert provider.icon == json.dumps(icon) + assert provider.version == "3.0.0" + + def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Labels are forwarded to ToolLabelManager during update.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=MagicMock()), + ) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + mock_label_mgr = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock() + ) + + # Act + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + labels=["a"], + ) + + # Assert + mock_label_mgr.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestListTenantWorkflowTools +# --------------------------------------------------------------------------- + + +class TestListTenantWorkflowTools: + """Tests for WorkflowToolManageService.list_tenant_workflow_tools.""" + + def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """An empty database yields an empty result list.""" + # Arrange + fake_scalars = MagicMock() + fake_scalars.all.return_value = [] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert + assert result == [] + + def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Providers that fail to load are logged and skipped.""" + # Arrange + good_provider = MagicMock(spec=WorkflowToolProvider) + good_provider.id = "good-id" + good_provider.app_id = "app-good" + bad_provider = MagicMock(spec=WorkflowToolProvider) + bad_provider.id = "bad-id" + bad_provider.app_id = "app-bad" + + fake_scalars = MagicMock() + fake_scalars.all.return_value = [good_provider, bad_provider] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + good_ctrl = MagicMock() + good_ctrl.provider_id = "good-id" + + def to_controller(provider: WorkflowToolProvider) -> MagicMock: + if provider is bad_provider: + raise RuntimeError("broken provider") + return good_ctrl + + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller + ) + mock_get_labels = MagicMock(return_value={}) + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels) + mock_to_user = MagicMock() + mock_to_user.return_value.tools = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user + ) + monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) + mock_get_tools = MagicMock(return_value=[MagicMock()]) + good_ctrl.get_tools = mock_get_tools + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() + ) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert - only good provider contributed + assert len(result) == 1 + + def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None: + """All successfully loaded providers appear in the result.""" + # Arrange + provider = MagicMock(spec=WorkflowToolProvider) + provider.id = "p-1" + provider.app_id = "app-1" + + fake_scalars = MagicMock() + fake_scalars.all.return_value = [provider] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + ctrl = MagicMock() + ctrl.provider_id = "p-1" + ctrl.get_tools.return_value = [MagicMock()] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []}) + ) + user_provider = MagicMock() + user_provider.tools = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_user_provider", + MagicMock(return_value=user_provider), + ) + monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() + ) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert + assert len(result) == 1 + assert result[0] is user_provider + + +# --------------------------------------------------------------------------- +# TestDeleteWorkflowTool +# --------------------------------------------------------------------------- + + +class TestDeleteWorkflowTool: + """Tests for WorkflowToolManageService.delete_workflow_tool.""" + + def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: + """delete_workflow_tool queries, deletes, commits, and returns success.""" + # Arrange + mock_query = MagicMock() + mock_query.where.return_value.delete.return_value = 1 + mock_commit = MagicMock() + fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + + # Act + result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1") + + # Assert + assert result == {"result": "success"} + mock_commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestGetWorkflowToolByToolId / ByAppId +# --------------------------------------------------------------------------- + + +class TestGetWorkflowToolByToolIdAndAppId: + """Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id.""" + + def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises ValueError when no WorkflowToolProvider found by tool id.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing") + + def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises ValueError when no WorkflowToolProvider found by app id.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app") + + +# --------------------------------------------------------------------------- +# TestGetWorkflowTool (private _get_workflow_tool) +# --------------------------------------------------------------------------- + + +class TestGetWorkflowTool: + """Tests for the internal _get_workflow_tool helper.""" + + def test_should_raise_when_db_tool_none(self) -> None: + """_get_workflow_tool raises ValueError when db_tool is None.""" + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService._get_workflow_tool("t", None) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the corresponding App row is missing.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when App has no attached Workflow.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + app = SimpleNamespace(workflow=None) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the controller returns no WorkflowTool instances.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + db_tool.id = "tool-1" + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + ctrl = MagicMock() + ctrl.get_tools.return_value = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: returns a dict with name, label, icon, synced, etc.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + db_tool.id = "tool-1" + db_tool.name = "my_tool" + db_tool.label = "My Tool" + db_tool.icon = json.dumps({"emoji": "🔧"}) + db_tool.description = "some desc" + db_tool.privacy_policy = "" + db_tool.version = "1.0" + db_tool.parameter_configurations = [] + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0") + app = SimpleNamespace(workflow=workflow) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + + workflow_tool = MagicMock() + workflow_tool.entity.output_schema = {"type": "object"} + ctrl = MagicMock() + ctrl.get_tools.return_value = [workflow_tool] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + mock_convert = MagicMock(return_value={"tool": "api_entity"}) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) + ) + + # Act + result = WorkflowToolManageService._get_workflow_tool("t", db_tool) + + # Assert + assert result["name"] == "my_tool" + assert result["label"] == "My Tool" + assert result["synced"] is True + assert "icon" in result + assert "output_schema" in result + + +# --------------------------------------------------------------------------- +# TestListSingleWorkflowTools +# --------------------------------------------------------------------------- + + +class TestListSingleWorkflowTools: + """Tests for WorkflowToolManageService.list_single_workflow_tools.""" + + def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the specified tool does not exist in DB.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the controller yields no tools for the provider.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.id = "tool-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(db_tool)), + ) + ctrl = MagicMock() + ctrl.get_tools.return_value = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: returns list with one ToolApiEntity.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.id = "tool-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(db_tool)), + ) + workflow_tool = MagicMock() + ctrl = MagicMock() + ctrl.get_tools.return_value = [workflow_tool] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + api_entity = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "convert_tool_entity_to_api_entity", + MagicMock(return_value=api_entity), + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) + ) + + # Act + result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + # Assert + assert result == [api_entity] diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index 83642fc209..f3391d6380 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -6,8 +6,8 @@ from unittest.mock import Mock, patch import pytest from sqlalchemy import Engine -from core.workflow.variables.segments import ObjectSegment, StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.variables.segments import ObjectSegment, StringSegment +from dify_graph.variables.types import SegmentType from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -24,7 +24,11 @@ class TestDraftVarLoaderSimple: def draft_var_loader(self, mock_engine): """Create DraftVarLoader instance for testing.""" return DraftVarLoader( - engine=mock_engine, app_id="test-app-id", tenant_id="test-tenant-id", fallback_variables=[] + engine=mock_engine, + app_id="test-app-id", + tenant_id="test-tenant-id", + user_id="test-user-id", + fallback_variables=[], ) def test_load_offloaded_variable_string_type_unit(self, draft_var_loader): @@ -174,7 +178,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.workflow.variables.segments import FloatSegment + from dify_graph.variables.segments import FloatSegment mock_segment = FloatSegment(value=test_number) mock_build_segment.return_value = mock_segment @@ -224,7 +228,7 @@ class TestDraftVarLoaderSimple: mock_storage.load.return_value = test_json_content.encode() with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from core.workflow.variables.segments import ArrayAnySegment + from dify_graph.variables.segments import ArrayAnySegment mock_segment = ArrayAnySegment(value=test_array) mock_build_segment.return_value = mock_segment @@ -323,7 +327,9 @@ class TestDraftVarLoaderSimple: # Verify service method was called mock_service.get_draft_variables_by_selectors.assert_called_once_with( - draft_var_loader._app_id, selectors + draft_var_loader._app_id, + selectors, + user_id=draft_var_loader._user_id, ) # Verify offloaded variable loading was called diff --git a/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py new file mode 100644 index 0000000000..bbfc1cc294 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_queue_dispatcher.py @@ -0,0 +1,110 @@ +from unittest.mock import patch + +import pytest + +from services.workflow.queue_dispatcher import ( + BaseQueueDispatcher, + ProfessionalQueueDispatcher, + QueueDispatcherManager, + QueuePriority, + SandboxQueueDispatcher, + TeamQueueDispatcher, +) + + +class TestQueuePriority: + def test_priority_values(self): + assert QueuePriority.PROFESSIONAL == "workflow_professional" + assert QueuePriority.TEAM == "workflow_team" + assert QueuePriority.SANDBOX == "workflow_sandbox" + + +class TestDispatchers: + def test_professional_dispatcher(self): + d = ProfessionalQueueDispatcher() + assert d.get_queue_name() == QueuePriority.PROFESSIONAL + assert d.get_priority() == 100 + + def test_team_dispatcher(self): + d = TeamQueueDispatcher() + assert d.get_queue_name() == QueuePriority.TEAM + assert d.get_priority() == 50 + + def test_sandbox_dispatcher(self): + d = SandboxQueueDispatcher() + assert d.get_queue_name() == QueuePriority.SANDBOX + assert d.get_priority() == 10 + + def test_base_dispatcher_is_abstract(self): + with pytest.raises(TypeError): + BaseQueueDispatcher() + + +class TestQueueDispatcherManager: + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_professional_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "professional"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, ProfessionalQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_team_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "team"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_sandbox_plan(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "sandbox"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_unknown_plan_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {"subscription": {"plan": "enterprise"}} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_enabled_service_failure_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.side_effect = Exception("billing unavailable") + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.dify_config") + def test_billing_disabled_defaults_to_team(self, mock_config): + mock_config.BILLING_ENABLED = False + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, TeamQueueDispatcher) + + @patch("services.workflow.queue_dispatcher.BillingService") + @patch("services.workflow.queue_dispatcher.dify_config") + def test_missing_subscription_key_defaults_to_sandbox(self, mock_config, mock_billing): + mock_config.BILLING_ENABLED = True + mock_billing.get_info.return_value = {} + + dispatcher = QueueDispatcherManager.get_dispatcher("tenant-1") + + assert isinstance(dispatcher, SandboxQueueDispatcher) diff --git a/api/tests/unit_tests/services/workflow/test_scheduler.py b/api/tests/unit_tests/services/workflow/test_scheduler.py new file mode 100644 index 0000000000..90b6cb2d8b --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_scheduler.py @@ -0,0 +1,89 @@ +import pytest + +from services.workflow.entities import WorkflowScheduleCFSPlanEntity +from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand + + +class TestSchedulerCommand: + def test_enum_values(self): + assert SchedulerCommand.RESOURCE_LIMIT_REACHED == "resource_limit_reached" + assert SchedulerCommand.NONE == "none" + + def test_enum_is_str(self): + for member in SchedulerCommand: + assert isinstance(member, str) + + +class TestCFSPlanScheduler: + def test_stores_plan(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + granularity=-1, + ) + + class ConcretePlanScheduler(CFSPlanScheduler): + def can_schedule(self): + return SchedulerCommand.NONE + + scheduler = ConcretePlanScheduler(plan) + + assert scheduler.plan is plan + assert scheduler.plan.schedule_strategy == WorkflowScheduleCFSPlanEntity.Strategy.Nop + assert scheduler.plan.granularity == -1 + + def test_cannot_instantiate_abstract(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=10, + ) + with pytest.raises(TypeError): + CFSPlanScheduler(plan) + + def test_concrete_subclass_can_schedule(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=5, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.NONE + + def test_concrete_subclass_resource_limit(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=-1, + ) + + class TimedScheduler(CFSPlanScheduler): + def can_schedule(self): + if self.plan.granularity > 0: + return SchedulerCommand.NONE + return SchedulerCommand.RESOURCE_LIMIT_REACHED + + scheduler = TimedScheduler(plan) + assert scheduler.can_schedule() == SchedulerCommand.RESOURCE_LIMIT_REACHED + + +class TestWorkflowScheduleCFSPlanEntity: + def test_strategy_values(self): + assert WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice == "time-slice" + assert WorkflowScheduleCFSPlanEntity.Strategy.Nop == "nop" + + def test_default_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.Nop, + ) + assert plan.granularity == -1 + + def test_explicit_granularity(self): + plan = WorkflowScheduleCFSPlanEntity( + schedule_strategy=WorkflowScheduleCFSPlanEntity.Strategy.TimeSlice, + granularity=100, + ) + assert plan.granularity == 100 diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py index 8ccbfbb16e..a847c2b4d1 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_converter.py @@ -15,9 +15,9 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.helper import encrypter -from core.model_runtime.entities.llm_entities import LLMMode -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.workflow.variables.input_entities import VariableEntity, VariableEntityType +from dify_graph.model_runtime.entities.llm_entities import LLMMode +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import AppMode from services.workflow.workflow_converter import WorkflowConverter diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py deleted file mode 100644 index dfe325648d..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ /dev/null @@ -1,127 +0,0 @@ -from unittest.mock import MagicMock - -import pytest -from sqlalchemy.orm import Session - -from models.model import App -from models.workflow import Workflow -from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService - - -@pytest.fixture -def workflow_setup(): - mock_session_maker = MagicMock() - workflow_service = WorkflowService(mock_session_maker) - session = MagicMock(spec=Session) - tenant_id = "test-tenant-id" - workflow_id = "test-workflow-id" - - # Mock workflow - workflow = MagicMock(spec=Workflow) - workflow.id = workflow_id - workflow.tenant_id = tenant_id - workflow.version = "1.0" # Not a draft - workflow.tool_published = False # Not published as a tool by default - - # Mock app - app = MagicMock(spec=App) - app.id = "test-app-id" - app.name = "Test App" - app.workflow_id = None # Not used by an app by default - - return { - "workflow_service": workflow_service, - "session": session, - "tenant_id": tenant_id, - "workflow_id": workflow_id, - "workflow": workflow, - "app": app, - } - - -def test_delete_workflow_success(workflow_setup): - # Setup mocks - - # Mock the tool provider query to return None (not published as a tool) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = None - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method - result = workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - assert result is True - workflow_setup["session"].delete.assert_called_once_with(workflow_setup["workflow"]) - - -def test_delete_workflow_draft_error(workflow_setup): - # Setup mocks - workflow_setup["workflow"].version = "draft" - workflow_setup["session"].scalar = MagicMock(return_value=workflow_setup["workflow"]) - - # Call the method and verify exception - with pytest.raises(DraftWorkflowDeletionError): - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_in_use_by_app_error(workflow_setup): - # Setup mocks - workflow_setup["app"].workflow_id = workflow_setup["workflow_id"] - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], workflow_setup["app"]] - ) # Return workflow first, then app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message contains app name - assert "Cannot delete workflow that is currently in use by app" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() - - -def test_delete_workflow_published_as_tool_error(workflow_setup): - # Setup mocks - from models.tools import WorkflowToolProvider - - # Mock the tool provider query - mock_tool_provider = MagicMock(spec=WorkflowToolProvider) - workflow_setup["session"].query.return_value.where.return_value.first.return_value = mock_tool_provider - - workflow_setup["session"].scalar = MagicMock( - side_effect=[workflow_setup["workflow"], None] - ) # Return workflow first, then None for app - - # Call the method and verify exception - with pytest.raises(WorkflowInUseError) as excinfo: - workflow_setup["workflow_service"].delete_workflow( - session=workflow_setup["session"], - workflow_id=workflow_setup["workflow_id"], - tenant_id=workflow_setup["tenant_id"], - ) - - # Verify error message - assert "Cannot delete workflow that is published as a tool" in str(excinfo.value) - - # Verify - workflow_setup["session"].delete.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 792257848f..0c2be9c79f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -7,10 +7,10 @@ import pytest from sqlalchemy import Engine from sqlalchemy.orm import Session -from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import NodeType -from core.workflow.variables.segments import StringSegment -from core.workflow.variables.types import SegmentType +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey +from dify_graph.variables.segments import StringSegment +from dify_graph.variables.types import SegmentType from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType @@ -54,12 +54,12 @@ class TestDraftVariableSaver: session=mock_session, app_id=test_app_id, node_id="test_node_id", - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_execution_id="test_execution_id", user=mock_user, ) - assert saver._should_variable_be_visible("123_456", NodeType.IF_ELSE, "output") == False - assert saver._should_variable_be_visible("123", NodeType.START, "output") == True + assert saver._should_variable_be_visible("123_456", BuiltinNodeTypes.IF_ELSE, "output") == False + assert saver._should_variable_be_visible("123", BuiltinNodeTypes.START, "output") == True def test__normalize_variable_for_start_node(self): @dataclasses.dataclass(frozen=True) @@ -102,7 +102,7 @@ class TestDraftVariableSaver: session=mock_session, app_id=test_app_id, node_id=_NODE_ID, - node_type=NodeType.START, + node_type=BuiltinNodeTypes.START, node_execution_id="test_execution_id", user=mock_user, ) @@ -134,7 +134,7 @@ class TestDraftVariableSaver: session=mock_session, app_id="test-app-id", node_id="test-node-id", - node_type=NodeType.LLM, + node_type=BuiltinNodeTypes.LLM, node_execution_id="test-execution-id", user=mock_user, ) @@ -182,6 +182,42 @@ class TestDraftVariableSaver: draft_vars = mock_batch_upsert.call_args[0][1] assert len(draft_vars) == 2 + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True) + def test_start_node_save_persists_sys_timestamp_and_workflow_run_id(self, mock_batch_upsert): + """Start node should persist common `sys.*` variables, not only `sys.files`.""" + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = "test-user-id" + mock_user.tenant_id = "test-tenant-id" + + saver = DraftVariableSaver( + session=mock_session, + app_id="test-app-id", + node_id="start-node-id", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-id", + user=mock_user, + ) + + outputs = { + f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.TIMESTAMP}": 1700000000, + f"{SYSTEM_VARIABLE_NODE_ID}.{SystemVariableKey.WORKFLOW_EXECUTION_ID}": "run-id-123", + } + + saver.save(outputs=outputs) + + mock_batch_upsert.assert_called_once() + draft_vars = mock_batch_upsert.call_args[0][1] + + # plus one dummy output because there are no non-sys Start inputs + assert len(draft_vars) == 3 + + sys_vars = [v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID] + assert {v.name for v in sys_vars} == { + str(SystemVariableKey.TIMESTAMP), + str(SystemVariableKey.WORKFLOW_EXECUTION_ID), + } + class TestWorkflowDraftVariableService: def _get_test_app_id(self): @@ -331,7 +367,7 @@ class TestWorkflowDraftVariableService: mock_node_config = {"type": "test_node"} with ( patch.object(workflow, "get_node_config_by_id", return_value=mock_node_config, autospec=True), - patch.object(workflow, "get_node_type_from_node_config", return_value=NodeType.LLM, autospec=True), + patch.object(workflow, "get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM, autospec=True), ): result = service._reset_node_var_or_sys_var(workflow, variable) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 844dab8976..6c1adba2b8 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -12,9 +12,9 @@ import pytest from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from core.workflow.entities.pause_reason import HumanInputRequired -from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from core.workflow.runtime import GraphRuntimeState, VariablePool +from dify_graph.entities.pause_reason import HumanInputRequired +from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from dify_graph.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index 5ac5ac8ad2..c890ab6a65 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -5,8 +5,9 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.orm import sessionmaker -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import ( +from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, @@ -22,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods): +def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -30,8 +31,8 @@ def _build_node_config(delivery_methods): inputs=[], user_actions=[], ).model_dump(mode="json") - node_data["type"] = NodeType.HUMAN_INPUT.value - return {"id": "node-1", "data": node_data} + node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT + return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod: diff --git a/api/tests/unit_tests/services/workflow/test_workflow_restore.py b/api/tests/unit_tests/services/workflow/test_workflow_restore.py new file mode 100644 index 0000000000..179361de45 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_restore.py @@ -0,0 +1,77 @@ +import json +from types import SimpleNamespace + +from models.workflow import Workflow +from services.workflow_restore import apply_published_workflow_snapshot_to_draft + +LEGACY_FEATURES = { + "file_upload": { + "image": { + "enabled": True, + "number_limits": 6, + "transfer_methods": ["remote_url", "local_file"], + } + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + +NORMALIZED_FEATURES = { + "file_upload": { + "enabled": True, + "allowed_file_types": ["image"], + "allowed_file_extensions": [], + "allowed_file_upload_methods": ["remote_url", "local_file"], + "number_limits": 6, + }, + "opening_statement": "", + "retriever_resource": {"enabled": True}, + "sensitive_word_avoidance": {"enabled": False}, + "speech_to_text": {"enabled": False}, + "suggested_questions": [], + "suggested_questions_after_answer": {"enabled": False}, + "text_to_speech": {"enabled": False, "language": "", "voice": ""}, +} + + +def _create_workflow(*, workflow_id: str, version: str, features: dict[str, object]) -> Workflow: + return Workflow( + id=workflow_id, + tenant_id="tenant-id", + app_id="app-id", + type="workflow", + version=version, + graph=json.dumps({"nodes": [], "edges": []}), + features=json.dumps(features), + created_by="account-id", + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + + +def test_apply_published_workflow_snapshot_to_draft_copies_serialized_features_without_mutating_source() -> None: + source_workflow = _create_workflow( + workflow_id="published-workflow-id", + version="2026-03-19T00:00:00", + features=LEGACY_FEATURES, + ) + + draft_workflow, is_new_draft = apply_published_workflow_snapshot_to_draft( + tenant_id="tenant-id", + app_id="app-id", + source_workflow=source_workflow, + draft_workflow=None, + account=SimpleNamespace(id="account-id"), + updated_at_factory=lambda: source_workflow.updated_at, + ) + + assert is_new_draft is True + assert source_workflow.serialized_features == json.dumps(LEGACY_FEATURES) + assert source_workflow.normalized_features_dict == NORMALIZED_FEATURES + assert draft_workflow.serialized_features == json.dumps(LEGACY_FEATURES) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 015dac257e..538c1b3595 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -4,9 +4,10 @@ from unittest.mock import MagicMock import pytest -from core.workflow.enums import NodeType -from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from core.workflow.nodes.human_input.enums import FormInputType +from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction +from dify_graph.nodes.human_input.enums import FormInputType from models.model import App from models.workflow import Workflow from services import workflow_service as workflow_service_module @@ -40,6 +41,23 @@ class TestWorkflowService: workflows.append(workflow) return workflows + @pytest.fixture + def dummy_session_cls(self): + class DummySession: + def __init__(self, *args, **kwargs): + self.commit = MagicMock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def begin(self): + return nullcontext() + + return DummySession + def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): mock_app.workflow_id = None mock_session = MagicMock() @@ -169,7 +187,10 @@ class TestWorkflowService: mock_session.scalars.assert_called_once() def test_submit_human_input_form_preview_uses_rendered_content( - self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch + self, + workflow_service: WorkflowService, + monkeypatch: pytest.MonkeyPatch, + dummy_session_cls, ) -> None: service = workflow_service node_data = HumanInputNodeData( @@ -187,25 +208,15 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + node_config = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} + ) + workflow.get_node_config_by_id.return_value = node_config workflow.get_enclosing_node_type_and_id.return_value = None service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] saved_outputs: dict[str, object] = {} - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - class DummySaver: def __init__(self, *args, **kwargs): pass @@ -213,7 +224,7 @@ class TestWorkflowService: def save(self, outputs, process_data): saved_outputs.update(outputs) - monkeypatch.setattr(workflow_service_module, "Session", DummySession) + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) @@ -232,8 +243,9 @@ class TestWorkflowService: service._build_human_input_variable_pool.assert_called_once_with( app_model=app_model, workflow=workflow, - node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}, + node_config=node_config, manual_inputs={"#node-0.result#": "LLM output"}, + user_id="account-1", ) node.render_form_content_with_outputs.assert_called_once() @@ -267,12 +279,13 @@ class TestWorkflowService: service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] workflow = MagicMock() - workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}} + workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} + ) service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: service.submit_human_input_form_preview( app_model=app_model, @@ -284,3 +297,119 @@ class TestWorkflowService: ) assert "Missing required inputs" in str(exc_info.value) + + def test_run_draft_workflow_node_successful_behavior( + self, workflow_service, mock_app, monkeypatch, dummy_session_cls + ): + """Behavior: When a basic workflow node runs, it correctly sets up context, + executes the node, and saves outputs.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.id = "wf-1" + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + + # Mock node config + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} + ) + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + # Mock class methods + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + + # Mock workflow entry execution + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-1" + mock_node_exec.process_data = {} + mock_run = MagicMock() + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) + + # Mock execution handling + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + # Mock repository + mock_repo = MagicMock() + mock_repo.get_execution_by_id.return_value = mock_node_exec + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Set up node execution service repo mock to return our exec node + mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} + mock_node_exec.node_id = "node-1" + mock_node_exec.node_type = "llm" + + # Mock draft variable saver + mock_saver = MagicMock() + monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) + + # Mock DB + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act + result = service.run_draft_workflow_node( + app_model=mock_app, + draft_workflow=mock_workflow, + node_id="node-1", + user_inputs={"input_val": "test"}, + account=account, + ) + + # Assert + assert result == mock_node_exec + service._handle_single_step_result.assert_called_once() + mock_repo.save.assert_called_once_with(mock_node_exec) + mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) + + def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): + """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" + service = workflow_service + account = SimpleNamespace(id="account-1") + mock_workflow = MagicMock() + mock_workflow.tenant_id = "tenant-1" + mock_workflow.environment_variables = [] + mock_workflow.conversation_variables = [] + mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( + {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} + ) + mock_workflow.get_enclosing_node_type_and_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) + monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) + monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) + + mock_node_exec = MagicMock() + mock_node_exec.id = "exec-invalid" + service._handle_single_step_result = MagicMock(return_value=mock_node_exec) + + mock_repo = MagicMock() + mock_repo_factory = MagicMock(return_value=mock_repo) + monkeypatch.setattr( + workflow_service_module.DifyCoreRepositoryFactory, + "create_workflow_node_execution_repository", + mock_repo_factory, + ) + service._node_execution_service_repo = mock_repo + + # Simulate failure to retrieve the saved execution + mock_repo.get_execution_by_id.return_value = None + + monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) + + monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) + + # Act & Assert + with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): + service.run_draft_workflow_node( + app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account + ) diff --git a/api/tests/unit_tests/tasks/test_clean_dataset_task.py b/api/tests/unit_tests/tasks/test_clean_dataset_task.py index df33f20c9b..74ba7f9c34 100644 --- a/api/tests/unit_tests/tasks/test_clean_dataset_task.py +++ b/api/tests/unit_tests/tasks/test_clean_dataset_task.py @@ -16,6 +16,7 @@ from unittest.mock import MagicMock, patch import pytest +from models.enums import DataSourceType from tasks.clean_dataset_task import clean_dataset_task # ============================================================================ @@ -116,7 +117,7 @@ def mock_document(): doc.id = str(uuid.uuid4()) doc.tenant_id = str(uuid.uuid4()) doc.dataset_id = str(uuid.uuid4()) - doc.data_source_type = "upload_file" + doc.data_source_type = DataSourceType.UPLOAD_FILE doc.data_source_info = '{"upload_file_id": "test-file-id"}' doc.data_source_info_dict = {"upload_file_id": "test-file-id"} return doc diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index 11b4663187..8a721124d6 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -10,14 +10,24 @@ This module tests the document indexing task functionality including: """ import uuid -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from core.indexing_runner import DocumentIsPausedError from core.rag.pipeline.queue import TenantIsolatedTaskQueue from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client +from models.dataset import Dataset, Document +from models.enums import IndexingStatus from services.document_indexing_proxy.document_indexing_task_proxy import DocumentIndexingTaskProxy +from tasks.document_indexing_task import ( + _document_indexing, + _document_indexing_with_tenant_queue, + document_indexing_task, + normal_document_indexing_task, + priority_document_indexing_task, +) # ============================================================================ # Fixtures @@ -56,6 +66,190 @@ def mock_redis(): return redis_client +# Additional fixtures required by tests in this module + + +@pytest.fixture +def mock_db_session(): + """Mock session_factory.create_session() to return a session whose queries use shared test data. + + Tests set session._shared_data = {"dataset": , "documents": [, ...]} + This fixture makes session.query(Dataset).first() return the shared dataset, + and session.query(Document).all()/first() return from the shared documents. + """ + with patch("tasks.document_indexing_task.session_factory") as mock_sf: + session = MagicMock() + session._shared_data = {"dataset": None, "documents": []} + + # Keep a pointer so repeated Document.first() calls iterate across provided docs + session._doc_first_idx = 0 + + def _query_side_effect(model): + q = MagicMock() + + # Capture filters passed via where(...) so first()/all() can honor them. + q._filters = {} + + def _extract_filters(*conds, **kw): + # Support both SQLAlchemy expressions (BinaryExpression) and kwargs + # We only need the simple fields used by production code: id, dataset_id, and id.in_(...) + for cond in conds: + left = getattr(cond, "left", None) + right = getattr(cond, "right", None) + key = None + if left is not None: + key = getattr(left, "key", None) or getattr(left, "name", None) + if not key: + continue + # Right side might be a BindParameter with .value, or a raw value/sequence + val = getattr(right, "value", right) + q._filters[key] = val + # Also accept kwargs (e.g., where(id=...)) just in case + for k, v in kw.items(): + q._filters[k] = v + + def _where_side_effect(*conds, **kw): + _extract_filters(*conds, **kw) + return q + + q.where.side_effect = _where_side_effect + + # Dataset queries + if model.__name__ == "Dataset": + + def _dataset_first(): + ds = session._shared_data.get("dataset") + if not ds: + return None + if "id" in q._filters: + val = q._filters["id"] + if isinstance(val, (list, tuple, set)): + return ds if ds.id in val else None + return ds if ds.id == val else None + return ds + + def _dataset_all(): + ds = session._shared_data.get("dataset") + if not ds: + return [] + first = _dataset_first() + return [first] if first else [] + + q.first.side_effect = _dataset_first + q.all.side_effect = _dataset_all + return q + + # Document queries + if model.__name__ == "Document": + + def _apply_doc_filters(docs): + result = list(docs) + for key in ("id", "dataset_id"): + if key in q._filters: + val = q._filters[key] + if isinstance(val, (list, tuple, set)): + result = [d for d in result if getattr(d, key, None) in val] + else: + result = [d for d in result if getattr(d, key, None) == val] + return result + + def _docs_all(): + docs = session._shared_data.get("documents", []) + return _apply_doc_filters(docs) + + def _docs_first(): + docs = _docs_all() + return docs[0] if docs else None + + q.all.side_effect = _docs_all + q.first.side_effect = _docs_first + return q + + # Default fallback + q.first.return_value = None + q.all.return_value = [] + return q + + session.query.side_effect = _query_side_effect + + # Implement session.begin() context manager that commits on exit + session.commit = MagicMock() + bm = MagicMock() + bm.__enter__.return_value = session + + def _bm_exit_side_effect(*args, **kwargs): + session.commit() + + bm.__exit__.side_effect = _bm_exit_side_effect + session.begin.return_value = bm + + # Context manager behavior for create_session(): ensure close() is called on exit + session.close = MagicMock() + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + mock_sf.create_session.return_value = cm + + yield session + + +@pytest.fixture +def mock_dataset(dataset_id, tenant_id): + """Create a mock Dataset object.""" + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "openai" + dataset.embedding_model = "text-embedding-ada-002" + return dataset + + +@pytest.fixture +def mock_documents(document_ids, dataset_id): + """Create mock Document objects.""" + documents = [] + for doc_id in document_ids: + doc = Mock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + # optional attribute used in some code paths + doc.doc_form = "text_model" + documents.append(doc) + return documents + + +@pytest.fixture +def mock_indexing_runner(): + """Mock IndexingRunner for document_indexing_task module.""" + with patch("tasks.document_indexing_task.IndexingRunner") as mock_runner_class: + mock_runner = MagicMock() + mock_runner_class.return_value = mock_runner + yield mock_runner + + +@pytest.fixture +def mock_feature_service(): + """Mock FeatureService for document_indexing_task module.""" + with patch("tasks.document_indexing_task.FeatureService") as mock_service: + mock_features = Mock() + mock_features.billing = Mock() + mock_features.billing.enabled = False + mock_features.vector_space = Mock() + mock_features.vector_space.size = 0 + mock_features.vector_space.limit = 1000 + mock_service.get_features.return_value = mock_features + yield mock_service + + # ============================================================================ # Test Task Enqueuing # ============================================================================ @@ -166,6 +360,492 @@ class TestTaskEnqueuing: assert mock_redis.lpush.called mock_task.delay.assert_not_called() + def test_legacy_document_indexing_task_still_works( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test that the legacy document_indexing_task function still works. + + This ensures backward compatibility for existing code that may still + use the deprecated function. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + document_indexing_task(dataset_id, document_ids) + + # Assert + mock_indexing_runner.run.assert_called_once() + + +# ============================================================================ +# Test Batch Processing +# ============================================================================ + + +class TestBatchProcessing: + """Test cases for batch processing of multiple documents.""" + + def test_batch_processing_multiple_documents( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test batch processing of multiple documents. + + All documents in the batch should be processed together and their + status should be updated to 'parsing'. + """ + # Arrange - Create actual document objects that can be modified + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should be set to 'parsing' status + for doc in mock_documents: + assert doc.indexing_status == IndexingStatus.PARSING + assert doc.processing_started_at is not None + + # IndexingRunner should be called with all documents + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == len(document_ids) + + def test_batch_processing_with_limit_check(self, dataset_id, mock_db_session, mock_dataset, mock_feature_service): + """ + Test batch processing respects upload limits. + + When the number of documents exceeds the batch upload limit, + an error should be raised and all documents should be marked as error. + """ + # Arrange + batch_limit = 10 + document_ids = [str(uuid.uuid4()) for _ in range(batch_limit + 1)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 1000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + with patch("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", str(batch_limit)): + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert doc.error is not None + assert "batch upload limit" in doc.error + + def test_batch_processing_sandbox_plan_single_document_only( + self, dataset_id, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test that sandbox plan only allows single document upload. + + Sandbox plan should reject batch uploads (more than 1 document). + """ + # Arrange + document_ids = [str(uuid.uuid4()) for _ in range(2)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX + mock_feature_service.get_features.return_value.vector_space.limit = 1000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert "does not support batch upload" in doc.error + + def test_batch_processing_empty_document_list( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test batch processing with empty document list. + + Should handle empty list gracefully without errors. + """ + # Arrange + document_ids = [] + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - IndexingRunner should still be called with empty list + mock_indexing_runner.run.assert_called_once_with([]) + + +# ============================================================================ +# Test Progress Tracking +# ============================================================================ + + +class TestProgressTracking: + """Test cases for progress tracking through task lifecycle.""" + + def test_document_status_progression( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test document status progresses correctly through lifecycle. + + Documents should transition from 'waiting' -> 'parsing' -> processed. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Status should be 'parsing' + for doc in mock_documents: + assert doc.indexing_status == IndexingStatus.PARSING + assert doc.processing_started_at is not None + + # Verify commit was called to persist status + assert mock_db_session.commit.called + + def test_processing_started_timestamp_set( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that processing_started_at timestamp is set correctly. + + When documents start processing, the timestamp should be recorded. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.processing_started_at is not None + + def test_tenant_queue_processes_next_task_after_completion( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that tenant queue processes next waiting task after completion. + + After a task completes, the system should check for waiting tasks + and process the next one. + """ + # Arrange + next_task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["next_doc_id"]} + + # Simulate next task in queue + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=next_task_data) + mock_redis.rpop.return_value = wrapper.serialize() + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Next task should be enqueued + mock_task.apply_async.assert_called() + # Task key should be set for next task + assert mock_redis.setex.called + + def test_tenant_queue_clears_flag_when_no_more_tasks( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that tenant queue clears flag when no more tasks are waiting. + + When there are no more tasks in the queue, the task key should be deleted. + """ + # Arrange + mock_redis.rpop.return_value = None # No more tasks + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Task key should be deleted + assert mock_redis.delete.called + + +# ============================================================================ +# Test Error Handling and Retries +# ============================================================================ + + +class TestErrorHandling: + """Test cases for error handling and retry mechanisms.""" + + def test_error_handling_sets_document_error_status( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test that errors during validation set document error status. + + When validation fails (e.g., limit exceeded), documents should be + marked with error status and error message. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Set up to trigger vector space limit error + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 100 + mock_feature_service.get_features.return_value.vector_space.size = 100 # At limit + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.indexing_status == "error" + assert doc.error is not None + assert "over the limit" in doc.error + assert doc.stopped_at is not None + + def test_error_handling_during_indexing_runner( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test error handling when IndexingRunner raises an exception. + + Errors during indexing should be caught and logged, but not crash the task. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Make IndexingRunner raise an exception + mock_indexing_runner.run.side_effect = Exception("Indexing failed") + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act - Should not raise exception + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed even after error + assert mock_db_session.close.called + + def test_document_paused_error_handling( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_documents, mock_indexing_runner + ): + """ + Test handling of DocumentIsPausedError. + + When a document is paused, the error should be caught and logged + but not treated as a failure. + """ + # Arrange + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Make IndexingRunner raise DocumentIsPausedError + mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act - Should not raise exception + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed + assert mock_db_session.close.called + + def test_dataset_not_found_error_handling(self, dataset_id, document_ids, mock_db_session): + """ + Test handling when dataset is not found. + + If the dataset doesn't exist, the task should exit gracefully. + """ + # Arrange + mock_db_session.query.return_value.where.return_value.first.return_value = None + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Session should be closed + assert mock_db_session.close.called + + def test_tenant_queue_error_handling_still_processes_next_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that errors don't prevent processing next task in tenant queue. + + Even if the current task fails, the next task should still be processed. + """ + # Arrange + next_task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": ["next_doc_id"]} + + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=next_task_data) + # Set up rpop to return task once for concurrency check + mock_redis.rpop.side_effect = [wrapper.serialize(), None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Make _document_indexing raise an error + with patch("tasks.document_indexing_task._document_indexing") as mock_indexing: + mock_indexing.side_effect = Exception("Processing failed") + + # Patch logger to avoid format string issue in actual code + with patch("tasks.document_indexing_task.logger"): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Next task should still be enqueued despite error + mock_task.apply_async.assert_called() + + def test_concurrent_task_limit_respected( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset + ): + """ + Test that tenant isolated task concurrency limit is respected. + + Should pull only TENANT_ISOLATED_TASK_CONCURRENCY tasks at a time. + """ + # Arrange + concurrency_limit = 2 + + # Create multiple tasks in queue + tasks = [] + for i in range(5): + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [f"doc_{i}"]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks one by one + mock_redis.rpop.side_effect = tasks[:concurrency_limit] + [None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Should enqueue exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + # ============================================================================ # Test Task Cancellation @@ -198,6 +878,407 @@ class TestTaskCancellation: assert tenant_2 in queue_2._queue +# ============================================================================ +# Integration Tests +# ============================================================================ + + +class TestAdvancedScenarios: + """Advanced test scenarios for edge cases and complex workflows.""" + + def test_multiple_documents_with_mixed_success_and_failure( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test handling of mixed success and failure scenarios in batch processing. + + When processing multiple documents, some may succeed while others fail. + This tests that the system handles partial failures gracefully. + + Scenario: + - Process 3 documents in a batch + - First document succeeds + - Second document is not found (skipped) + - Third document succeeds + + Expected behavior: + - Only found documents are processed + - Missing documents are skipped without crashing + - IndexingRunner receives only valid documents + """ + # Arrange - Create document IDs with one missing + document_ids = [str(uuid.uuid4()) for _ in range(3)] + + # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents + mock_documents = [] + for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - Only 2 documents should be processed (missing one skipped) + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == 2 # Only found documents + + def test_tenant_queue_with_multiple_concurrent_tasks( + self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset + ): + """ + Test concurrent task processing with tenant isolation. + + This tests the scenario where multiple tasks are queued for the same tenant + and need to be processed respecting the concurrency limit. + + Scenario: + - 5 tasks are waiting in the queue + - Concurrency limit is 2 + - After current task completes, pull and enqueue next 2 tasks + + Expected behavior: + - Exactly 2 tasks are pulled from queue (respecting concurrency) + - Each task is enqueued with correct parameters + - Task waiting time is set for each new task + """ + # Arrange + concurrency_limit = 2 + document_ids = [str(uuid.uuid4())] + + # Create multiple waiting tasks + waiting_tasks = [] + for i in range(5): + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [f"doc_{i}"]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + waiting_tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks up to concurrency limit + mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert + # Should enqueue exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + + # Verify task waiting time was set for each task + assert mock_redis.setex.call_count >= concurrency_limit + + def test_vector_space_limit_edge_case_at_exact_limit( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_feature_service + ): + """ + Test vector space limit validation at exact boundary. + + Edge case: When vector space is exactly at the limit (not over), + the upload should still be rejected. + + Scenario: + - Vector space limit: 100 + - Current size: 100 (exactly at limit) + - Try to upload 3 documents + + Expected behavior: + - Upload is rejected with appropriate error message + - All documents are marked with error status + """ + # Arrange + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.error = None + doc.stopped_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Set vector space exactly at limit + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 100 + mock_feature_service.get_features.return_value.vector_space.size = 100 # Exactly at limit + + # Act + _document_indexing(dataset_id, document_ids) + + # Assert - All documents should have error status + for doc in mock_documents: + assert doc.indexing_status == "error" + assert "over the limit" in doc.error + + def test_task_queue_fifo_ordering(self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset): + """ + Test that tasks are processed in FIFO (First-In-First-Out) order. + + The tenant isolated queue should maintain task order, ensuring + that tasks are processed in the sequence they were added. + + Scenario: + - Task A added first + - Task B added second + - Task C added third + - When pulling tasks, should get A, then B, then C + + Expected behavior: + - Tasks are retrieved in the order they were added + - FIFO ordering is maintained throughout processing + """ + # Arrange + document_ids = [str(uuid.uuid4())] + + # Create tasks with identifiable document IDs to track order + task_order = ["task_A", "task_B", "task_C"] + tasks = [] + for task_name in task_order: + task_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": [task_name]} + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks in FIFO order + mock_redis.rpop.side_effect = tasks + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", 3): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Verify tasks were enqueued in correct order + assert mock_task.apply_async.call_count == 3 + + # Check that document_ids in calls match expected order + for i, call_obj in enumerate(mock_task.apply_async.call_args_list): + called_doc_ids = call_obj[1]["kwargs"]["document_ids"] + assert called_doc_ids == [task_order[i]] + + def test_empty_queue_after_task_completion_cleans_up( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset + ): + """ + Test cleanup behavior when queue becomes empty after task completion. + + After processing the last task in the queue, the system should: + 1. Detect that no more tasks are waiting + 2. Delete the task key to indicate tenant is idle + 3. Allow new tasks to start fresh processing + + Scenario: + - Process a task + - Check queue for next tasks + - Queue is empty + - Task key should be deleted + + Expected behavior: + - Task key is deleted when queue is empty + - Tenant is marked as idle (no active tasks) + """ + # Arrange + mock_redis.rpop.return_value = None # Empty queue + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert + # Verify delete was called to clean up task key + mock_redis.delete.assert_called_once() + + # Verify the correct key was deleted (contains tenant_id and "document_indexing") + delete_call_args = mock_redis.delete.call_args[0][0] + assert tenant_id in delete_call_args + assert "document_indexing" in delete_call_args + + def test_billing_disabled_skips_limit_checks( + self, dataset_id, document_ids, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service + ): + """ + Test that billing limit checks are skipped when billing is disabled. + + For self-hosted or enterprise deployments where billing is disabled, + the system should not enforce vector space or batch upload limits. + + Scenario: + - Billing is disabled + - Upload 100 documents (would normally exceed limits) + - No limit checks should be performed + + Expected behavior: + - Documents are processed without limit validation + - No errors related to limits + - All documents proceed to indexing + """ + # Arrange - Create many documents + large_batch_ids = [str(uuid.uuid4()) for _ in range(100)] + + mock_documents = [] + for doc_id in large_batch_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Billing disabled - limits should not be checked + mock_feature_service.get_features.return_value.billing.enabled = False + + # Act + _document_indexing(dataset_id, large_batch_ids) + + # Assert + # All documents should be set to parsing (no limit errors) + for doc in mock_documents: + assert doc.indexing_status == IndexingStatus.PARSING + + # IndexingRunner should be called with all documents + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == 100 + + +class TestIntegration: + """Integration tests for complete task workflows.""" + + def test_complete_workflow_normal_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test complete workflow for normal document indexing task. + + This tests the full flow from task receipt to completion. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set up rpop to return None for concurrency check (no more tasks) + mock_redis.rpop.side_effect = [None] + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + normal_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + # Documents should be processed + mock_indexing_runner.run.assert_called_once() + # Session should be closed + assert mock_db_session.close.called + # Task key should be deleted (no more tasks) + assert mock_redis.delete.called + + def test_complete_workflow_priority_task( + self, tenant_id, dataset_id, document_ids, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test complete workflow for priority document indexing task. + + Priority tasks should follow the same flow as normal tasks. + """ + # Arrange - Create actual document objects + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set up rpop to return None for concurrency check (no more tasks) + mock_redis.rpop.side_effect = [None] + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + # Act + priority_document_indexing_task(tenant_id, dataset_id, document_ids) + + # Assert + mock_indexing_runner.run.assert_called_once() + assert mock_db_session.close.called + assert mock_redis.delete.called + + def test_queue_chain_processing( + self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset, mock_indexing_runner + ): + """ + Test that multiple tasks in queue are processed in sequence. + + When tasks are queued, they should be processed one after another. + """ + # Arrange + task_1_docs = [str(uuid.uuid4())] + task_2_docs = [str(uuid.uuid4())] + + task_2_data = {"tenant_id": tenant_id, "dataset_id": dataset_id, "document_ids": task_2_docs} + + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_2_data) + + # First call returns task 2, second call returns None + mock_redis.rpop.side_effect = [wrapper.serialize(), None] + + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: + mock_features.return_value.billing.enabled = False + + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act - Process first task + _document_indexing_with_tenant_queue(tenant_id, dataset_id, task_1_docs, mock_task) + + # Assert - Second task should be enqueued + assert mock_task.apply_async.called + call_args = mock_task.apply_async.call_args + assert call_args[1]["kwargs"]["document_ids"] == task_2_docs + + # ============================================================================ # Additional Edge Case Tests # ============================================================================ @@ -249,6 +1330,107 @@ class TestEdgeCases: class TestPerformanceScenarios: """Test performance-related scenarios and optimizations.""" + def test_large_document_batch_processing( + self, dataset_id, mock_db_session, mock_dataset, mock_indexing_runner, mock_feature_service + ): + """ + Test processing a large batch of documents at batch limit. + + When processing the maximum allowed batch size, the system + should handle it efficiently without errors. + + Scenario: + - Process exactly batch_upload_limit documents (e.g., 50) + - All documents are valid + - Billing is enabled + + Expected behavior: + - All documents are processed successfully + - No timeout or memory issues + - Batch limit is not exceeded + """ + # Arrange + batch_limit = 50 + document_ids = [str(uuid.uuid4()) for _ in range(batch_limit)] + + mock_documents = [] + for doc_id in document_ids: + doc = MagicMock(spec=Document) + doc.id = doc_id + doc.dataset_id = dataset_id + doc.indexing_status = "waiting" + doc.processing_started_at = None + mock_documents.append(doc) + + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents + + # Configure billing with sufficient limits + mock_feature_service.get_features.return_value.billing.enabled = True + mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL + mock_feature_service.get_features.return_value.vector_space.limit = 10000 + mock_feature_service.get_features.return_value.vector_space.size = 0 + + with patch("tasks.document_indexing_task.dify_config.BATCH_UPLOAD_LIMIT", str(batch_limit)): + # Act + _document_indexing(dataset_id, document_ids) + + # Assert + for doc in mock_documents: + assert doc.indexing_status == IndexingStatus.PARSING + + mock_indexing_runner.run.assert_called_once() + call_args = mock_indexing_runner.run.call_args[0][0] + assert len(call_args) == batch_limit + + def test_tenant_queue_handles_burst_traffic(self, tenant_id, dataset_id, mock_redis, mock_db_session, mock_dataset): + """ + Test tenant queue handling burst traffic scenarios. + + When many tasks arrive in a burst for the same tenant, + the queue should handle them efficiently without dropping tasks. + + Scenario: + - 20 tasks arrive rapidly + - Concurrency limit is 3 + - Tasks should be queued and processed in batches + + Expected behavior: + - First 3 tasks are processed immediately + - Remaining tasks wait in queue + - No tasks are lost + """ + # Arrange + num_tasks = 20 + concurrency_limit = 3 + document_ids = [str(uuid.uuid4())] + + # Create waiting tasks + waiting_tasks = [] + for i in range(num_tasks): + task_data = { + "tenant_id": tenant_id, + "dataset_id": dataset_id, + "document_ids": [f"doc_{i}"], + } + from core.rag.pipeline.queue import TaskWrapper + + wrapper = TaskWrapper(data=task_data) + waiting_tasks.append(wrapper.serialize()) + + # Mock rpop to return tasks up to concurrency limit + mock_redis.rpop.side_effect = waiting_tasks[:concurrency_limit] + [None] + mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + with patch("tasks.document_indexing_task.dify_config.TENANT_ISOLATED_TASK_CONCURRENCY", concurrency_limit): + with patch("tasks.document_indexing_task.normal_document_indexing_task") as mock_task: + # Act + _document_indexing_with_tenant_queue(tenant_id, dataset_id, document_ids, mock_task) + + # Assert - Should process exactly concurrency_limit tasks + assert mock_task.apply_async.call_count == concurrency_limit + def test_multiple_tenants_isolated_processing(self, mock_redis): """ Test that multiple tenants process tasks in isolation. diff --git a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py index 68fb8b748f..f6dbc4275b 100644 --- a/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_duplicate_document_indexing_task.py @@ -1,158 +1,38 @@ -""" -Unit tests for duplicate document indexing tasks. - -This module tests the duplicate document indexing task functionality including: -- Task enqueuing to different queues (normal, priority, tenant-isolated) -- Batch processing of multiple duplicate documents -- Progress tracking through task lifecycle -- Error handling and retry mechanisms -- Cleanup of old document data before re-indexing -""" +"""Unit tests for queue/wrapper behaviors in duplicate document indexing tasks (non-database logic).""" import uuid -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest -from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.pipeline.queue import TenantIsolatedTaskQueue -from enums.cloud_plan import CloudPlan -from models.dataset import Dataset, Document, DocumentSegment from tasks.duplicate_document_indexing_task import ( - _duplicate_document_indexing_task, _duplicate_document_indexing_task_with_tenant_queue, duplicate_document_indexing_task, normal_duplicate_document_indexing_task, priority_duplicate_document_indexing_task, ) -# ============================================================================ -# Fixtures -# ============================================================================ - @pytest.fixture def tenant_id(): - """Generate a unique tenant ID for testing.""" return str(uuid.uuid4()) @pytest.fixture def dataset_id(): - """Generate a unique dataset ID for testing.""" return str(uuid.uuid4()) @pytest.fixture def document_ids(): - """Generate a list of document IDs for testing.""" return [str(uuid.uuid4()) for _ in range(3)] -@pytest.fixture -def mock_dataset(dataset_id, tenant_id): - """Create a mock Dataset object.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.indexing_technique = "high_quality" - dataset.embedding_model_provider = "openai" - dataset.embedding_model = "text-embedding-ada-002" - return dataset - - -@pytest.fixture -def mock_documents(document_ids, dataset_id): - """Create mock Document objects.""" - documents = [] - for doc_id in document_ids: - doc = Mock(spec=Document) - doc.id = doc_id - doc.dataset_id = dataset_id - doc.indexing_status = "waiting" - doc.error = None - doc.stopped_at = None - doc.processing_started_at = None - doc.doc_form = "text_model" - documents.append(doc) - return documents - - -@pytest.fixture -def mock_document_segments(document_ids): - """Create mock DocumentSegment objects.""" - segments = [] - for doc_id in document_ids: - for i in range(3): - segment = Mock(spec=DocumentSegment) - segment.id = str(uuid.uuid4()) - segment.document_id = doc_id - segment.index_node_id = f"node-{doc_id}-{i}" - segments.append(segment) - return segments - - -@pytest.fixture -def mock_db_session(): - """Mock database session via session_factory.create_session().""" - with patch("tasks.duplicate_document_indexing_task.session_factory", autospec=True) as mock_sf: - session = MagicMock() - # Allow tests to observe session.close() via context manager teardown - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - - def _exit_side_effect(*args, **kwargs): - session.close() - - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm - - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - session.scalars.return_value = MagicMock() - yield session - - -@pytest.fixture -def mock_indexing_runner(): - """Mock IndexingRunner.""" - with patch("tasks.duplicate_document_indexing_task.IndexingRunner", autospec=True) as mock_runner_class: - mock_runner = MagicMock(spec=IndexingRunner) - mock_runner_class.return_value = mock_runner - yield mock_runner - - -@pytest.fixture -def mock_feature_service(): - """Mock FeatureService.""" - with patch("tasks.duplicate_document_indexing_task.FeatureService", autospec=True) as mock_service: - mock_features = Mock() - mock_features.billing = Mock() - mock_features.billing.enabled = False - mock_features.vector_space = Mock() - mock_features.vector_space.size = 0 - mock_features.vector_space.limit = 1000 - mock_service.get_features.return_value = mock_features - yield mock_service - - -@pytest.fixture -def mock_index_processor_factory(): - """Mock IndexProcessorFactory.""" - with patch("tasks.duplicate_document_indexing_task.IndexProcessorFactory", autospec=True) as mock_factory: - mock_processor = MagicMock() - mock_processor.clean = Mock() - mock_factory.return_value.init_index_processor.return_value = mock_processor - yield mock_factory - - @pytest.fixture def mock_tenant_isolated_queue(): - """Mock TenantIsolatedTaskQueue.""" with patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue", autospec=True) as mock_queue_class: - mock_queue = MagicMock(spec=TenantIsolatedTaskQueue) + mock_queue = Mock(spec=TenantIsolatedTaskQueue) mock_queue.pull_tasks.return_value = [] mock_queue.delete_task_key = Mock() mock_queue.set_task_waiting_time = Mock() @@ -160,11 +40,6 @@ def mock_tenant_isolated_queue(): yield mock_queue -# ============================================================================ -# Tests for deprecated duplicate_document_indexing_task -# ============================================================================ - - class TestDuplicateDocumentIndexingTask: """Tests for the deprecated duplicate_document_indexing_task function.""" @@ -190,258 +65,6 @@ class TestDuplicateDocumentIndexingTask: mock_core_func.assert_called_once_with(dataset_id, document_ids) -# ============================================================================ -# Tests for _duplicate_document_indexing_task core function -# ============================================================================ - - -class TestDuplicateDocumentIndexingTaskCore: - """Tests for the _duplicate_document_indexing_task core function.""" - - def test_successful_duplicate_document_indexing( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - mock_document_segments, - dataset_id, - document_ids, - ): - """Test successful duplicate document indexing flow.""" - # Arrange - # Dataset via query.first() - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # scalars() call sequence: - # 1) documents list - # 2..N) segments per document - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - # First call returns documents; subsequent calls return segments - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = mock_document_segments - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Verify IndexingRunner was called - mock_indexing_runner.run.assert_called_once() - - # Verify all documents were set to parsing status - for doc in mock_documents: - assert doc.indexing_status == "parsing" - assert doc.processing_started_at is not None - - # Verify session operations - assert mock_db_session.commit.called - assert mock_db_session.close.called - - def test_duplicate_document_indexing_dataset_not_found(self, mock_db_session, dataset_id, document_ids): - """Test duplicate document indexing when dataset is not found.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = None - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should close the session at least once - assert mock_db_session.close.called - - def test_duplicate_document_indexing_with_billing_enabled_sandbox_plan( - self, - mock_db_session, - mock_feature_service, - mock_dataset, - dataset_id, - document_ids, - ): - """Test duplicate document indexing with billing enabled and sandbox plan.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_features = mock_feature_service.get_features.return_value - mock_features.billing.enabled = True - mock_features.billing.subscription.plan = CloudPlan.SANDBOX - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # For sandbox plan with multiple documents, should fail - mock_db_session.commit.assert_called() - - def test_duplicate_document_indexing_with_billing_limit_exceeded( - self, - mock_db_session, - mock_feature_service, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when billing limit is exceeded.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - # First scalars() -> documents; subsequent -> empty segments - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_features = mock_feature_service.get_features.return_value - mock_features.billing.enabled = True - mock_features.billing.subscription.plan = CloudPlan.TEAM - mock_features.vector_space.size = 990 - mock_features.vector_space.limit = 1000 - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should commit the session - assert mock_db_session.commit.called - # Should close the session - assert mock_db_session.close.called - - def test_duplicate_document_indexing_runner_error( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when IndexingRunner raises an error.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_indexing_runner.run.side_effect = Exception("Indexing error") - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should close the session even after error - mock_db_session.close.assert_called_once() - - def test_duplicate_document_indexing_document_is_paused( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - dataset_id, - document_ids, - ): - """Test duplicate document indexing when document is paused.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = [] - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Should handle DocumentIsPausedError gracefully - mock_db_session.close.assert_called_once() - - def test_duplicate_document_indexing_cleans_old_segments( - self, - mock_db_session, - mock_indexing_runner, - mock_feature_service, - mock_index_processor_factory, - mock_dataset, - mock_documents, - mock_document_segments, - dataset_id, - document_ids, - ): - """Test that duplicate document indexing cleans old segments.""" - # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def _scalars_side_effect(*args, **kwargs): - m = MagicMock() - if not hasattr(_scalars_side_effect, "_calls"): - _scalars_side_effect._calls = 0 - if _scalars_side_effect._calls == 0: - m.all.return_value = mock_documents - else: - m.all.return_value = mock_document_segments - _scalars_side_effect._calls += 1 - return m - - mock_db_session.scalars.side_effect = _scalars_side_effect - mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - - # Act - _duplicate_document_indexing_task(dataset_id, document_ids) - - # Assert - # Verify clean was called for each document - assert mock_processor.clean.call_count == len(mock_documents) - - # Verify segments were deleted in batch (DELETE FROM document_segments) - execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] - assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) - - -# ============================================================================ -# Tests for tenant queue wrapper function -# ============================================================================ - - class TestDuplicateDocumentIndexingTaskWithTenantQueue: """Tests for _duplicate_document_indexing_task_with_tenant_queue function.""" @@ -536,11 +159,6 @@ class TestDuplicateDocumentIndexingTaskWithTenantQueue: mock_tenant_isolated_queue.pull_tasks.assert_called_once() -# ============================================================================ -# Tests for normal_duplicate_document_indexing_task -# ============================================================================ - - class TestNormalDuplicateDocumentIndexingTask: """Tests for normal_duplicate_document_indexing_task function.""" @@ -581,11 +199,6 @@ class TestNormalDuplicateDocumentIndexingTask: ) -# ============================================================================ -# Tests for priority_duplicate_document_indexing_task -# ============================================================================ - - class TestPriorityDuplicateDocumentIndexingTask: """Tests for priority_duplicate_document_indexing_task function.""" diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index ee0699ba2d..bd0182a402 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -6,7 +6,7 @@ from typing import Any import pytest -from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module @@ -47,7 +47,7 @@ class _FakeSessionFactory: class _FakeFormRepo: - def __init__(self, _session_factory, form_map: dict[str, Any] | None = None): + def __init__(self, form_map: dict[str, Any] | None = None): self.calls: list[dict[str, Any]] = [] self._form_map = form_map or {} @@ -149,9 +149,9 @@ def test_check_and_handle_human_input_timeouts_marks_and_routes(monkeypatch: pyt monkeypatch.setattr(task_module, "sessionmaker", lambda *args, **kwargs: _FakeSessionFactory(forms, capture)) form_map = {form.id: form for form in forms} - repo = _FakeFormRepo(None, form_map=form_map) + repo = _FakeFormRepo(form_map=form_map) - def _repo_factory(_session_factory): + def _repo_factory(): return repo service = _FakeService(None) diff --git a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py index 20cb7a211e..37b7a85451 100644 --- a/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/unit_tests/tasks/test_mail_human_input_delivery_task.py @@ -120,4 +120,37 @@ def test_dispatch_human_input_email_task_replaces_body_variables(monkeypatch: py session_factory=lambda: _DummySession(form), ) - assert mail.sent[0]["html"] == "Body OK" + assert mail.sent[0]["html"] == "

Body OK

" + + +@pytest.mark.parametrize("line_break", ["\r\n", "\r", "\n"]) +def test_dispatch_human_input_email_task_sanitizes_subject( + monkeypatch: pytest.MonkeyPatch, + line_break: str, +): + mail = _DummyMail() + form = SimpleNamespace(id="form-1", tenant_id="tenant-1", workflow_run_id=None) + job = task_module._EmailDeliveryJob( + form_id="form-1", + subject=f"Notice{line_break}BCC:attacker@example.com Alert", + body="Body", + form_content="content", + recipients=[task_module._EmailRecipient(email="user@example.com", token="token-1")], + ) + + monkeypatch.setattr(task_module, "mail", mail) + monkeypatch.setattr( + task_module.FeatureService, + "get_features", + lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True), + ) + monkeypatch.setattr(task_module, "_load_email_jobs", lambda _session, _form: [job]) + monkeypatch.setattr(task_module, "_load_variable_pool", lambda _workflow_run_id: None) + + task_module.dispatch_human_input_email_task( + form_id="form-1", + node_title="Approve", + session_factory=lambda: _DummySession(form), + ) + + assert mail.sent[0]["subject"] == "Notice BCC:attacker@example.com Alert" diff --git a/api/tests/unit_tests/tasks/test_summary_queue_isolation.py b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py new file mode 100644 index 0000000000..f6632e0a8a --- /dev/null +++ b/api/tests/unit_tests/tasks/test_summary_queue_isolation.py @@ -0,0 +1,40 @@ +""" +Unit tests for summary index task queue isolation. + +These tasks must NOT run on the shared 'dataset' queue because they invoke LLMs +for each document segment and can occupy all worker slots for hours, blocking +document indexing tasks. +""" + +import pytest + +from tasks.generate_summary_index_task import generate_summary_index_task +from tasks.regenerate_summary_index_task import regenerate_summary_index_task + +SUMMARY_QUEUE = "dataset_summary" +INDEXING_QUEUE = "dataset" + + +def _task_queue(task) -> str | None: + # Celery's @shared_task(queue=...) stores the routing key on the task instance + # at runtime, but type stubs don't declare it; use getattr to stay type-clean. + return getattr(task, "queue", None) + + +@pytest.mark.parametrize( + ("task", "task_name"), + [ + (generate_summary_index_task, "generate_summary_index_task"), + (regenerate_summary_index_task, "regenerate_summary_index_task"), + ], +) +def test_summary_task_uses_dedicated_queue(task, task_name): + """Summary tasks must use the dataset_summary queue, not the shared dataset queue. + + Summary generation is LLM-heavy and will block document indexing if placed + on the shared queue. + """ + assert _task_queue(task) == SUMMARY_QUEUE, ( + f"{task_name} must run on '{SUMMARY_QUEUE}' queue (not '{INDEXING_QUEUE}'). " + "Summary generation is LLM-heavy and will block document indexing if placed on the shared queue." + ) diff --git a/api/tests/unit_tests/tasks/test_workflow_execute_task.py b/api/tests/unit_tests/tasks/test_workflow_execute_task.py index 161151305d..d3cf632b47 100644 --- a/api/tests/unit_tests/tasks/test_workflow_execute_task.py +++ b/api/tests/unit_tests/tasks/test_workflow_execute_task.py @@ -2,12 +2,40 @@ from __future__ import annotations import json import uuid +from types import SimpleNamespace from unittest.mock import MagicMock import pytest -from models.model import AppMode -from tasks.app_generate.workflow_execute_task import _publish_streaming_response +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom +from models.enums import CreatorUserRole +from models.model import App, AppMode, Conversation +from models.workflow import Workflow, WorkflowRun +from tasks.app_generate.workflow_execute_task import _publish_streaming_response, _resume_app_execution + + +class _FakeSessionContext: + def __init__(self, session: MagicMock): + self._session = session + + def __enter__(self) -> MagicMock: + return self._session + + def __exit__(self, exc_type, exc, tb) -> bool: + return False + + +def _build_advanced_chat_generate_entity(conversation_id: str | None) -> AdvancedChatAppGenerateEntity: + return AdvancedChatAppGenerateEntity( + task_id="task-id", + inputs={}, + files=[], + user_id="user-id", + stream=True, + invoke_from=InvokeFrom.WEB_APP, + query="query", + conversation_id=conversation_id, + ) @pytest.fixture @@ -37,3 +65,138 @@ def test_publish_streaming_response_coerces_string_uuid(mock_topic: MagicMock): _publish_streaming_response(response_stream, str(workflow_run_id), app_mode=AppMode.ADVANCED_CHAT) mock_topic.publish.assert_called_once_with(json.dumps({"event": "bar"}).encode()) + + +def test_resume_app_execution_queries_message_by_conversation_and_workflow_run(mocker): + workflow_run_id = "run-id" + conversation_id = "conversation-id" + message = MagicMock() + + mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + + pause_entity = MagicMock() + pause_entity.get_state.return_value = b"state" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_pause.return_value = pause_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + generate_entity = _build_advanced_chat_generate_entity(conversation_id) + resumption_context = MagicMock() + resumption_context.serialized_graph_runtime_state = "{}" + resumption_context.get_generate_entity.return_value = generate_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + ) + mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) + + workflow_run = SimpleNamespace( + workflow_id="wf-id", + app_id="app-id", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by="account-id", + tenant_id="tenant-id", + ) + workflow = SimpleNamespace(created_by="workflow-owner") + app_model = SimpleNamespace(id="app-id") + conversation = SimpleNamespace(id=conversation_id) + + session = MagicMock() + + def _session_get(model, key): + if model is WorkflowRun: + return workflow_run + if model is Workflow: + return workflow + if model is App: + return app_model + if model is Conversation: + return conversation + return None + + session.get.side_effect = _session_get + session.scalar.return_value = message + + mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) + mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) + resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") + mocker.patch("tasks.app_generate.workflow_execute_task._resume_workflow") + + _resume_app_execution({"workflow_run_id": workflow_run_id}) + + stmt = session.scalar.call_args.args[0] + stmt_text = str(stmt) + assert "messages.conversation_id = :conversation_id_1" in stmt_text + assert "messages.workflow_run_id = :workflow_run_id_1" in stmt_text + assert "ORDER BY messages.created_at DESC" in stmt_text + assert " LIMIT " in stmt_text + + compiled_params = stmt.compile().params + assert conversation_id in compiled_params.values() + assert workflow_run_id in compiled_params.values() + + workflow_run_repo.resume_workflow_pause.assert_called_once_with(workflow_run_id, pause_entity) + resume_advanced_chat.assert_called_once() + assert resume_advanced_chat.call_args.kwargs["conversation"] is conversation + assert resume_advanced_chat.call_args.kwargs["message"] is message + + +def test_resume_app_execution_returns_early_when_advanced_chat_missing_conversation_id(mocker): + workflow_run_id = "run-id" + + mocker.patch("tasks.app_generate.workflow_execute_task.db", SimpleNamespace(engine=object())) + + pause_entity = MagicMock() + pause_entity.get_state.return_value = b"state" + + workflow_run_repo = MagicMock() + workflow_run_repo.get_workflow_pause.return_value = pause_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.DifyAPIRepositoryFactory.create_api_workflow_run_repository", + return_value=workflow_run_repo, + ) + + generate_entity = _build_advanced_chat_generate_entity(conversation_id=None) + resumption_context = MagicMock() + resumption_context.serialized_graph_runtime_state = "{}" + resumption_context.get_generate_entity.return_value = generate_entity + mocker.patch( + "tasks.app_generate.workflow_execute_task.WorkflowResumptionContext.loads", return_value=resumption_context + ) + mocker.patch("tasks.app_generate.workflow_execute_task.GraphRuntimeState.from_snapshot", return_value=MagicMock()) + + workflow_run = SimpleNamespace( + workflow_id="wf-id", + app_id="app-id", + created_by_role=CreatorUserRole.ACCOUNT.value, + created_by="account-id", + tenant_id="tenant-id", + ) + workflow = SimpleNamespace(created_by="workflow-owner") + app_model = SimpleNamespace(id="app-id") + + session = MagicMock() + + def _session_get(model, key): + if model is WorkflowRun: + return workflow_run + if model is Workflow: + return workflow + if model is App: + return app_model + return None + + session.get.side_effect = _session_get + + mocker.patch("tasks.app_generate.workflow_execute_task.Session", return_value=_FakeSessionContext(session)) + mocker.patch("tasks.app_generate.workflow_execute_task._resolve_user_for_run", return_value=MagicMock()) + resume_advanced_chat = mocker.patch("tasks.app_generate.workflow_execute_task._resume_advanced_chat") + + _resume_app_execution({"workflow_run_id": workflow_run_id}) + + session.scalar.assert_not_called() + workflow_run_repo.resume_workflow_pause.assert_not_called() + resume_advanced_chat.assert_not_called() diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index fd5f0713a4..a223f0119e 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -11,11 +11,11 @@ # import pytest -# from core.workflow.entities.workflow_node_execution import ( +# from dify_graph.entities.workflow_node_execution import ( # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from core.workflow.enums import NodeType +# from dify_graph.enums import BuiltinNodeTypes # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType @@ -41,7 +41,7 @@ # workflow_execution_id=str(uuid4()), # index=1, # node_id="test_node", -# node_type=NodeType.LLM, +# node_type=BuiltinNodeTypes.LLM, # title="Test Node", # inputs={"input_key": "input_value"}, # outputs={"output_key": "output_value"}, @@ -134,7 +134,7 @@ # workflow_execution_id=str(uuid4()), # index=1, # node_id="test_node", -# node_type=NodeType.LLM, +# node_type=BuiltinNodeTypes.LLM, # title="Test Node", # inputs=large_data, # outputs=large_data, diff --git a/api/tests/unit_tests/tools/__init__.py b/api/tests/unit_tests/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index 5930b63f58..fa9c6af287 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -13,11 +13,11 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) -from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool +from dify_graph.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 9a0dbfa2d8..7ec1343f98 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -5,7 +5,7 @@ import pytest from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_runtime.entities.llm_entities import ( +from dify_graph.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, @@ -13,13 +13,13 @@ from core.model_runtime.entities.llm_entities import ( LLMResultWithStructuredOutput, LLMUsage, ) -from core.model_runtime.entities.message_entities import ( +from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py new file mode 100644 index 0000000000..1f0bf8ef37 --- /dev/null +++ b/api/tests/workflow_test_utils.py @@ -0,0 +1,53 @@ +from collections.abc import Mapping +from typing import Any + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from dify_graph.entities.graph_init_params import GraphInitParams + + +def build_test_run_context( + *, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> dict[str, Any]: + normalized_user_from = user_from if isinstance(user_from, UserFrom) else UserFrom(user_from) + normalized_invoke_from = invoke_from if isinstance(invoke_from, InvokeFrom) else InvokeFrom(invoke_from) + return build_dify_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=normalized_user_from, + invoke_from=normalized_invoke_from, + extra_context=extra_context, + ) + + +def build_test_graph_init_params( + *, + workflow_id: str = "workflow", + graph_config: Mapping[str, Any] | None = None, + call_depth: int = 0, + tenant_id: str = "tenant", + app_id: str = "app", + user_id: str = "user", + user_from: UserFrom | str = UserFrom.ACCOUNT, + invoke_from: InvokeFrom | str = InvokeFrom.DEBUGGER, + extra_context: Mapping[str, Any] | None = None, +) -> GraphInitParams: + return GraphInitParams( + workflow_id=workflow_id, + graph_config=graph_config or {}, + run_context=build_test_run_context( + tenant_id=tenant_id, + app_id=app_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + extra_context=extra_context, + ), + call_depth=call_depth, + ) diff --git a/api/uv.lock b/api/uv.lock index 9efb53cc5e..bbec12561d 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -2,18 +2,12 @@ version = 1 revision = 3 requires-python = ">=3.11, <3.13" resolution-markers = [ - "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", - "python_full_version >= '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform != 'linux'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_python_implementation != 'PyPy' and sys_platform != 'linux'", - "python_full_version >= '3.12.4' and platform_python_implementation == 'PyPy' and sys_platform == 'linux'", - "python_full_version >= '3.12.4' and platform_python_implementation == 'PyPy' and sys_platform != 'linux'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_python_implementation == 'PyPy' and sys_platform == 'linux'", - "python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_python_implementation == 'PyPy' and sys_platform != 'linux'", - "python_full_version < '3.12' and platform_python_implementation != 'PyPy' and sys_platform == 'linux'", - "python_full_version < '3.12' and platform_python_implementation != 'PyPy' and sys_platform != 'linux'", - "python_full_version < '3.12' and platform_python_implementation == 'PyPy' and sys_platform == 'linux'", - "python_full_version < '3.12' and platform_python_implementation == 'PyPy' and sys_platform != 'linux'", + "python_full_version >= '3.12' and sys_platform == 'win32'", + "python_full_version >= '3.12' and sys_platform == 'emscripten'", + "python_full_version >= '3.12' and sys_platform != 'emscripten' and sys_platform != 'win32'", + "python_full_version < '3.12' and sys_platform == 'win32'", + "python_full_version < '3.12' and sys_platform == 'emscripten'", + "python_full_version < '3.12' and sys_platform != 'emscripten' and sys_platform != 'win32'", ] [[package]] @@ -124,21 +118,21 @@ wheels = [ [[package]] name = "alembic" -version = "1.17.2" +version = "1.18.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "mako" }, { name = "sqlalchemy" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/02/a6/74c8cadc2882977d80ad756a13857857dbcf9bd405bc80b662eb10651282/alembic-1.17.2.tar.gz", hash = "sha256:bbe9751705c5e0f14877f02d46c53d10885e377e3d90eda810a016f9baa19e8e", size = 1988064, upload-time = "2025-11-14T20:35:04.057Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ba/88/6237e97e3385b57b5f1528647addea5cc03d4d65d5979ab24327d41fb00d/alembic-1.17.2-py3-none-any.whl", hash = "sha256:f483dd1fe93f6c5d49217055e4d15b905b425b6af906746abb35b69c1996c4e6", size = 248554, upload-time = "2025-11-14T20:35:05.699Z" }, + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, ] [[package]] name = "alibabacloud-credentials" -version = "1.0.3" +version = "1.0.8" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiofiles" }, @@ -146,9 +140,9 @@ dependencies = [ { name = "alibabacloud-tea" }, { name = "apscheduler" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/df/82/45ec98bd19387507cf058ce47f62d6fea288bf0511c5a101b832e13d3edd/alibabacloud-credentials-1.0.3.tar.gz", hash = "sha256:9d8707e96afc6f348e23f5677ed15a21c2dfce7cfe6669776548ee4c80e1dfaf", size = 35831, upload-time = "2025-10-14T06:39:58.97Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/15/2b01b4a6cbed4cc2c8a1c801efec43af945af22fd3ca5f78c932117fd4ce/alibabacloud_credentials-1.0.8.tar.gz", hash = "sha256:364c22abef2d240b259ceadf1ce6800017f19a336729553956928a1edd12e769", size = 40465, upload-time = "2026-03-11T09:13:59.398Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/df/dbd9ae9d531a40d5613573c5a22ef774ecfdcaa0dc43aad42189f89c04ce/alibabacloud_credentials-1.0.3-py3-none-any.whl", hash = "sha256:30c8302f204b663c655d97e1c283ee9f9f84a6257d7901b931477d6cf34445a8", size = 41875, upload-time = "2025-10-14T06:39:58.029Z" }, + { url = "https://files.pythonhosted.org/packages/a9/24/7c47501b24897a1379cd57cc8b8de376161f2487548fc8233b2b74ab25c7/alibabacloud_credentials-1.0.8-py3-none-any.whl", hash = "sha256:66677c3fa54aeb66cfb9cc97da4a787534f38a04d09bbfa0bc6c815fe1af7e28", size = 48799, upload-time = "2026-03-11T09:13:58.113Z" }, ] [[package]] @@ -193,13 +187,16 @@ wheels = [ [[package]] name = "alibabacloud-openapi-util" -version = "0.2.2" +version = "0.2.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-tea-util" }, { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f6/50/5f41ab550d7874c623f6e992758429802c4b52a6804db437017e5387de33/alibabacloud_openapi_util-0.2.2.tar.gz", hash = "sha256:ebbc3906f554cb4bf8f513e43e8a33e8b6a3d4a0ef13617a0e14c3dda8ef52a8", size = 7201, upload-time = "2023-10-23T07:44:18.523Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f6/51/be5802851a4ed20ac2c6db50ac8354a6e431e93db6e714ca39b50983626f/alibabacloud_openapi_util-0.2.4.tar.gz", hash = "sha256:87022b9dcb7593a601f7a40ca698227ac3ccb776b58cb7b06b8dc7f510995c34", size = 7981, upload-time = "2026-01-15T08:05:03.947Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/46/9b217343648b366eb93447f5d93116e09a61956005794aed5ef95a2e9e2e/alibabacloud_openapi_util-0.2.4-py3-none-any.whl", hash = "sha256:a2474f230b5965ae9a8c286e0dc86132a887928d02d20b8182656cf6b1b6c5bd", size = 7661, upload-time = "2026-01-15T08:05:01.374Z" }, +] [[package]] name = "alibabacloud-openplatform20191219" @@ -259,16 +256,19 @@ sdist = { url = "https://files.pythonhosted.org/packages/22/8a/ef8ddf5ee0350984c [[package]] name = "alibabacloud-tea-openapi" -version = "0.3.16" +version = "0.4.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, { name = "alibabacloud-gateway-spi" }, - { name = "alibabacloud-openapi-util" }, { name = "alibabacloud-tea-util" }, - { name = "alibabacloud-tea-xml" }, + { name = "cryptography" }, + { name = "darabonba-core" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/91/4f/b5288eea8f4d4b032c9a8f2cd1d926d5017977d10b874956f31e5343f299/alibabacloud_tea_openapi-0.4.3.tar.gz", hash = "sha256:12aef036ed993637b6f141abbd1de9d6199d5516f4a901588bb65d6a3768d41b", size = 21864, upload-time = "2026-01-15T07:55:16.744Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a5/37/48ee5468ecad19c6d44cf3b9629d77078e836ee3ec760f0366247f307b7c/alibabacloud_tea_openapi-0.4.3-py3-none-any.whl", hash = "sha256:d0b3a373b760ef6278b25fc128c73284301e07888977bf97519e7636d47bdf0a", size = 26159, upload-time = "2026-01-15T07:55:15.72Z" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/09/be/f594e79625e5ccfcfe7f12d7d70709a3c59e920878469c998886211c850d/alibabacloud_tea_openapi-0.3.16.tar.gz", hash = "sha256:6bffed8278597592e67860156f424bde4173a6599d7b6039fb640a3612bae292", size = 13087, upload-time = "2025-07-04T09:30:10.689Z" } [[package]] name = "alibabacloud-tea-util" @@ -293,7 +293,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/32/eb/5e82e419c3061823f [[package]] name = "aliyun-log-python-sdk" -version = "0.9.37" +version = "0.9.42" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dateparser" }, @@ -305,7 +305,7 @@ dependencies = [ { name = "requests" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/90/70/291d494619bb7b0cbcc00689ad995945737c2c9e0bff2733e0aa7dbaee14/aliyun_log_python_sdk-0.9.37.tar.gz", hash = "sha256:ea65c9cca3a7377cef87d568e897820338328a53a7acb1b02f1383910e103f68", size = 152549, upload-time = "2025-11-27T07:56:06.098Z" } +sdist = { url = "https://files.pythonhosted.org/packages/10/44/c77ddc6abc0770318f8c3c59db6711c04cee3507cc4f84b267d46f86ad9f/aliyun_log_python_sdk-0.9.42.tar.gz", hash = "sha256:27d2a857743fa61576947aa16e46cd3a1bab151bf3a5493b32b4e2a995362e29", size = 154460, upload-time = "2026-01-15T03:43:31.811Z" } [[package]] name = "aliyun-python-sdk-core" @@ -368,36 +368,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, ] -[[package]] -name = "anthropic" -version = "0.49.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "distro" }, - { name = "httpx" }, - { name = "jiter" }, - { name = "pydantic" }, - { name = "sniffio" }, - { name = "typing-extensions" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/86/e3/a88c8494ce4d1a88252b9e053607e885f9b14d0a32273d47b727cbee4228/anthropic-0.49.0.tar.gz", hash = "sha256:c09e885b0f674b9119b4f296d8508907f6cff0009bc20d5cf6b35936c40b4398", size = 210016, upload-time = "2025-02-28T19:35:47.01Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/76/74/5d90ad14d55fbe3f9c474fdcb6e34b4bed99e3be8efac98734a5ddce88c1/anthropic-0.49.0-py3-none-any.whl", hash = "sha256:bbc17ad4e7094988d2fa86b87753ded8dce12498f4b85fe5810f208f454a8375", size = 243368, upload-time = "2025-02-28T19:35:44.963Z" }, -] - [[package]] name = "anyio" -version = "4.11.0" +version = "4.12.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "idna" }, - { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/78/7d432127c41b50bccba979505f272c16cbcadcc33645d5fa3a738110ae75/anyio-4.11.0.tar.gz", hash = "sha256:82a8d0b81e318cc5ce71a5f1f8b5c4e63619620b63141ef8c995fa0db95a57c4", size = 219094, upload-time = "2025-09-23T09:19:12.58Z" } +sdist = { url = "https://files.pythonhosted.org/packages/96/f0/5eb65b2bb0d09ac6776f2eb54adee6abe8228ea05b20a5ad0e4945de8aac/anyio-4.12.1.tar.gz", hash = "sha256:41cfcc3a4c85d3f05c932da7c26d0201ac36f72abd4435ba90d0464a3ffed703", size = 228685, upload-time = "2026-01-06T11:45:21.246Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/15/b3/9b1a8074496371342ec1e796a96f99c82c945a339cd81a8e73de28b4cf9e/anyio-4.11.0-py3-none-any.whl", hash = "sha256:0287e96f4d26d4149305414d4e3bc32f0dcd0862365a4bddea19d7a1ec38c4fc", size = 109097, upload-time = "2025-09-23T09:19:10.601Z" }, + { url = "https://files.pythonhosted.org/packages/38/0e/27be9fdef66e72d64c0cdc3cc2823101b80585f8119b5c112c2e8f5f7dab/anyio-4.12.1-py3-none-any.whl", hash = "sha256:d405828884fc140aa80a3c667b8beed277f1dfedec42ba031bd6ac3db606ab6c", size = 113592, upload-time = "2026-01-06T11:45:19.497Z" }, ] [[package]] @@ -411,19 +392,19 @@ wheels = [ [[package]] name = "apscheduler" -version = "3.11.1" +version = "3.11.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzlocal" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d0/81/192db4f8471de5bc1f0d098783decffb1e6e69c4f8b4bc6711094691950b/apscheduler-3.11.1.tar.gz", hash = "sha256:0db77af6400c84d1747fe98a04b8b58f0080c77d11d338c4f507a9752880f221", size = 108044, upload-time = "2025-10-31T18:55:42.819Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/12/3e4389e5920b4c1763390c6d371162f3784f86f85cd6d6c1bfe68eef14e2/apscheduler-3.11.2.tar.gz", hash = "sha256:2a9966b052ec805f020c8c4c3ae6e6a06e24b1bf19f2e11d91d8cca0473eef41", size = 108683, upload-time = "2025-12-22T00:39:34.884Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/9f/d3c76f76c73fcc959d28e9def45b8b1cc3d7722660c5003b19c1022fd7f4/apscheduler-3.11.1-py3-none-any.whl", hash = "sha256:6162cb5683cb09923654fa9bdd3130c4be4bfda6ad8990971c9597ecd52965d2", size = 64278, upload-time = "2025-10-31T18:55:41.186Z" }, + { url = "https://files.pythonhosted.org/packages/9f/64/2e54428beba8d9992aa478bb8f6de9e4ecaa5f8f513bcfd567ed7fb0262d/apscheduler-3.11.2-py3-none-any.whl", hash = "sha256:ce005177f741409db4e4dd40a7431b76feb856b9dd69d57e0da49d6715bfd26d", size = 64439, upload-time = "2025-12-22T00:39:33.303Z" }, ] [[package]] name = "arize-phoenix-otel" -version = "0.9.2" +version = "0.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "openinference-instrumentation" }, @@ -433,19 +414,20 @@ dependencies = [ { name = "opentelemetry-sdk" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, + { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/27/b9/8c89191eb46915e9ba7bdb473e2fb1c510b7db3635ae5ede5e65b2176b9d/arize_phoenix_otel-0.9.2.tar.gz", hash = "sha256:a48c7d41f3ac60dc75b037f036bf3306d2af4af371cdb55e247e67957749bc31", size = 11599, upload-time = "2025-04-14T22:05:28.637Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/f0/b254118db28a2a202573472be67cf61f09cb37912bfde45b27ddc1c5b71f/arize_phoenix_otel-0.15.0.tar.gz", hash = "sha256:56c7dae09aaaa80df9e9595b7384c1bd4054b69b6032ab18e3a110a59b488388", size = 20254, upload-time = "2026-03-02T20:19:04.112Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/3d/f64136a758c649e883315939f30fe51ad0747024b0db05fd78450801a78d/arize_phoenix_otel-0.9.2-py3-none-any.whl", hash = "sha256:5286b33c58b596ef8edd9a4255ee00fd74f774b1e5dbd9393e77e87870a14d76", size = 12560, upload-time = "2025-04-14T22:05:27.162Z" }, + { url = "https://files.pythonhosted.org/packages/e4/4d/70d9c9d7137cc2e2aad819932172ef13ce21b4e60bf258910b9f15e426af/arize_phoenix_otel-0.15.0-py3-none-any.whl", hash = "sha256:5ff4d03b52d2dbd9c2a234417848f6b171cd220dc3c4020cf3568be84b89b88b", size = 17697, upload-time = "2026-03-02T20:19:03.242Z" }, ] [[package]] name = "asgiref" -version = "3.11.0" +version = "3.11.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/76/b9/4db2509eabd14b4a8c71d1b24c8d5734c52b8560a7b1e1a8b56c8d25568b/asgiref-3.11.0.tar.gz", hash = "sha256:13acff32519542a1736223fb79a715acdebe24286d98e8b164a73085f40da2c4", size = 37969, upload-time = "2025-11-19T15:32:20.106Z" } +sdist = { url = "https://files.pythonhosted.org/packages/63/40/f03da1264ae8f7cfdbf9146542e5e7e8100a4c66ab48e791df9a03d3f6c0/asgiref-3.11.1.tar.gz", hash = "sha256:5f184dc43b7e763efe848065441eac62229c9f7b0475f41f80e207a114eda4ce", size = 38550, upload-time = "2026-02-03T13:30:14.33Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/91/be/317c2c55b8bbec407257d45f5c8d1b6867abc76d12043f2d3d58c538a4ea/asgiref-3.11.0-py3-none-any.whl", hash = "sha256:1db9021efadb0d9512ce8ffaf72fcef601c7b73a8807a1bb2ef143dc6b14846d", size = 24096, upload-time = "2025-11-19T15:32:19.004Z" }, + { url = "https://files.pythonhosted.org/packages/5c/0a/a72d10ed65068e115044937873362e6e32fab1b7dce0046aeb224682c989/asgiref-3.11.1-py3-none-any.whl", hash = "sha256:e8667a091e69529631969fd45dc268fa79b99c92c5fcdda727757e52146ec133", size = 24345, upload-time = "2026-02-03T13:30:13.039Z" }, ] [[package]] @@ -459,56 +441,57 @@ wheels = [ [[package]] name = "attrs" -version = "25.4.0" +version = "26.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6b/5c/685e6633917e101e5dcb62b9dd76946cbb57c26e133bae9e0cd36033c0a9/attrs-25.4.0.tar.gz", hash = "sha256:16d5969b87f0859ef33a48b35d55ac1be6e42ae49d5e853b597db70c35c57e11", size = 934251, upload-time = "2025-10-06T13:54:44.725Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9a/8e/82a0fe20a541c03148528be8cac2408564a6c9a0cc7e9171802bc1d26985/attrs-26.1.0.tar.gz", hash = "sha256:d03ceb89cb322a8fd706d4fb91940737b6642aa36998fe130a9bc96c985eff32", size = 952055, upload-time = "2026-03-19T14:22:25.026Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3a/2a/7cc015f5b9f5db42b7d48157e23356022889fc354a2813c15934b7cb5c0e/attrs-25.4.0-py3-none-any.whl", hash = "sha256:adcf7e2a1fb3b36ac48d97835bb6d8ade15b8dcce26aba8bf1d14847b57a3373", size = 67615, upload-time = "2025-10-06T13:54:43.17Z" }, + { url = "https://files.pythonhosted.org/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309", size = 67548, upload-time = "2026-03-19T14:22:23.645Z" }, ] [[package]] name = "authlib" -version = "1.6.6" +version = "1.6.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bb/9b/b1661026ff24bc641b76b78c5222d614776b0c085bcfdac9bd15a1cb4b35/authlib-1.6.6.tar.gz", hash = "sha256:45770e8e056d0f283451d9996fbb59b70d45722b45d854d58f32878d0a40c38e", size = 164894, upload-time = "2025-12-12T08:01:41.464Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/98/00d3dd826d46959ad8e32af2dbb2398868fd9fd0683c26e56d0789bd0e68/authlib-1.6.9.tar.gz", hash = "sha256:d8f2421e7e5980cc1ddb4e32d3f5fa659cfaf60d8eaf3281ebed192e4ab74f04", size = 165134, upload-time = "2026-03-02T07:44:01.998Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/51/321e821856452f7386c4e9df866f196720b1ad0c5ea1623ea7399969ae3b/authlib-1.6.6-py2.py3-none-any.whl", hash = "sha256:7d9e9bc535c13974313a87f53e8430eb6ea3d1cf6ae4f6efcd793f2e949143fd", size = 244005, upload-time = "2025-12-12T08:01:40.209Z" }, + { url = "https://files.pythonhosted.org/packages/53/23/b65f568ed0c22f1efacb744d2db1a33c8068f384b8c9b482b52ebdbc3ef6/authlib-1.6.9-py2.py3-none-any.whl", hash = "sha256:f08b4c14e08f0861dc18a32357b33fbcfd2ea86cfe3fe149484b4d764c4a0ac3", size = 244197, upload-time = "2026-03-02T07:44:00.307Z" }, ] [[package]] name = "azure-core" -version = "1.38.0" +version = "1.39.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/1b/e503e08e755ea94e7d3419c9242315f888fc664211c90d032e40479022bf/azure_core-1.38.0.tar.gz", hash = "sha256:8194d2682245a3e4e3151a667c686464c3786fed7918b394d035bdcd61bb5993", size = 363033, upload-time = "2026-01-12T17:03:05.535Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/83/bbde3faa84ddcb8eb0eca4b3ffb3221252281db4ce351300fe248c5c70b1/azure_core-1.39.0.tar.gz", hash = "sha256:8a90a562998dd44ce84597590fff6249701b98c0e8797c95fcdd695b54c35d74", size = 367531, upload-time = "2026-03-19T01:31:29.461Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/d8/b8fcba9464f02b121f39de2db2bf57f0b216fe11d014513d666e8634380d/azure_core-1.38.0-py3-none-any.whl", hash = "sha256:ab0c9b2cd71fecb1842d52c965c95285d3cfb38902f6766e4a471f1cd8905335", size = 217825, upload-time = "2026-01-12T17:03:07.291Z" }, + { url = "https://files.pythonhosted.org/packages/7e/d6/8ebcd05b01a580f086ac9a97fb9fac65c09a4b012161cc97c21a336e880b/azure_core-1.39.0-py3-none-any.whl", hash = "sha256:4ac7b70fab5438c3f68770649a78daf97833caa83827f91df9c14e0e0ea7d34f", size = 218318, upload-time = "2026-03-19T01:31:31.25Z" }, ] [[package]] name = "azure-identity" -version = "1.16.1" +version = "1.25.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "azure-core" }, { name = "cryptography" }, { name = "msal" }, { name = "msal-extensions" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bb/1c/bd704075e555046e24b069157ca25c81aedb4199c3e0b35acba9243a6ca6/azure-identity-1.16.1.tar.gz", hash = "sha256:6d93f04468f240d59246d8afde3091494a5040d4f141cad0f49fc0c399d0d91e", size = 236726, upload-time = "2024-06-10T22:23:27.46Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c5/0e/3a63efb48aa4a5ae2cfca61ee152fbcb668092134d3eb8bfda472dd5c617/azure_identity-1.25.3.tar.gz", hash = "sha256:ab23c0d63015f50b630ef6c6cf395e7262f439ce06e5d07a64e874c724f8d9e6", size = 286304, upload-time = "2026-03-13T01:12:20.892Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/c5/ca55106564d2044ab90614381368b3756690fb7e3ab04552e17f308e4e4f/azure_identity-1.16.1-py3-none-any.whl", hash = "sha256:8fb07c25642cd4ac422559a8b50d3e77f73dcc2bbfaba419d06d6c9d7cff6726", size = 166741, upload-time = "2024-06-10T22:23:30.906Z" }, + { url = "https://files.pythonhosted.org/packages/49/9a/417b3a533e01953a7c618884df2cb05a71e7b68bdbce4fbdb62349d2a2e8/azure_identity-1.25.3-py3-none-any.whl", hash = "sha256:f4d0b956a8146f30333e071374171f3cfa7bdb8073adb8c3814b65567aa7447c", size = 192138, upload-time = "2026-03-13T01:12:22.951Z" }, ] [[package]] name = "azure-storage-blob" -version = "12.26.0" +version = "12.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "azure-core" }, @@ -516,9 +499,9 @@ dependencies = [ { name = "isodate" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/96/95/3e3414491ce45025a1cde107b6ae72bf72049e6021597c201cd6a3029b9a/azure_storage_blob-12.26.0.tar.gz", hash = "sha256:5dd7d7824224f7de00bfeb032753601c982655173061e242f13be6e26d78d71f", size = 583332, upload-time = "2025-07-16T21:34:07.644Z" } +sdist = { url = "https://files.pythonhosted.org/packages/71/24/072ba8e27b0e2d8fec401e9969b429d4f5fc4c8d4f0f05f4661e11f7234a/azure_storage_blob-12.28.0.tar.gz", hash = "sha256:e7d98ea108258d29aa0efbfd591b2e2075fa1722a2fae8699f0b3c9de11eff41", size = 604225, upload-time = "2026-01-06T23:48:57.282Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/64/63dbfdd83b31200ac58820a7951ddfdeed1fbee9285b0f3eae12d1357155/azure_storage_blob-12.26.0-py3-none-any.whl", hash = "sha256:8c5631b8b22b4f53ec5fff2f3bededf34cfef111e2af613ad42c9e6de00a77fe", size = 412907, upload-time = "2025-07-16T21:34:09.367Z" }, + { url = "https://files.pythonhosted.org/packages/d8/3a/6ef2047a072e54e1142718d433d50e9514c999a58f51abfff7902f3a72f8/azure_storage_blob-12.28.0-py3-none-any.whl", hash = "sha256:00fb1db28bf6a7b7ecaa48e3b1d5c83bfadacc5a678b77826081304bd87d6461", size = 431499, upload-time = "2026-01-06T23:48:58.995Z" }, ] [[package]] @@ -531,38 +514,77 @@ wheels = [ ] [[package]] -name = "backports-tarfile" -version = "1.2.0" +name = "backports-zstd" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/86/72/cd9b395f25e290e633655a100af28cb253e4393396264a98bd5f5951d50f/backports_tarfile-1.2.0.tar.gz", hash = "sha256:d75e02c268746e1b8144c278978b6e98e85de6ad16f8e4b0844a154557eca991", size = 86406, upload-time = "2024-05-28T17:01:54.731Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/b1/36a5182ce1d8ef9ef32bff69037bd28b389bbdb66338f8069e61da7028cb/backports_zstd-1.3.0.tar.gz", hash = "sha256:e8b2d68e2812f5c9970cabc5e21da8b409b5ed04e79b4585dbffa33e9b45ebe2", size = 997138, upload-time = "2025-12-29T17:28:06.143Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/fa/123043af240e49752f1c4bd24da5053b6bd00cad78c2be53c0d1e8b975bc/backports.tarfile-1.2.0-py3-none-any.whl", hash = "sha256:77e284d754527b01fb1e6fa8a1afe577858ebe4e9dad8919e34c862cb399bc34", size = 30181, upload-time = "2024-05-28T17:01:53.112Z" }, + { url = "https://files.pythonhosted.org/packages/ac/28/ed31a0e35feb4538a996348362051b52912d50f00d25c2d388eccef9242c/backports_zstd-1.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:249f90b39d3741c48620021a968b35f268ca70e35f555abeea9ff95a451f35f9", size = 435660, upload-time = "2025-12-29T17:25:55.207Z" }, + { url = "https://files.pythonhosted.org/packages/00/0d/3db362169d80442adda9dd563c4f0bb10091c8c1c9a158037f4ecd53988e/backports_zstd-1.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b0e71e83e46154a9d3ced6d4de9a2fea8207ee1e4832aeecf364dc125eda305c", size = 362056, upload-time = "2025-12-29T17:25:56.729Z" }, + { url = "https://files.pythonhosted.org/packages/bd/00/b67ba053a7d6f6dbe2f8a704b7d3a5e01b1d2e2e8edbc9b634f2702ef73c/backports_zstd-1.3.0-cp311-cp311-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:cbc6193acd21f96760c94dd71bf32b161223e8503f5277acb0a5ab54e5598957", size = 505957, upload-time = "2025-12-29T17:25:57.941Z" }, + { url = "https://files.pythonhosted.org/packages/6f/3e/2667c0ddb53ddf28667e330bf9fe92e8e17705a481c9b698e283120565f7/backports_zstd-1.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1df583adc0ae84a8d13d7139f42eade6d90182b1dd3e0d28f7df3c564b9fd55d", size = 475569, upload-time = "2025-12-29T17:25:59.075Z" }, + { url = "https://files.pythonhosted.org/packages/eb/86/4052473217bd954ccdffda5f7264a0e99e7c4ecf70c0f729845c6a45fc5a/backports_zstd-1.3.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d833fc23aa3cc2e05aeffc7cfadd87b796654ad3a7fb214555cda3f1db2d4dc2", size = 581196, upload-time = "2025-12-29T17:26:00.508Z" }, + { url = "https://files.pythonhosted.org/packages/e5/bd/064f6fdb61db3d2c473159ebc844243e650dc032de0f8208443a00127925/backports_zstd-1.3.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:142178fe981061f1d2a57c5348f2cd31a3b6397a35593e7a17dbda817b793a7f", size = 640888, upload-time = "2025-12-29T17:26:02.134Z" }, + { url = "https://files.pythonhosted.org/packages/d8/09/0822403f40932a165a4f1df289d41653683019e4fd7a86b63ed20e9b6177/backports_zstd-1.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5eed0a09a163f3a8125a857cb031be87ed052e4a47bc75085ed7fca786e9bb5b", size = 491100, upload-time = "2025-12-29T17:26:03.418Z" }, + { url = "https://files.pythonhosted.org/packages/a6/a3/f5ac28d74039b7e182a780809dc66b9dbfc893186f5d5444340bba135389/backports_zstd-1.3.0-cp311-cp311-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:60aa483fef5843749e993dde01229e5eedebca8c283023d27d6bf6800d1d4ce3", size = 565071, upload-time = "2025-12-29T17:26:05.022Z" }, + { url = "https://files.pythonhosted.org/packages/e1/ac/50209aeb92257a642ee987afa1e61d5b6731ab6bf0bff70905856e5aede6/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ea0886c1b619773544546e243ed73f6d6c2b1ae3c00c904ccc9903a352d731e1", size = 481519, upload-time = "2025-12-29T17:26:06.255Z" }, + { url = "https://files.pythonhosted.org/packages/08/1f/b06f64199fb4b2e9437cedbf96d0155ca08aeec35fe81d41065acd44762e/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5e137657c830a5ce99be40a1d713eb1d246bae488ada28ff0666ac4387aebdd5", size = 509465, upload-time = "2025-12-29T17:26:07.602Z" }, + { url = "https://files.pythonhosted.org/packages/f4/37/2c365196e61c8fffbbc930ffd69f1ada7aa1c7210857b3e565031c787ac6/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:94048c8089755e482e4b34608029cf1142523a625873c272be2b1c9253871a72", size = 585552, upload-time = "2025-12-29T17:26:08.911Z" }, + { url = "https://files.pythonhosted.org/packages/93/8d/c2c4f448bb6b6c9df17410eaedce415e8db0eb25b60d09a3d22a98294d09/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:d339c1ec40485e97e600eb9a285fb13169dbf44c5094b945788a62f38b96e533", size = 562893, upload-time = "2025-12-29T17:26:10.566Z" }, + { url = "https://files.pythonhosted.org/packages/74/e8/2110d4d39115130f7514cbbcec673a885f4052bb68d15e41bc96a7558856/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:8aeee9210c54cf8bf83f4d263a6d0d6e7a0298aeb5a14a0a95e90487c5c3157c", size = 631462, upload-time = "2025-12-29T17:26:11.99Z" }, + { url = "https://files.pythonhosted.org/packages/b9/a8/d64b59ae0714fdace14e43873f794eff93613e35e3e85eead33a4f44cd80/backports_zstd-1.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ba7114a3099e5ea05cbb46568bd0e08bca2ca11e12c6a7b563a24b86b2b4a67f", size = 495125, upload-time = "2025-12-29T17:26:13.218Z" }, + { url = "https://files.pythonhosted.org/packages/ef/d8/bcff0a091fcf27172c57ae463e49d8dec6dc31e01d7e7bf1ae3aad9c3566/backports_zstd-1.3.0-cp311-cp311-win32.whl", hash = "sha256:08dfdfb85da5915383bfae680b6ac10ab5769ab22e690f9a854320720011ae8e", size = 288664, upload-time = "2025-12-29T17:26:14.791Z" }, + { url = "https://files.pythonhosted.org/packages/28/1a/379061e2abf8c3150ad51c1baab9ac723e01cf7538860a6a74c48f8b73ee/backports_zstd-1.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:d8aac2e7cdcc8f310c16f98a0062b48d0a081dbb82862794f4f4f5bdafde30a4", size = 313633, upload-time = "2025-12-29T17:26:16.31Z" }, + { url = "https://files.pythonhosted.org/packages/35/e7/eca40858883029fc716660106069b23253e2ec5fd34e86b4101c8cfe864b/backports_zstd-1.3.0-cp311-cp311-win_arm64.whl", hash = "sha256:440ef1be06e82dc0d69dbb57177f2ce98bbd2151013ee7e551e2f2b54caa6120", size = 288814, upload-time = "2025-12-29T17:26:17.571Z" }, + { url = "https://files.pythonhosted.org/packages/72/d4/356da49d3053f4bc50e71a8535631b57bc9ca4e8c6d2442e073e0ab41c44/backports_zstd-1.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f4a292e357f3046d18766ce06d990ccbab97411708d3acb934e63529c2ea7786", size = 435972, upload-time = "2025-12-29T17:26:18.752Z" }, + { url = "https://files.pythonhosted.org/packages/30/8f/dbe389e60c7e47af488520f31a4aa14028d66da5bf3c60d3044b571eb906/backports_zstd-1.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fb4c386f38323698991b38edcc9c091d46d4713f5df02a3b5c80a28b40e289ea", size = 362124, upload-time = "2025-12-29T17:26:19.995Z" }, + { url = "https://files.pythonhosted.org/packages/55/4b/173beafc99e99e7276ce008ef060b704471e75124c826bc5e2092815da37/backports_zstd-1.3.0-cp312-cp312-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:f52523d2bdada29e653261abdc9cfcecd9e5500d305708b7e37caddb24909d4e", size = 506378, upload-time = "2025-12-29T17:26:21.855Z" }, + { url = "https://files.pythonhosted.org/packages/df/c8/3f12a411d9a99d262cdb37b521025eecc2aa7e4a93277be3f4f4889adb74/backports_zstd-1.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3321d00beaacbd647252a7f581c1e1cdbdbda2407f2addce4bfb10e8e404b7c7", size = 476201, upload-time = "2025-12-29T17:26:23.047Z" }, + { url = "https://files.pythonhosted.org/packages/43/dc/73c090e4a2d5671422512e1b6d276ca6ea0cc0c45ec4634789106adc0d66/backports_zstd-1.3.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:88f94d238ef36c639c0ae17cf41054ce103da9c4d399c6a778ce82690d9f4919", size = 581659, upload-time = "2025-12-29T17:26:24.189Z" }, + { url = "https://files.pythonhosted.org/packages/08/4f/11bfcef534aa2bf3f476f52130217b45337f334d8a287edb2e06744a6515/backports_zstd-1.3.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:97d8c78fe20c7442c810adccfd5e3ea6a4e6f4f1fa4c73da2bc083260ebead17", size = 640388, upload-time = "2025-12-29T17:26:25.47Z" }, + { url = "https://files.pythonhosted.org/packages/71/17/8faea426d4f49b63238bdfd9f211a9f01c862efe0d756d3abeb84265a4e2/backports_zstd-1.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eefda80c3dbfbd924f1c317e7b0543d39304ee645583cb58bae29e19f42948ed", size = 494173, upload-time = "2025-12-29T17:26:26.736Z" }, + { url = "https://files.pythonhosted.org/packages/ba/9d/901f19ac90f3cd999bdcfb6edb4d7b4dc383dfba537f06f533fc9ac4777b/backports_zstd-1.3.0-cp312-cp312-manylinux_2_34_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:2ab5d3b5a54a674f4f6367bb9e0914063f22cd102323876135e9cc7a8f14f17e", size = 568628, upload-time = "2025-12-29T17:26:28.12Z" }, + { url = "https://files.pythonhosted.org/packages/60/39/4d29788590c2465a570c2fae49dbff05741d1f0c8e4a0fb2c1c310f31804/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7558fb0e8c8197c59a5f80c56bf8f56c3690c45fd62f14e9e2081661556e3e64", size = 482233, upload-time = "2025-12-29T17:26:29.399Z" }, + { url = "https://files.pythonhosted.org/packages/d9/4b/24c7c9e8ef384b19d515a7b1644a500ceb3da3baeff6d579687da1a0f62b/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:27744870e38f017159b9c0241ea51562f94c7fefcfa4c5190fb3ec4a65a7fc63", size = 509806, upload-time = "2025-12-29T17:26:30.605Z" }, + { url = "https://files.pythonhosted.org/packages/3f/7e/7ba1aeecf0b5859f1855c0e661b4559566b64000f0627698ebd9e83f2138/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b099750755bb74c280827c7d68de621da0f245189082ab48ff91bda0ec2db9df", size = 586037, upload-time = "2025-12-29T17:26:32.201Z" }, + { url = "https://files.pythonhosted.org/packages/4a/1a/18f0402b36b9cfb0aea010b5df900cfd42c214f37493561dba3abac90c4e/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:5434e86f2836d453ae3e19a2711449683b7e21e107686838d12a255ad256ca99", size = 566220, upload-time = "2025-12-29T17:26:33.5Z" }, + { url = "https://files.pythonhosted.org/packages/dc/d9/44c098ab31b948bbfd909ec4ae08e1e44c5025a2d846f62991a62ab3ebea/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:407e451f64e2f357c9218f5be4e372bb6102d7ae88582d415262a9d0a4f9b625", size = 630847, upload-time = "2025-12-29T17:26:35.273Z" }, + { url = "https://files.pythonhosted.org/packages/30/33/e74cb2cfb162d2e9e00dad8bcdf53118ca7786cfd467925d6864732f79cc/backports_zstd-1.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:58a071f3c198c781b2df801070290b7174e3ff61875454e9df93ab7ea9ea832b", size = 498665, upload-time = "2025-12-29T17:26:37.123Z" }, + { url = "https://files.pythonhosted.org/packages/a2/a9/67a24007c333ed22736d5cd79f1aa1d7209f09be772ff82a8fd724c1978e/backports_zstd-1.3.0-cp312-cp312-win32.whl", hash = "sha256:21a9a542ccc7958ddb51ae6e46d8ed25d585b54d0d52aaa1c8da431ea158046a", size = 288809, upload-time = "2025-12-29T17:26:38.373Z" }, + { url = "https://files.pythonhosted.org/packages/42/24/34b816118ea913debb2ea23e71ffd0fb2e2ac738064c4ac32e3fb62c18bb/backports_zstd-1.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:89ea8281821123b071a06b30b80da8e4d8a2b40a4f57315a19850337a21297ac", size = 313815, upload-time = "2025-12-29T17:26:39.665Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2f/babd02c9fc4ca35376ada7c291193a208165c7be2455f0f98bc1e1243f31/backports_zstd-1.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:f6843ecb181480e423b02f60fe29e393cbc31a95fb532acdf0d3a2c87bd50ce3", size = 288927, upload-time = "2025-12-29T17:26:40.923Z" }, + { url = "https://files.pythonhosted.org/packages/9a/d9/8c9c246e5ea79a4f45d551088b11b61f2dc7efcdc5dbe6df3be84a506e0c/backports_zstd-1.3.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:968167d29f012cee7b112ad031a8925e484e97e99288e55e4d62962c3a1013e3", size = 409666, upload-time = "2025-12-29T17:27:57.37Z" }, + { url = "https://files.pythonhosted.org/packages/a4/4f/a55b33c314ca8c9074e99daab54d04c5d212070ae7dbc435329baf1b139e/backports_zstd-1.3.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:d8f6fc7d62b71083b574193dd8fb3a60e6bb34880cc0132aad242943af301f7a", size = 339199, upload-time = "2025-12-29T17:27:58.542Z" }, + { url = "https://files.pythonhosted.org/packages/9d/13/ce31bd048b1c88d0f65d7af60b6cf89cfbed826c7c978f0ebca9a8a71cfc/backports_zstd-1.3.0-pp311-pypy311_pp73-manylinux2010_i686.manylinux_2_12_i686.manylinux_2_28_i686.whl", hash = "sha256:e0f2eca6aac280fdb77991ad3362487ee91a7fb064ad40043fb5a0bf5a376943", size = 420332, upload-time = "2025-12-29T17:28:00.332Z" }, + { url = "https://files.pythonhosted.org/packages/cf/80/c0cdbc533d0037b57248588403a3afb050b2a83b8c38aa608e31b3a4d600/backports_zstd-1.3.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:676eb5e177d4ef528cf3baaeea4fffe05f664e4dd985d3ac06960ef4619c81a9", size = 393879, upload-time = "2025-12-29T17:28:01.57Z" }, + { url = "https://files.pythonhosted.org/packages/0f/38/c97428867cac058ed196ccaeddfdf82ecd43b8a65965f2950a6e7547e77a/backports_zstd-1.3.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:199eb9bd8aca6a9d489c41a682fad22c587dffe57b613d0fe6d492d0d38ce7c5", size = 413842, upload-time = "2025-12-29T17:28:03.113Z" }, + { url = "https://files.pythonhosted.org/packages/8d/ec/6247be6536668fe1c7dfae3eaa9c94b00b956b716957c0fc986ba78c3cc4/backports_zstd-1.3.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:2524bd6777a828d5e7ccd7bd1a57f9e7007ae654fc2bd1bc1a207f6428674e4a", size = 299684, upload-time = "2025-12-29T17:28:04.856Z" }, ] [[package]] name = "basedpyright" -version = "1.31.7" +version = "1.38.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/ba/ed69e8df732a09c8ca469f592c8e08707fe29149735b834c276d94d4a3da/basedpyright-1.31.7.tar.gz", hash = "sha256:394f334c742a19bcc5905b2455c9f5858182866b7679a6f057a70b44b049bceb", size = 22710948, upload-time = "2025-10-11T05:12:48.3Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/58/7abba2c743571a42b2548f07aee556ebc1e4d0bc2b277aeba1ee6c83b0af/basedpyright-1.38.3.tar.gz", hash = "sha256:9725419786afbfad8a9539527f162da02d462afad440b0412fdb3f3cdf179b90", size = 25277430, upload-time = "2026-03-17T13:10:41.526Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f8/90/ce01ad2d0afdc1b82b8b5aaba27e60d2e138e39d887e71c35c55d8f1bfcd/basedpyright-1.31.7-py3-none-any.whl", hash = "sha256:7c54beb7828c9ed0028630aaa6904f395c27e5a9f5a313aa9e91fc1d11170831", size = 11817571, upload-time = "2025-10-11T05:12:45.432Z" }, + { url = "https://files.pythonhosted.org/packages/2c/e3/3ebb5c23bd3abb5fc2053b8a06a889aa5c1cf8cff738c78cb6c1957e90cd/basedpyright-1.38.3-py3-none-any.whl", hash = "sha256:1f15c2e489c67d6c5e896c24b6a63251195c04223a55e4568b8f8e8ed49ca830", size = 12313363, upload-time = "2026-03-17T13:10:47.344Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.53" +version = "0.9.64" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/8d/85ec18ca2dba624cb5932bda74e926c346a7a6403a628aeda45d848edb48/bce_python_sdk-0.9.53.tar.gz", hash = "sha256:fb14b09d1064a6987025648589c8245cb7e404acd38bb900f0775f396e3d9b3e", size = 275594, upload-time = "2025-11-21T03:48:58.869Z" } +sdist = { url = "https://files.pythonhosted.org/packages/61/33/047e9c1a6c97e0cd4d93a6490abd8fbc2ccd13569462fc0228699edc08bc/bce_python_sdk-0.9.64.tar.gz", hash = "sha256:901bf787c26ad35855a80d65e58d7584c8541f7f0f2af20847830e572e5b622e", size = 287125, upload-time = "2026-03-17T11:24:29.345Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/e9/6fc142b5ac5b2e544bc155757dc28eee2b22a576ca9eaf968ac033b6dc45/bce_python_sdk-0.9.53-py3-none-any.whl", hash = "sha256:00fc46b0ff8d1700911aef82b7263533c52a63b1cc5a51449c4f715a116846a7", size = 390434, upload-time = "2025-11-21T03:48:57.201Z" }, + { url = "https://files.pythonhosted.org/packages/48/7f/dd289582f37ab4effea47b2a8503880db4781ca0fc8e0a8ed5ff493359e5/bce_python_sdk-0.9.64-py3-none-any.whl", hash = "sha256:eaad97e4f0e7d613ae978da3cdc5294e9f724ffca2735f79820037fa1317cd6d", size = 402233, upload-time = "2026-03-17T11:24:24.673Z" }, ] [[package]] @@ -609,50 +631,36 @@ wheels = [ [[package]] name = "beautifulsoup4" -version = "4.12.2" +version = "4.14.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "soupsieve" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/0b/44c39cf3b18a9280950ad63a579ce395dda4c32193ee9da7ff0aed547094/beautifulsoup4-4.12.2.tar.gz", hash = "sha256:492bbc69dca35d12daac71c4db1bfff0c876c00ef4a2ffacce226d4638eb72da", size = 505113, upload-time = "2023-04-07T15:02:49.038Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/b0/1c6a16426d389813b48d95e26898aff79abbde42ad353958ad95cc8c9b21/beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86", size = 627737, upload-time = "2025-11-30T15:08:26.084Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl", hash = "sha256:bd2520ca0d9d7d12694a53d44ac482d181b4ec1888909b035a3dbf40d0f57d4a", size = 142979, upload-time = "2023-04-07T15:02:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" }, ] [[package]] name = "billiard" -version = "4.2.3" +version = "4.2.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/50/cc2b8b6e6433918a6b9a3566483b743dcd229da1e974be9b5f259db3aad7/billiard-4.2.3.tar.gz", hash = "sha256:96486f0885afc38219d02d5f0ccd5bec8226a414b834ab244008cbb0025b8dcb", size = 156450, upload-time = "2025-11-16T17:47:30.281Z" } +sdist = { url = "https://files.pythonhosted.org/packages/58/23/b12ac0bcdfb7360d664f40a00b1bda139cbbbced012c34e375506dbd0143/billiard-4.2.4.tar.gz", hash = "sha256:55f542c371209e03cd5862299b74e52e4fbcba8250ba611ad94276b369b6a85f", size = 156537, upload-time = "2025-11-30T13:28:48.52Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/cc/38b6f87170908bd8aaf9e412b021d17e85f690abe00edf50192f1a4566b9/billiard-4.2.3-py3-none-any.whl", hash = "sha256:989e9b688e3abf153f307b68a1328dfacfb954e30a4f920005654e276c69236b", size = 87042, upload-time = "2025-11-16T17:47:29.005Z" }, + { url = "https://files.pythonhosted.org/packages/cb/87/8bab77b323f16d67be364031220069f79159117dd5e43eeb4be2fef1ac9b/billiard-4.2.4-py3-none-any.whl", hash = "sha256:525b42bdec68d2b983347ac312f892db930858495db601b5836ac24e6477cde5", size = 87070, upload-time = "2025-11-30T13:28:47.016Z" }, ] [[package]] -name = "black" -version = "25.12.0" +name = "bleach" +version = "6.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "click" }, - { name = "mypy-extensions" }, - { name = "packaging" }, - { name = "pathspec" }, - { name = "platformdirs" }, - { name = "pytokens" }, + { name = "webencodings" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c4/d9/07b458a3f1c525ac392b5edc6b191ff140b596f9d77092429417a54e249d/black-25.12.0.tar.gz", hash = "sha256:8d3dd9cea14bff7ddc0eb243c811cdb1a011ebb4800a5f0335a01a68654796a7", size = 659264, upload-time = "2025-12-08T01:40:52.501Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/18/3c8523962314be6bf4c8989c79ad9531c825210dd13a8669f6b84336e8bd/bleach-6.3.0.tar.gz", hash = "sha256:6f3b91b1c0a02bb9a78b5a454c92506aa0fdf197e1d5e114d2e00c6f64306d22", size = 203533, upload-time = "2025-10-27T17:57:39.211Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/ad/7ac0d0e1e0612788dbc48e62aef8a8e8feffac7eb3d787db4e43b8462fa8/black-25.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d0cfa263e85caea2cff57d8f917f9f51adae8e20b610e2b23de35b5b11ce691a", size = 1877003, upload-time = "2025-12-08T01:43:29.967Z" }, - { url = "https://files.pythonhosted.org/packages/e8/dd/a237e9f565f3617a88b49284b59cbca2a4f56ebe68676c1aad0ce36a54a7/black-25.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1a2f578ae20c19c50a382286ba78bfbeafdf788579b053d8e4980afb079ab9be", size = 1712639, upload-time = "2025-12-08T01:52:46.756Z" }, - { url = "https://files.pythonhosted.org/packages/12/80/e187079df1ea4c12a0c63282ddd8b81d5107db6d642f7d7b75a6bcd6fc21/black-25.12.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d3e1b65634b0e471d07ff86ec338819e2ef860689859ef4501ab7ac290431f9b", size = 1758143, upload-time = "2025-12-08T01:45:29.137Z" }, - { url = "https://files.pythonhosted.org/packages/93/b5/3096ccee4f29dc2c3aac57274326c4d2d929a77e629f695f544e159bfae4/black-25.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:a3fa71e3b8dd9f7c6ac4d818345237dfb4175ed3bf37cd5a581dbc4c034f1ec5", size = 1420698, upload-time = "2025-12-08T01:45:53.379Z" }, - { url = "https://files.pythonhosted.org/packages/7e/39/f81c0ffbc25ffbe61c7d0385bf277e62ffc3e52f5ee668d7369d9854fadf/black-25.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:51e267458f7e650afed8445dc7edb3187143003d52a1b710c7321aef22aa9655", size = 1229317, upload-time = "2025-12-08T01:46:35.606Z" }, - { url = "https://files.pythonhosted.org/packages/d1/bd/26083f805115db17fda9877b3c7321d08c647df39d0df4c4ca8f8450593e/black-25.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:31f96b7c98c1ddaeb07dc0f56c652e25bdedaac76d5b68a059d998b57c55594a", size = 1924178, upload-time = "2025-12-08T01:49:51.048Z" }, - { url = "https://files.pythonhosted.org/packages/89/6b/ea00d6651561e2bdd9231c4177f4f2ae19cc13a0b0574f47602a7519b6ca/black-25.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:05dd459a19e218078a1f98178c13f861fe6a9a5f88fc969ca4d9b49eb1809783", size = 1742643, upload-time = "2025-12-08T01:49:59.09Z" }, - { url = "https://files.pythonhosted.org/packages/6d/f3/360fa4182e36e9875fabcf3a9717db9d27a8d11870f21cff97725c54f35b/black-25.12.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c1f68c5eff61f226934be6b5b80296cf6939e5d2f0c2f7d543ea08b204bfaf59", size = 1800158, upload-time = "2025-12-08T01:44:27.301Z" }, - { url = "https://files.pythonhosted.org/packages/f8/08/2c64830cb6616278067e040acca21d4f79727b23077633953081c9445d61/black-25.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:274f940c147ddab4442d316b27f9e332ca586d39c85ecf59ebdea82cc9ee8892", size = 1426197, upload-time = "2025-12-08T01:45:51.198Z" }, - { url = "https://files.pythonhosted.org/packages/d4/60/a93f55fd9b9816b7432cf6842f0e3000fdd5b7869492a04b9011a133ee37/black-25.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:169506ba91ef21e2e0591563deda7f00030cb466e747c4b09cb0a9dae5db2f43", size = 1237266, upload-time = "2025-12-08T01:45:10.556Z" }, - { url = "https://files.pythonhosted.org/packages/68/11/21331aed19145a952ad28fca2756a1433ee9308079bd03bd898e903a2e53/black-25.12.0-py3-none-any.whl", hash = "sha256:48ceb36c16dbc84062740049eef990bb2ce07598272e673c17d1a7720c71c828", size = 206191, upload-time = "2025-12-08T01:40:50.963Z" }, + { url = "https://files.pythonhosted.org/packages/cd/3a/577b549de0cc09d95f11087ee63c739bba856cd3952697eec4c4bb91350a/bleach-6.3.0-py3-none-any.whl", hash = "sha256:fe10ec77c93ddf3d13a73b035abaac7a9f5e436513864ccdad516693213c65d6", size = 164437, upload-time = "2025-10-27T17:57:37.538Z" }, ] [[package]] @@ -664,32 +672,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc", size = 8458, upload-time = "2024-11-08T17:25:46.184Z" }, ] +[[package]] +name = "blis" +version = "1.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/d0/d8cc8c9a4488a787e7fa430f6055e5bd1ddb22c340a751d9e901b82e2efe/blis-1.3.3.tar.gz", hash = "sha256:034d4560ff3cc43e8aa37e188451b0440e3261d989bb8a42ceee865607715ecd", size = 2644873, upload-time = "2025-11-17T12:28:30.511Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/0a/a4c8736bc497d386b0ffc76d321f478c03f1a4725e52092f93b38beb3786/blis-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e10c8d3e892b1dbdff365b9d00e08291876fc336915bf1a5e9f188ed087e1a91", size = 6925522, upload-time = "2025-11-17T12:27:29.199Z" }, + { url = "https://files.pythonhosted.org/packages/83/5a/3437009282f23684ecd3963a8b034f9307cdd2bf4484972e5a6b096bf9ac/blis-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:66e6249564f1db22e8af1e0513ff64134041fa7e03c8dd73df74db3f4d8415a7", size = 1232787, upload-time = "2025-11-17T12:27:30.996Z" }, + { url = "https://files.pythonhosted.org/packages/d1/0e/82221910d16259ce3017c1442c468a3f206a4143a96fbba9f5b5b81d62e8/blis-1.3.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7260da065958b4e5475f62f44895ef9d673b0f47dcf61b672b22b7dae1a18505", size = 2844596, upload-time = "2025-11-17T12:27:32.601Z" }, + { url = "https://files.pythonhosted.org/packages/6c/93/ab547f1a5c23e20bca16fbcf04021c32aac3f969be737ea4980509a7ca90/blis-1.3.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9327a6ca67de8ae76fe071e8584cc7f3b2e8bfadece4961d40f2826e1cda2df", size = 11377746, upload-time = "2025-11-17T12:27:35.342Z" }, + { url = "https://files.pythonhosted.org/packages/6e/a6/7733820aa62da32526287a63cd85c103b2b323b186c8ee43b7772ff7017c/blis-1.3.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c4ae70629cf302035d268858a10ca4eb6242a01b2dc8d64422f8e6dcb8a8ee74", size = 3041954, upload-time = "2025-11-17T12:27:37.479Z" }, + { url = "https://files.pythonhosted.org/packages/87/53/e39d67fd3296b649772780ca6aab081412838ecb54e0b0c6432d01626a50/blis-1.3.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45866a9027d43b93e8b59980a23c5d7358b6536fc04606286e39fdcfce1101c2", size = 14251222, upload-time = "2025-11-17T12:27:39.705Z" }, + { url = "https://files.pythonhosted.org/packages/ea/44/b749f8777b020b420bceaaf60f66432fc30cc904ca5b69640ec9cbef11ed/blis-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:27f82b8633030f8d095d2b412dffa7eb6dbc8ee43813139909a20012e54422ea", size = 6171233, upload-time = "2025-11-17T12:27:41.921Z" }, + { url = "https://files.pythonhosted.org/packages/16/d1/429cf0cf693d4c7dc2efed969bd474e315aab636e4a95f66c4ed7264912d/blis-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2a1c74e100665f8e918ebdbae2794576adf1f691680b5cdb8b29578432f623ef", size = 6929663, upload-time = "2025-11-17T12:27:44.482Z" }, + { url = "https://files.pythonhosted.org/packages/11/69/363c8df8d98b3cc97be19aad6aabb2c9c53f372490d79316bdee92d476e7/blis-1.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3f6c595185176ce021316263e1a1d636a3425b6c48366c1fd712d08d0b71849a", size = 1230939, upload-time = "2025-11-17T12:27:46.19Z" }, + { url = "https://files.pythonhosted.org/packages/96/2a/fbf65d906d823d839076c5150a6f8eb5ecbc5f9135e0b6510609bda1e6b7/blis-1.3.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d734b19fba0be7944f272dfa7b443b37c61f9476d9ab054a9ac53555ceadd2e0", size = 2818835, upload-time = "2025-11-17T12:27:48.167Z" }, + { url = "https://files.pythonhosted.org/packages/d5/ad/58deaa3ad856dd3cc96493e40ffd2ed043d18d4d304f85a65cde1ccbf644/blis-1.3.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1ef6d6e2b599a3a2788eb6d9b443533961265aa4ec49d574ed4bb846e548dcdb", size = 11366550, upload-time = "2025-11-17T12:27:49.958Z" }, + { url = "https://files.pythonhosted.org/packages/78/82/816a7adfe1f7acc8151f01ec86ef64467a3c833932d8f19f8e06613b8a4e/blis-1.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8c888438ae99c500422d50698e3028b65caa8ebb44e24204d87fda2df64058f7", size = 3023686, upload-time = "2025-11-17T12:27:52.062Z" }, + { url = "https://files.pythonhosted.org/packages/1e/e2/0e93b865f648b5519360846669a35f28ee8f4e1d93d054f6850d8afbabde/blis-1.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8177879fd3590b5eecdd377f9deafb5dc8af6d684f065bd01553302fb3fcf9a7", size = 14250939, upload-time = "2025-11-17T12:27:53.847Z" }, + { url = "https://files.pythonhosted.org/packages/20/07/fb43edc2ff0a6a367e4a94fc39eb3b85aa1e55e24cc857af2db145ce9f0d/blis-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:f20f7ad69aaffd1ce14fe77de557b6df9b61e0c9e582f75a843715d836b5c8af", size = 6192759, upload-time = "2025-11-17T12:27:56.176Z" }, +] + [[package]] name = "boto3" -version = "1.35.99" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f7/99/3e8b48f15580672eda20f33439fc1622bd611f6238b6d05407320e1fb98c/boto3-1.35.99.tar.gz", hash = "sha256:e0abd794a7a591d90558e92e29a9f8837d25ece8e3c120e530526fe27eba5fca", size = 111028, upload-time = "2025-01-14T20:20:28.636Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e4/8b/d00575be514744ca4839e7d85bf4a8a3c7b6b4574433291e58d14c68ae09/boto3-1.42.73.tar.gz", hash = "sha256:d37b58d6cd452ca808dd6823ae19ca65b6244096c5125ef9052988b337298bae", size = 112775, upload-time = "2026-03-20T19:39:52.814Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/77/8bbca82f70b062181cf0ae53fd43f1ac6556f3078884bfef9da2269c06a3/boto3-1.35.99-py3-none-any.whl", hash = "sha256:83e560faaec38a956dfb3d62e05e1703ee50432b45b788c09e25107c5058bd71", size = 139178, upload-time = "2025-01-14T20:20:25.48Z" }, + { url = "https://files.pythonhosted.org/packages/aa/05/1fcf03d90abaa3d0b42a6bfd10231dd709493ecbacf794aa2eea5eae6841/boto3-1.42.73-py3-none-any.whl", hash = "sha256:1f81b79b873f130eeab14bb556417a7c66d38f3396b7f2fe3b958b3f9094f455", size = 140556, upload-time = "2026-03-20T19:39:50.298Z" }, ] [[package]] name = "boto3-stubs" -version = "1.41.3" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/5b/6d274aa25f7fa09f8b7defab5cb9389e6496a7d9b76c1efcf27b0b15e868/boto3_stubs-1.41.3.tar.gz", hash = "sha256:c7cc9706ac969c8ea284c2d45ec45b6371745666d087c6c5e7c9d39dafdd48bc", size = 100010, upload-time = "2025-11-24T20:34:27.052Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/c3/fcc47102c63278af25ad57c93d97dc393f4dbc54c0117a29c78f2b96ec1e/boto3_stubs-1.42.73.tar.gz", hash = "sha256:36f625769b5505c4bc627f16244b98de9e10dae3ac36f1aa0f0ebe2f201dc138", size = 101373, upload-time = "2026-03-20T19:59:51.463Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/d6/ef971013d1fc7333c6df322d98ebf4592df9c80e1966fb12732f91e9e71b/boto3_stubs-1.41.3-py3-none-any.whl", hash = "sha256:bec698419b31b499f3740f1dfb6dae6519167d9e3aa536f6f730ed280556230b", size = 69294, upload-time = "2025-11-24T20:34:23.1Z" }, + { url = "https://files.pythonhosted.org/packages/4b/57/d570ba61a2a0c7fe0c8667e41269a0480293cb53e1786d6661a2bd827fc5/boto3_stubs-1.42.73-py3-none-any.whl", hash = "sha256:bd658429069d8215247fc3abc003220cd875c24ab6eda7b3405090408afaacdf", size = 70009, upload-time = "2026-03-20T19:59:43.786Z" }, ] [package.optional-dependencies] @@ -699,28 +732,28 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.35.99" +version = "1.42.73" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7c/9c/1df6deceee17c88f7170bad8325aa91452529d683486273928eecfd946d8/botocore-1.35.99.tar.gz", hash = "sha256:1eab44e969c39c5f3d9a3104a0836c24715579a455f12b3979a31d7cde51b3c3", size = 13490969, upload-time = "2025-01-14T20:20:11.419Z" } +sdist = { url = "https://files.pythonhosted.org/packages/28/23/0c88ca116ef63b1ae77c901cd5d2095d22a8dbde9e80df74545db4a061b4/botocore-1.42.73.tar.gz", hash = "sha256:575858641e4949aaf2af1ced145b8524529edf006d075877af6b82ff96ad854c", size = 15008008, upload-time = "2026-03-20T19:39:40.082Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fc/dd/d87e2a145fad9e08d0ec6edcf9d71f838ccc7acdd919acc4c0d4a93515f8/botocore-1.35.99-py3-none-any.whl", hash = "sha256:b22d27b6b617fc2d7342090d6129000af2efd20174215948c0d7ae2da0fab445", size = 13293216, upload-time = "2025-01-14T20:20:06.427Z" }, + { url = "https://files.pythonhosted.org/packages/8e/65/971f3d55015f4d133a6ff3ad74cd39f4b8dd8f53f7775a3c2ad378ea5145/botocore-1.42.73-py3-none-any.whl", hash = "sha256:7b62e2a12f7a1b08eb7360eecd23bb16fe3b7ab7f5617cf91b25476c6f86a0fe", size = 14681861, upload-time = "2026-03-20T19:39:35.341Z" }, ] [[package]] name = "botocore-stubs" -version = "1.41.3" +version = "1.42.41" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-awscrt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ec/8f/a42c3ae68d0b9916f6e067546d73e9a24a6af8793999a742e7af0b7bffa2/botocore_stubs-1.41.3.tar.gz", hash = "sha256:bacd1647cd95259aa8fc4ccdb5b1b3893f495270c120cda0d7d210e0ae6a4170", size = 42404, upload-time = "2025-11-24T20:29:27.47Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/a8/a26608ff39e3a5866c6c79eda10133490205cbddd45074190becece3ff2a/botocore_stubs-1.42.41.tar.gz", hash = "sha256:dbeac2f744df6b814ce83ec3f3777b299a015cbea57a2efc41c33b8c38265825", size = 42411, upload-time = "2026-02-03T20:46:14.479Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/57/b7/f4a051cefaf76930c77558b31646bcce7e9b3fbdcbc89e4073783e961519/botocore_stubs-1.41.3-py3-none-any.whl", hash = "sha256:6ab911bd9f7256f1dcea2e24a4af7ae0f9f07e83d0a760bba37f028f4a2e5589", size = 66749, upload-time = "2025-11-24T20:29:26.142Z" }, + { url = "https://files.pythonhosted.org/packages/32/76/cab7af7f16c0b09347f2ebe7ffda7101132f786acb767666dce43055faab/botocore_stubs-1.42.41-py3-none-any.whl", hash = "sha256:9423110fb0e391834bd2ed44ae5f879d8cb370a444703d966d30842ce2bcb5f0", size = 66759, upload-time = "2026-02-03T20:46:13.02Z" }, ] [[package]] @@ -778,22 +811,22 @@ wheels = [ [[package]] name = "brotlicffi" -version = "1.2.0.0" +version = "1.2.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cffi", marker = "platform_python_implementation == 'PyPy'" }, + { name = "cffi" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/85/57c314a6b35336efbbdc13e5fc9ae13f6b60a0647cfa7c1221178ac6d8ae/brotlicffi-1.2.0.0.tar.gz", hash = "sha256:34345d8d1f9d534fcac2249e57a4c3c8801a33c9942ff9f8574f67a175e17adb", size = 476682, upload-time = "2025-11-21T18:17:57.334Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8a/b6/017dc5f852ed9b8735af77774509271acbf1de02d238377667145fcee01d/brotlicffi-1.2.0.1.tar.gz", hash = "sha256:c20d5c596278307ad06414a6d95a892377ea274a5c6b790c2548c009385d621c", size = 478156, upload-time = "2026-03-05T19:54:11.547Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/df/a72b284d8c7bef0ed5756b41c2eb7d0219a1dd6ac6762f1c7bdbc31ef3af/brotlicffi-1.2.0.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:9458d08a7ccde8e3c0afedbf2c70a8263227a68dea5ab13590593f4c0a4fd5f4", size = 432340, upload-time = "2025-11-21T18:17:42.277Z" }, - { url = "https://files.pythonhosted.org/packages/74/2b/cc55a2d1d6fb4f5d458fba44a3d3f91fb4320aa14145799fd3a996af0686/brotlicffi-1.2.0.0-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:84e3d0020cf1bd8b8131f4a07819edee9f283721566fe044a20ec792ca8fd8b7", size = 1534002, upload-time = "2025-11-21T18:17:43.746Z" }, - { url = "https://files.pythonhosted.org/packages/e4/9c/d51486bf366fc7d6735f0e46b5b96ca58dc005b250263525a1eea3cd5d21/brotlicffi-1.2.0.0-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:33cfb408d0cff64cd50bef268c0fed397c46fbb53944aa37264148614a62e990", size = 1536547, upload-time = "2025-11-21T18:17:45.729Z" }, - { url = "https://files.pythonhosted.org/packages/1b/37/293a9a0a7caf17e6e657668bebb92dfe730305999fe8c0e2703b8888789c/brotlicffi-1.2.0.0-cp38-abi3-win32.whl", hash = "sha256:23e5c912fdc6fd37143203820230374d24babd078fc054e18070a647118158f6", size = 343085, upload-time = "2025-11-21T18:17:48.887Z" }, - { url = "https://files.pythonhosted.org/packages/07/6b/6e92009df3b8b7272f85a0992b306b61c34b7ea1c4776643746e61c380ac/brotlicffi-1.2.0.0-cp38-abi3-win_amd64.whl", hash = "sha256:f139a7cdfe4ae7859513067b736eb44d19fae1186f9e99370092f6915216451b", size = 378586, upload-time = "2025-11-21T18:17:50.531Z" }, - { url = "https://files.pythonhosted.org/packages/a4/ec/52488a0563f1663e2ccc75834b470650f4b8bcdea3132aef3bf67219c661/brotlicffi-1.2.0.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:fa102a60e50ddbd08de86a63431a722ea216d9bc903b000bf544149cc9b823dc", size = 402002, upload-time = "2025-11-21T18:17:51.76Z" }, - { url = "https://files.pythonhosted.org/packages/e4/63/d4aea4835fd97da1401d798d9b8ba77227974de565faea402f520b37b10f/brotlicffi-1.2.0.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7d3c4332fc808a94e8c1035950a10d04b681b03ab585ce897ae2a360d479037c", size = 406447, upload-time = "2025-11-21T18:17:53.614Z" }, - { url = "https://files.pythonhosted.org/packages/62/4e/5554ecb2615ff035ef8678d4e419549a0f7a28b3f096b272174d656749fb/brotlicffi-1.2.0.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb4eb5830026b79a93bf503ad32b2c5257315e9ffc49e76b2715cffd07c8e3db", size = 402521, upload-time = "2025-11-21T18:17:54.875Z" }, - { url = "https://files.pythonhosted.org/packages/b5/d3/b07f8f125ac52bbee5dc00ef0d526f820f67321bf4184f915f17f50a4657/brotlicffi-1.2.0.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:3832c66e00d6d82087f20a972b2fc03e21cd99ef22705225a6f8f418a9158ecc", size = 374730, upload-time = "2025-11-21T18:17:56.334Z" }, + { url = "https://files.pythonhosted.org/packages/9a/9f/b98dcd4af47994cee97aebac866996a006a2e5fc1fd1e2b82a8ad95cf09c/brotlicffi-1.2.0.1-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:91ba5f0ccc040f6ff8f7efaf839f797723d03ed46acb8ae9408f99ffd2572cf4", size = 432608, upload-time = "2026-03-05T19:53:56.736Z" }, + { url = "https://files.pythonhosted.org/packages/b1/7a/ac4ee56595a061e3718a6d1ea7e921f4df156894acffb28ed88a1fd52022/brotlicffi-1.2.0.1-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be9a670c6811af30a4bd42d7116dc5895d3b41beaa8ed8a89050447a0181f5ce", size = 1534257, upload-time = "2026-03-05T19:53:58.667Z" }, + { url = "https://files.pythonhosted.org/packages/99/39/e7410db7f6f56de57744ea52a115084ceb2735f4d44973f349bb92136586/brotlicffi-1.2.0.1-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f3314a3476f59e5443f9f72a6dff16edc0c3463c9b318feaef04ae3e4683f5a", size = 1536838, upload-time = "2026-03-05T19:54:00.705Z" }, + { url = "https://files.pythonhosted.org/packages/a6/75/6e7977d1935fc3fbb201cbd619be8f2c7aea25d40a096967132854b34708/brotlicffi-1.2.0.1-cp38-abi3-win32.whl", hash = "sha256:82ea52e2b5d3145b6c406ebd3efb0d55db718b7ad996bd70c62cec0439de1187", size = 343337, upload-time = "2026-03-05T19:54:02.446Z" }, + { url = "https://files.pythonhosted.org/packages/d8/ef/e7e485ce5e4ba3843a0a92feb767c7b6098fd6e65ce752918074d175ae71/brotlicffi-1.2.0.1-cp38-abi3-win_amd64.whl", hash = "sha256:da2e82a08e7778b8bc539d27ca03cdd684113e81394bfaaad8d0dfc6a17ddede", size = 379026, upload-time = "2026-03-05T19:54:04.322Z" }, + { url = "https://files.pythonhosted.org/packages/7f/53/6262c2256513e6f530d81642477cb19367270922063eaa2d7b781d8c723d/brotlicffi-1.2.0.1-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:e015af99584c6db1490a69a210c765953e473e63adc2d891ac3062a737c9e851", size = 402265, upload-time = "2026-03-05T19:54:05.858Z" }, + { url = "https://files.pythonhosted.org/packages/1f/d9/d5340b43cf5fbe7fe5a083d237e5338cc1caa73bea523be1c5e452c26290/brotlicffi-1.2.0.1-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:37cb587d32bf7168e2218c455e22e409ad1f3157c6c71945879a311f3e6b6abf", size = 406710, upload-time = "2026-03-05T19:54:07.272Z" }, + { url = "https://files.pythonhosted.org/packages/a3/82/dbced4c1e0792efdf23fd90ff6d2a320c64ff4dfef7aacc85c04fde9ddd2/brotlicffi-1.2.0.1-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9d6ba65dd528892b4d9960beba2ae011a753620bcfc66cf6fa3cee18d7b0baa4", size = 402787, upload-time = "2026-03-05T19:54:08.73Z" }, + { url = "https://files.pythonhosted.org/packages/ef/6f/534205ba7590c9a8716a614f270c5c2ec419b5b7079b3f9cd31b7b5580de/brotlicffi-1.2.0.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:f2a5575653b0672638ba039b82fda56854934d7a6a24d4b8b5033f73ab43cbc1", size = 375108, upload-time = "2026-03-05T19:54:10.079Z" }, ] [[package]] @@ -813,7 +846,7 @@ name = "build" version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "os_name == 'nt' and sys_platform != 'linux'" }, + { name = "colorama", marker = "os_name == 'nt'" }, { name = "packaging" }, { name = "pyproject-hooks" }, ] @@ -831,9 +864,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fb/2b/a64c2d25a37aeb921fddb929111413049fc5f8b9a4c1aefaffaafe768d54/cachetools-5.3.3-py3-none-any.whl", hash = "sha256:0abad1021d3f8325b2fc1d2e9c8b9c9d57b04c3932657a72465447332c24d945", size = 9325, upload-time = "2024-02-26T20:33:20.308Z" }, ] +[[package]] +name = "catalogue" +version = "2.0.10" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/38/b4/244d58127e1cdf04cf2dc7d9566f0d24ef01d5ce21811bab088ecc62b5ea/catalogue-2.0.10.tar.gz", hash = "sha256:4f56daa940913d3f09d589c191c74e5a6d51762b3a9e37dd53b7437afd6cda15", size = 19561, upload-time = "2023-09-25T06:29:24.962Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/96/d32b941a501ab566a16358d68b6eb4e4acc373fab3c3c4d7d9e649f7b4bb/catalogue-2.0.10-py3-none-any.whl", hash = "sha256:58c2de0020aa90f4a2da7dfad161bf7b3b054c86a5f09fcedc0b2b740c109a9f", size = 17325, upload-time = "2023-09-25T06:29:23.337Z" }, +] + [[package]] name = "celery" -version = "5.5.3" +version = "5.6.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "billiard" }, @@ -843,32 +885,33 @@ dependencies = [ { name = "click-repl" }, { name = "kombu" }, { name = "python-dateutil" }, + { name = "tzlocal" }, { name = "vine" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bb/7d/6c289f407d219ba36d8b384b42489ebdd0c84ce9c413875a8aae0c85f35b/celery-5.5.3.tar.gz", hash = "sha256:6c972ae7968c2b5281227f01c3a3f984037d21c5129d07bf3550cc2afc6b10a5", size = 1667144, upload-time = "2025-06-01T11:08:12.563Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8f/9d/3d13596519cfa7207a6f9834f4b082554845eb3cd2684b5f8535d50c7c44/celery-5.6.2.tar.gz", hash = "sha256:4a8921c3fcf2ad76317d3b29020772103581ed2454c4c042cc55dcc43585009b", size = 1718802, upload-time = "2026-01-04T12:35:58.012Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c9/af/0dcccc7fdcdf170f9a1585e5e96b6fb0ba1749ef6be8c89a6202284759bd/celery-5.5.3-py3-none-any.whl", hash = "sha256:0b5761a07057acee94694464ca482416b959568904c9dfa41ce8413a7d65d525", size = 438775, upload-time = "2025-06-01T11:08:09.94Z" }, + { url = "https://files.pythonhosted.org/packages/dd/bd/9ecd619e456ae4ba73b6583cc313f26152afae13e9a82ac4fe7f8856bfd1/celery-5.6.2-py3-none-any.whl", hash = "sha256:3ffafacbe056951b629c7abcf9064c4a2366de0bdfc9fdba421b97ebb68619a5", size = 445502, upload-time = "2026-01-04T12:35:55.894Z" }, ] [[package]] name = "celery-types" -version = "0.23.0" +version = "0.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e9/d1/0823e71c281e4ad0044e278cf1577d1a68e05f2809424bf94e1614925c5d/celery_types-0.23.0.tar.gz", hash = "sha256:402ed0555aea3cd5e1e6248f4632e4f18eec8edb2435173f9e6dc08449fa101e", size = 31479, upload-time = "2025-03-03T23:56:51.547Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/38/813dd7534e41682684d3a5c2cc4a8710e3acc51b364920b9c4d747c7b18f/celery_types-0.26.0.tar.gz", hash = "sha256:fa318136fdad83f83f1531deecd9fe664b5dfffff29f3c31e9120a46b8e3908f", size = 106210, upload-time = "2026-03-12T23:06:49.941Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/8b/92bb54dd74d145221c3854aa245c84f4dc04cc9366147496182cec8e88e3/celery_types-0.23.0-py3-none-any.whl", hash = "sha256:0cc495b8d7729891b7e070d0ec8d4906d2373209656a6e8b8276fe1ed306af9a", size = 50189, upload-time = "2025-03-03T23:56:50.458Z" }, + { url = "https://files.pythonhosted.org/packages/d9/e5/c5ec98f7fd7817d077c9a5a5e705d54f74d4ca08ee3f14dee881c93c0511/celery_types-0.26.0-py3-none-any.whl", hash = "sha256:eb9da76f461786091970df466ec647d9a27956399852542cb6cab9309970f950", size = 211260, upload-time = "2026-03-12T23:06:48.588Z" }, ] [[package]] name = "certifi" -version = "2025.11.12" +version = "2026.2.25" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a2/8c/58f469717fa48465e4a50c014a0400602d3c437d7c0c468e17ada824da3a/certifi-2025.11.12.tar.gz", hash = "sha256:d8ab5478f2ecd78af242878415affce761ca6bc54a22a27e026d7c25357c3316", size = 160538, upload-time = "2025-11-12T02:54:51.517Z" } +sdist = { url = "https://files.pythonhosted.org/packages/af/2d/7bf41579a8986e348fa033a31cdd0e4121114f6bce2457e8876010b092dd/certifi-2026.2.25.tar.gz", hash = "sha256:e887ab5cee78ea814d3472169153c2d12cd43b14bd03329a39a9c6e2e80bfba7", size = 155029, upload-time = "2026-02-25T02:54:17.342Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/7d/9bc192684cea499815ff478dfcdc13835ddf401365057044fb721ec6bddb/certifi-2025.11.12-py3-none-any.whl", hash = "sha256:97de8790030bbd5c2d96b7ec782fc2f7820ef8dba6db909ccf95449f2d062d4b", size = 159438, upload-time = "2025-11-12T02:54:49.735Z" }, + { url = "https://files.pythonhosted.org/packages/9a/3c/c17fb3ca2d9c3acff52e30b309f538586f9f5b9c9cf454f3845fc9af4881/certifi-2026.2.25-py3-none-any.whl", hash = "sha256:027692e4402ad994f1c42e52a4997a9763c646b73e4096e4d5d6db8af1d6f0fa", size = 153684, upload-time = "2026-02-25T02:54:15.766Z" }, ] [[package]] @@ -907,63 +950,64 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6", size = 177762, upload-time = "2025-09-08T23:22:59.668Z" }, ] -[[package]] -name = "cfgv" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/4e/b5/721b8799b04bf9afe054a3899c6cf4e880fcf8563cc71c15610242490a0c/cfgv-3.5.0.tar.gz", hash = "sha256:d5b1034354820651caa73ede66a6294d6e95c1b00acc5e9b098e917404669132", size = 7334, upload-time = "2025-11-19T20:55:51.612Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/db/3c/33bac158f8ab7f89b2e59426d5fe2e4f63f7ed25df84c036890172b412b5/cfgv-3.5.0-py2.py3-none-any.whl", hash = "sha256:a8dc6b26ad22ff227d2634a65cb388215ce6cc96bbcc5cfde7641ae87e8dacc0", size = 7445, upload-time = "2025-11-19T20:55:50.744Z" }, -] - [[package]] name = "chardet" -version = "5.1.0" +version = "7.2.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/41/32/cdc91dcf83849c7385bf8e2a5693d87376536ed000807fa07f5eab33430d/chardet-5.1.0.tar.gz", hash = "sha256:0d62712b956bc154f85fb0a266e2a3c5913c2967e00348701b32411d6def31e5", size = 2069617, upload-time = "2022-12-01T22:34:18.086Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1d/94/7af830a4c63df020644aa99d76147d003a1463f255d0a054958978be5a8a/chardet-7.2.0.tar.gz", hash = "sha256:4ef7292b1342ea805c32cce58a45db204f59d080ed311d6cdaa7ca747fcc0cd5", size = 516522, upload-time = "2026-03-18T00:07:23.76Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/8f/8fc49109009e8d2169d94d72e6b1f4cd45c13d147ba7d6170fb41f22b08f/chardet-5.1.0-py3-none-any.whl", hash = "sha256:362777fb014af596ad31334fde1e8c327dfdb076e1960d1694662d46a6917ab9", size = 199124, upload-time = "2022-12-01T22:34:14.609Z" }, + { url = "https://files.pythonhosted.org/packages/88/63/3ba1b7828340ac4b4761df5454abd0c48dd620eb4f12a5106c3390539711/chardet-7.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a8685b331c4896e9135bd748387f713dd53c019475ae1b8238b8f59be1668acd", size = 545761, upload-time = "2026-03-18T00:06:53.562Z" }, + { url = "https://files.pythonhosted.org/packages/0d/b4/c3d87a7aa5ee1c71fff91a503ae1a0c3bc3b756e646948f6bfdfd2c8c873/chardet-7.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fa14cc0e7d2142dd313524b3a339e15cbd8b7a8a7e11a560686e0b6f58038ec9", size = 539103, upload-time = "2026-03-18T00:06:54.837Z" }, + { url = "https://files.pythonhosted.org/packages/71/51/8eb14c4b5308225889eb4bdd9499a3d7cab1a77a82e1bcc1ad0ad22cb3a3/chardet-7.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4c51a3d8aa3c162be0495404b39bb1c137b44a634c1f46e2909e2c6a60349c00", size = 560010, upload-time = "2026-03-18T00:06:56.442Z" }, + { url = "https://files.pythonhosted.org/packages/1e/cc/350b4ac6916291624093ea07ac186733e490bd33948d205d07848dbd51ff/chardet-7.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:347ed77bb5eed8929fae7482671690a15c731d66808f1ff0ce7d22224ca7ec79", size = 562610, upload-time = "2026-03-18T00:06:57.95Z" }, + { url = "https://files.pythonhosted.org/packages/36/f9/b757ade39ad981f89e3607abc75827729bf408359ddd31073e7a85cb8aeb/chardet-7.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:d298762002a6b6e81dbcc81ade9e0882e579e968f4801daf4d8ffd6a31b99552", size = 530914, upload-time = "2026-03-18T00:06:59.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/f2/5b4bfc3c93458c2d618d71f79e34def05552f178b4d452555a8333696f1a/chardet-7.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c4604344380a6f9b982c28855c1edfd23a45a2c9142b9a34bc0c08986049f398", size = 547261, upload-time = "2026-03-18T00:07:00.869Z" }, + { url = "https://files.pythonhosted.org/packages/38/fd/3effc8151d19b6ced8d1de427df5a039b1cce4cef79a3ac6f3c1d1135502/chardet-7.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:195c54d8f04a7a9c321cb7cebececa35b1c818c7aa7c195086bae10fcbb3391f", size = 539283, upload-time = "2026-03-18T00:07:02.419Z" }, + { url = "https://files.pythonhosted.org/packages/9e/b1/c1990fcafa601fcebe9308ae23026906f1e04b53b53ed38e6a81499acd30/chardet-7.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ddd03a67fca8c91287f8718dfbe3f94c2c1aa1fd3a82433b693f5b868dedf319", size = 561023, upload-time = "2026-03-18T00:07:04.078Z" }, + { url = "https://files.pythonhosted.org/packages/19/5e/4ddbef974a1036416431ef6ceb13dae8c5ab2193a301f2b58c5348855f1b/chardet-7.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8f6af0fa005b9488c8fbf8fec2ad7023531970320901d6334c50844ccca9b117", size = 564598, upload-time = "2026-03-18T00:07:05.341Z" }, + { url = "https://files.pythonhosted.org/packages/ae/6b/045858a8b6a54777e64ff4880058018cc05e547e49808f84f7a41a45615a/chardet-7.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:e8853c71ea1261bcc1b8f8b171acb7c272a5cfd06b57729c460241ee38705049", size = 531154, upload-time = "2026-03-18T00:07:07.061Z" }, + { url = "https://files.pythonhosted.org/packages/c2/47/97786f40be59ff5ff10ec5ebcb1ef0ad28dd915717cb210cee89ae7a83a6/chardet-7.2.0-py3-none-any.whl", hash = "sha256:f8ea866b9fbd8df5f19032d765a4d81dcbf6194a3c7388b44d378d02c9784170", size = 414953, upload-time = "2026-03-18T00:07:22.48Z" }, ] [[package]] name = "charset-normalizer" -version = "3.4.4" +version = "3.4.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/13/69/33ddede1939fdd074bce5434295f38fae7136463422fe4fd3e0e89b98062/charset_normalizer-3.4.4.tar.gz", hash = "sha256:94537985111c35f28720e43603b8e7b43a6ecfb2ce1d3058bbe955b73404e21a", size = 129418, upload-time = "2025-10-14T04:42:32.879Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/60/e3bec1881450851b087e301bedc3daa9377a4d45f1c26aa90b0b235e38aa/charset_normalizer-3.4.6.tar.gz", hash = "sha256:1ae6b62897110aa7c79ea2f5dd38d1abca6db663687c0b1ad9aed6f6bae3d9d6", size = 143363, upload-time = "2026-03-15T18:53:25.478Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/27/c6491ff4954e58a10f69ad90aca8a1b6fe9c5d3c6f380907af3c37435b59/charset_normalizer-3.4.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6e1fcf0720908f200cd21aa4e6750a48ff6ce4afe7ff5a79a90d5ed8a08296f8", size = 206988, upload-time = "2025-10-14T04:40:33.79Z" }, - { url = "https://files.pythonhosted.org/packages/94/59/2e87300fe67ab820b5428580a53cad894272dbb97f38a7a814a2a1ac1011/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5f819d5fe9234f9f82d75bdfa9aef3a3d72c4d24a6e57aeaebba32a704553aa0", size = 147324, upload-time = "2025-10-14T04:40:34.961Z" }, - { url = "https://files.pythonhosted.org/packages/07/fb/0cf61dc84b2b088391830f6274cb57c82e4da8bbc2efeac8c025edb88772/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a59cb51917aa591b1c4e6a43c132f0cdc3c76dbad6155df4e28ee626cc77a0a3", size = 142742, upload-time = "2025-10-14T04:40:36.105Z" }, - { url = "https://files.pythonhosted.org/packages/62/8b/171935adf2312cd745d290ed93cf16cf0dfe320863ab7cbeeae1dcd6535f/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8ef3c867360f88ac904fd3f5e1f902f13307af9052646963ee08ff4f131adafc", size = 160863, upload-time = "2025-10-14T04:40:37.188Z" }, - { url = "https://files.pythonhosted.org/packages/09/73/ad875b192bda14f2173bfc1bc9a55e009808484a4b256748d931b6948442/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d9e45d7faa48ee908174d8fe84854479ef838fc6a705c9315372eacbc2f02897", size = 157837, upload-time = "2025-10-14T04:40:38.435Z" }, - { url = "https://files.pythonhosted.org/packages/6d/fc/de9cce525b2c5b94b47c70a4b4fb19f871b24995c728e957ee68ab1671ea/charset_normalizer-3.4.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:840c25fb618a231545cbab0564a799f101b63b9901f2569faecd6b222ac72381", size = 151550, upload-time = "2025-10-14T04:40:40.053Z" }, - { url = "https://files.pythonhosted.org/packages/55/c2/43edd615fdfba8c6f2dfbd459b25a6b3b551f24ea21981e23fb768503ce1/charset_normalizer-3.4.4-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ca5862d5b3928c4940729dacc329aa9102900382fea192fc5e52eb69d6093815", size = 149162, upload-time = "2025-10-14T04:40:41.163Z" }, - { url = "https://files.pythonhosted.org/packages/03/86/bde4ad8b4d0e9429a4e82c1e8f5c659993a9a863ad62c7df05cf7b678d75/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9c7f57c3d666a53421049053eaacdd14bbd0a528e2186fcb2e672effd053bb0", size = 150019, upload-time = "2025-10-14T04:40:42.276Z" }, - { url = "https://files.pythonhosted.org/packages/1f/86/a151eb2af293a7e7bac3a739b81072585ce36ccfb4493039f49f1d3cae8c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:277e970e750505ed74c832b4bf75dac7476262ee2a013f5574dd49075879e161", size = 143310, upload-time = "2025-10-14T04:40:43.439Z" }, - { url = "https://files.pythonhosted.org/packages/b5/fe/43dae6144a7e07b87478fdfc4dbe9efd5defb0e7ec29f5f58a55aeef7bf7/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:31fd66405eaf47bb62e8cd575dc621c56c668f27d46a61d975a249930dd5e2a4", size = 162022, upload-time = "2025-10-14T04:40:44.547Z" }, - { url = "https://files.pythonhosted.org/packages/80/e6/7aab83774f5d2bca81f42ac58d04caf44f0cc2b65fc6db2b3b2e8a05f3b3/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:0d3d8f15c07f86e9ff82319b3d9ef6f4bf907608f53fe9d92b28ea9ae3d1fd89", size = 149383, upload-time = "2025-10-14T04:40:46.018Z" }, - { url = "https://files.pythonhosted.org/packages/4f/e8/b289173b4edae05c0dde07f69f8db476a0b511eac556dfe0d6bda3c43384/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:9f7fcd74d410a36883701fafa2482a6af2ff5ba96b9a620e9e0721e28ead5569", size = 159098, upload-time = "2025-10-14T04:40:47.081Z" }, - { url = "https://files.pythonhosted.org/packages/d8/df/fe699727754cae3f8478493c7f45f777b17c3ef0600e28abfec8619eb49c/charset_normalizer-3.4.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ebf3e58c7ec8a8bed6d66a75d7fb37b55e5015b03ceae72a8e7c74495551e224", size = 152991, upload-time = "2025-10-14T04:40:48.246Z" }, - { url = "https://files.pythonhosted.org/packages/1a/86/584869fe4ddb6ffa3bd9f491b87a01568797fb9bd8933f557dba9771beaf/charset_normalizer-3.4.4-cp311-cp311-win32.whl", hash = "sha256:eecbc200c7fd5ddb9a7f16c7decb07b566c29fa2161a16cf67b8d068bd21690a", size = 99456, upload-time = "2025-10-14T04:40:49.376Z" }, - { url = "https://files.pythonhosted.org/packages/65/f6/62fdd5feb60530f50f7e38b4f6a1d5203f4d16ff4f9f0952962c044e919a/charset_normalizer-3.4.4-cp311-cp311-win_amd64.whl", hash = "sha256:5ae497466c7901d54b639cf42d5b8c1b6a4fead55215500d2f486d34db48d016", size = 106978, upload-time = "2025-10-14T04:40:50.844Z" }, - { url = "https://files.pythonhosted.org/packages/7a/9d/0710916e6c82948b3be62d9d398cb4fcf4e97b56d6a6aeccd66c4b2f2bd5/charset_normalizer-3.4.4-cp311-cp311-win_arm64.whl", hash = "sha256:65e2befcd84bc6f37095f5961e68a6f077bf44946771354a28ad434c2cce0ae1", size = 99969, upload-time = "2025-10-14T04:40:52.272Z" }, - { url = "https://files.pythonhosted.org/packages/f3/85/1637cd4af66fa687396e757dec650f28025f2a2f5a5531a3208dc0ec43f2/charset_normalizer-3.4.4-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0a98e6759f854bd25a58a73fa88833fba3b7c491169f86ce1180c948ab3fd394", size = 208425, upload-time = "2025-10-14T04:40:53.353Z" }, - { url = "https://files.pythonhosted.org/packages/9d/6a/04130023fef2a0d9c62d0bae2649b69f7b7d8d24ea5536feef50551029df/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b5b290ccc2a263e8d185130284f8501e3e36c5e02750fc6b6bdeb2e9e96f1e25", size = 148162, upload-time = "2025-10-14T04:40:54.558Z" }, - { url = "https://files.pythonhosted.org/packages/78/29/62328d79aa60da22c9e0b9a66539feae06ca0f5a4171ac4f7dc285b83688/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:74bb723680f9f7a6234dcf67aea57e708ec1fbdf5699fb91dfd6f511b0a320ef", size = 144558, upload-time = "2025-10-14T04:40:55.677Z" }, - { url = "https://files.pythonhosted.org/packages/86/bb/b32194a4bf15b88403537c2e120b817c61cd4ecffa9b6876e941c3ee38fe/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f1e34719c6ed0b92f418c7c780480b26b5d9c50349e9a9af7d76bf757530350d", size = 161497, upload-time = "2025-10-14T04:40:57.217Z" }, - { url = "https://files.pythonhosted.org/packages/19/89/a54c82b253d5b9b111dc74aca196ba5ccfcca8242d0fb64146d4d3183ff1/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2437418e20515acec67d86e12bf70056a33abdacb5cb1655042f6538d6b085a8", size = 159240, upload-time = "2025-10-14T04:40:58.358Z" }, - { url = "https://files.pythonhosted.org/packages/c0/10/d20b513afe03acc89ec33948320a5544d31f21b05368436d580dec4e234d/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11d694519d7f29d6cd09f6ac70028dba10f92f6cdd059096db198c283794ac86", size = 153471, upload-time = "2025-10-14T04:40:59.468Z" }, - { url = "https://files.pythonhosted.org/packages/61/fa/fbf177b55bdd727010f9c0a3c49eefa1d10f960e5f09d1d887bf93c2e698/charset_normalizer-3.4.4-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:ac1c4a689edcc530fc9d9aa11f5774b9e2f33f9a0c6a57864e90908f5208d30a", size = 150864, upload-time = "2025-10-14T04:41:00.623Z" }, - { url = "https://files.pythonhosted.org/packages/05/12/9fbc6a4d39c0198adeebbde20b619790e9236557ca59fc40e0e3cebe6f40/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:21d142cc6c0ec30d2efee5068ca36c128a30b0f2c53c1c07bd78cb6bc1d3be5f", size = 150647, upload-time = "2025-10-14T04:41:01.754Z" }, - { url = "https://files.pythonhosted.org/packages/ad/1f/6a9a593d52e3e8c5d2b167daf8c6b968808efb57ef4c210acb907c365bc4/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:5dbe56a36425d26d6cfb40ce79c314a2e4dd6211d51d6d2191c00bed34f354cc", size = 145110, upload-time = "2025-10-14T04:41:03.231Z" }, - { url = "https://files.pythonhosted.org/packages/30/42/9a52c609e72471b0fc54386dc63c3781a387bb4fe61c20231a4ebcd58bdd/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5bfbb1b9acf3334612667b61bd3002196fe2a1eb4dd74d247e0f2a4d50ec9bbf", size = 162839, upload-time = "2025-10-14T04:41:04.715Z" }, - { url = "https://files.pythonhosted.org/packages/c4/5b/c0682bbf9f11597073052628ddd38344a3d673fda35a36773f7d19344b23/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:d055ec1e26e441f6187acf818b73564e6e6282709e9bcb5b63f5b23068356a15", size = 150667, upload-time = "2025-10-14T04:41:05.827Z" }, - { url = "https://files.pythonhosted.org/packages/e4/24/a41afeab6f990cf2daf6cb8c67419b63b48cf518e4f56022230840c9bfb2/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:af2d8c67d8e573d6de5bc30cdb27e9b95e49115cd9baad5ddbd1a6207aaa82a9", size = 160535, upload-time = "2025-10-14T04:41:06.938Z" }, - { url = "https://files.pythonhosted.org/packages/2a/e5/6a4ce77ed243c4a50a1fecca6aaaab419628c818a49434be428fe24c9957/charset_normalizer-3.4.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:780236ac706e66881f3b7f2f32dfe90507a09e67d1d454c762cf642e6e1586e0", size = 154816, upload-time = "2025-10-14T04:41:08.101Z" }, - { url = "https://files.pythonhosted.org/packages/a8/ef/89297262b8092b312d29cdb2517cb1237e51db8ecef2e9af5edbe7b683b1/charset_normalizer-3.4.4-cp312-cp312-win32.whl", hash = "sha256:5833d2c39d8896e4e19b689ffc198f08ea58116bee26dea51e362ecc7cd3ed26", size = 99694, upload-time = "2025-10-14T04:41:09.23Z" }, - { url = "https://files.pythonhosted.org/packages/3d/2d/1e5ed9dd3b3803994c155cd9aacb60c82c331bad84daf75bcb9c91b3295e/charset_normalizer-3.4.4-cp312-cp312-win_amd64.whl", hash = "sha256:a79cfe37875f822425b89a82333404539ae63dbdddf97f84dcbc3d339aae9525", size = 107131, upload-time = "2025-10-14T04:41:10.467Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d9/0ed4c7098a861482a7b6a95603edce4c0d9db2311af23da1fb2b75ec26fc/charset_normalizer-3.4.4-cp312-cp312-win_arm64.whl", hash = "sha256:376bec83a63b8021bb5c8ea75e21c4ccb86e7e45ca4eb81146091b56599b80c3", size = 100390, upload-time = "2025-10-14T04:41:11.915Z" }, - { url = "https://files.pythonhosted.org/packages/0a/4c/925909008ed5a988ccbb72dcc897407e5d6d3bd72410d69e051fc0c14647/charset_normalizer-3.4.4-py3-none-any.whl", hash = "sha256:7a32c560861a02ff789ad905a2fe94e3f840803362c84fecf1851cb4cf3dc37f", size = 53402, upload-time = "2025-10-14T04:42:31.76Z" }, + { url = "https://files.pythonhosted.org/packages/62/28/ff6f234e628a2de61c458be2779cb182bc03f6eec12200d4a525bbfc9741/charset_normalizer-3.4.6-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:82060f995ab5003a2d6e0f4ad29065b7672b6593c8c63559beefe5b443242c3e", size = 293582, upload-time = "2026-03-15T18:50:25.454Z" }, + { url = "https://files.pythonhosted.org/packages/1c/b7/b1a117e5385cbdb3205f6055403c2a2a220c5ea80b8716c324eaf75c5c95/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60c74963d8350241a79cb8feea80e54d518f72c26db618862a8f53e5023deaf9", size = 197240, upload-time = "2026-03-15T18:50:27.196Z" }, + { url = "https://files.pythonhosted.org/packages/a1/5f/2574f0f09f3c3bc1b2f992e20bce6546cb1f17e111c5be07308dc5427956/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f6e4333fb15c83f7d1482a76d45a0818897b3d33f00efd215528ff7c51b8e35d", size = 217363, upload-time = "2026-03-15T18:50:28.601Z" }, + { url = "https://files.pythonhosted.org/packages/4a/d1/0ae20ad77bc949ddd39b51bf383b6ca932f2916074c95cad34ae465ab71f/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:bc72863f4d9aba2e8fd9085e63548a324ba706d2ea2c83b260da08a59b9482de", size = 212994, upload-time = "2026-03-15T18:50:30.102Z" }, + { url = "https://files.pythonhosted.org/packages/60/ac/3233d262a310c1b12633536a07cde5ddd16985e6e7e238e9f3f9423d8eb9/charset_normalizer-3.4.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9cc4fc6c196d6a8b76629a70ddfcd4635a6898756e2d9cac5565cf0654605d73", size = 204697, upload-time = "2026-03-15T18:50:31.654Z" }, + { url = "https://files.pythonhosted.org/packages/25/3c/8a18fc411f085b82303cfb7154eed5bd49c77035eb7608d049468b53f87c/charset_normalizer-3.4.6-cp311-cp311-manylinux_2_31_armv7l.whl", hash = "sha256:0c173ce3a681f309f31b87125fecec7a5d1347261ea11ebbb856fa6006b23c8c", size = 191673, upload-time = "2026-03-15T18:50:33.433Z" }, + { url = "https://files.pythonhosted.org/packages/ff/a7/11cfe61d6c5c5c7438d6ba40919d0306ed83c9ab957f3d4da2277ff67836/charset_normalizer-3.4.6-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c907cdc8109f6c619e6254212e794d6548373cc40e1ec75e6e3823d9135d29cc", size = 201120, upload-time = "2026-03-15T18:50:35.105Z" }, + { url = "https://files.pythonhosted.org/packages/b5/10/cf491fa1abd47c02f69687046b896c950b92b6cd7337a27e6548adbec8e4/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:404a1e552cf5b675a87f0651f8b79f5f1e6fd100ee88dc612f89aa16abd4486f", size = 200911, upload-time = "2026-03-15T18:50:36.819Z" }, + { url = "https://files.pythonhosted.org/packages/28/70/039796160b48b18ed466fde0af84c1b090c4e288fae26cd674ad04a2d703/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:e3c701e954abf6fc03a49f7c579cc80c2c6cc52525340ca3186c41d3f33482ef", size = 192516, upload-time = "2026-03-15T18:50:38.228Z" }, + { url = "https://files.pythonhosted.org/packages/ff/34/c56f3223393d6ff3124b9e78f7de738047c2d6bc40a4f16ac0c9d7a1cb3c/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7a6967aaf043bceabab5412ed6bd6bd26603dae84d5cb75bf8d9a74a4959d398", size = 218795, upload-time = "2026-03-15T18:50:39.664Z" }, + { url = "https://files.pythonhosted.org/packages/e8/3b/ce2d4f86c5282191a041fdc5a4ce18f1c6bd40a5bd1f74cf8625f08d51c1/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:5feb91325bbceade6afab43eb3b508c63ee53579fe896c77137ded51c6b6958e", size = 201833, upload-time = "2026-03-15T18:50:41.552Z" }, + { url = "https://files.pythonhosted.org/packages/3b/9b/b6a9f76b0fd7c5b5ec58b228ff7e85095370282150f0bd50b3126f5506d6/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:f820f24b09e3e779fe84c3c456cb4108a7aa639b0d1f02c28046e11bfcd088ed", size = 213920, upload-time = "2026-03-15T18:50:43.33Z" }, + { url = "https://files.pythonhosted.org/packages/ae/98/7bc23513a33d8172365ed30ee3a3b3fe1ece14a395e5fc94129541fc6003/charset_normalizer-3.4.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b35b200d6a71b9839a46b9b7fff66b6638bb52fc9658aa58796b0326595d3021", size = 206951, upload-time = "2026-03-15T18:50:44.789Z" }, + { url = "https://files.pythonhosted.org/packages/32/73/c0b86f3d1458468e11aec870e6b3feac931facbe105a894b552b0e518e79/charset_normalizer-3.4.6-cp311-cp311-win32.whl", hash = "sha256:9ca4c0b502ab399ef89248a2c84c54954f77a070f28e546a85e91da627d1301e", size = 143703, upload-time = "2026-03-15T18:50:46.103Z" }, + { url = "https://files.pythonhosted.org/packages/c6/e3/76f2facfe8eddee0bbd38d2594e709033338eae44ebf1738bcefe0a06185/charset_normalizer-3.4.6-cp311-cp311-win_amd64.whl", hash = "sha256:a9e68c9d88823b274cf1e72f28cb5dc89c990edf430b0bfd3e2fb0785bfeabf4", size = 153857, upload-time = "2026-03-15T18:50:47.563Z" }, + { url = "https://files.pythonhosted.org/packages/e2/dc/9abe19c9b27e6cd3636036b9d1b387b78c40dedbf0b47f9366737684b4b0/charset_normalizer-3.4.6-cp311-cp311-win_arm64.whl", hash = "sha256:97d0235baafca5f2b09cf332cc275f021e694e8362c6bb9c96fc9a0eb74fc316", size = 142751, upload-time = "2026-03-15T18:50:49.234Z" }, + { url = "https://files.pythonhosted.org/packages/e5/62/c0815c992c9545347aeea7859b50dc9044d147e2e7278329c6e02ac9a616/charset_normalizer-3.4.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:2ef7fedc7a6ecbe99969cd09632516738a97eeb8bd7258bf8a0f23114c057dab", size = 295154, upload-time = "2026-03-15T18:50:50.88Z" }, + { url = "https://files.pythonhosted.org/packages/a8/37/bdca6613c2e3c58c7421891d80cc3efa1d32e882f7c4a7ee6039c3fc951a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4ea868bc28109052790eb2b52a9ab33f3aa7adc02f96673526ff47419490e21", size = 199191, upload-time = "2026-03-15T18:50:52.658Z" }, + { url = "https://files.pythonhosted.org/packages/6c/92/9934d1bbd69f7f398b38c5dae1cbf9cc672e7c34a4adf7b17c0a9c17d15d/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:836ab36280f21fc1a03c99cd05c6b7af70d2697e374c7af0b61ed271401a72a2", size = 218674, upload-time = "2026-03-15T18:50:54.102Z" }, + { url = "https://files.pythonhosted.org/packages/af/90/25f6ab406659286be929fd89ab0e78e38aa183fc374e03aa3c12d730af8a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f1ce721c8a7dfec21fcbdfe04e8f68174183cf4e8188e0645e92aa23985c57ff", size = 215259, upload-time = "2026-03-15T18:50:55.616Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ef/79a463eb0fff7f96afa04c1d4c51f8fc85426f918db467854bfb6a569ce3/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e28d62a8fc7a1fa411c43bd65e346f3bce9716dc51b897fbe930c5987b402d5", size = 207276, upload-time = "2026-03-15T18:50:57.054Z" }, + { url = "https://files.pythonhosted.org/packages/f7/72/d0426afec4b71dc159fa6b4e68f868cd5a3ecd918fec5813a15d292a7d10/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:530d548084c4a9f7a16ed4a294d459b4f229db50df689bfe92027452452943a0", size = 195161, upload-time = "2026-03-15T18:50:58.686Z" }, + { url = "https://files.pythonhosted.org/packages/bf/18/c82b06a68bfcb6ce55e508225d210c7e6a4ea122bfc0748892f3dc4e8e11/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30f445ae60aad5e1f8bdbb3108e39f6fbc09f4ea16c815c66578878325f8f15a", size = 203452, upload-time = "2026-03-15T18:51:00.196Z" }, + { url = "https://files.pythonhosted.org/packages/44/d6/0c25979b92f8adafdbb946160348d8d44aa60ce99afdc27df524379875cb/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ac2393c73378fea4e52aa56285a3d64be50f1a12395afef9cce47772f60334c2", size = 202272, upload-time = "2026-03-15T18:51:01.703Z" }, + { url = "https://files.pythonhosted.org/packages/2e/3d/7fea3e8fe84136bebbac715dd1221cc25c173c57a699c030ab9b8900cbb7/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:90ca27cd8da8118b18a52d5f547859cc1f8354a00cd1e8e5120df3e30d6279e5", size = 195622, upload-time = "2026-03-15T18:51:03.526Z" }, + { url = "https://files.pythonhosted.org/packages/57/8a/d6f7fd5cb96c58ef2f681424fbca01264461336d2a7fc875e4446b1f1346/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8e5a94886bedca0f9b78fecd6afb6629142fd2605aa70a125d49f4edc6037ee6", size = 220056, upload-time = "2026-03-15T18:51:05.269Z" }, + { url = "https://files.pythonhosted.org/packages/16/50/478cdda782c8c9c3fb5da3cc72dd7f331f031e7f1363a893cdd6ca0f8de0/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:695f5c2823691a25f17bc5d5ffe79fa90972cc34b002ac6c843bb8a1720e950d", size = 203751, upload-time = "2026-03-15T18:51:06.858Z" }, + { url = "https://files.pythonhosted.org/packages/75/fc/cc2fcac943939c8e4d8791abfa139f685e5150cae9f94b60f12520feaa9b/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:231d4da14bcd9301310faf492051bee27df11f2bc7549bc0bb41fef11b82daa2", size = 216563, upload-time = "2026-03-15T18:51:08.564Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b7/a4add1d9a5f68f3d037261aecca83abdb0ab15960a3591d340e829b37298/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a056d1ad2633548ca18ffa2f85c202cfb48b68615129143915b8dc72a806a923", size = 209265, upload-time = "2026-03-15T18:51:10.312Z" }, + { url = "https://files.pythonhosted.org/packages/6c/18/c094561b5d64a24277707698e54b7f67bd17a4f857bbfbb1072bba07c8bf/charset_normalizer-3.4.6-cp312-cp312-win32.whl", hash = "sha256:c2274ca724536f173122f36c98ce188fd24ce3dad886ec2b7af859518ce008a4", size = 144229, upload-time = "2026-03-15T18:51:11.694Z" }, + { url = "https://files.pythonhosted.org/packages/ab/20/0567efb3a8fd481b8f34f739ebddc098ed062a59fed41a8d193a61939e8f/charset_normalizer-3.4.6-cp312-cp312-win_amd64.whl", hash = "sha256:c8ae56368f8cc97c7e40a7ee18e1cedaf8e780cd8bc5ed5ac8b81f238614facb", size = 154277, upload-time = "2026-03-15T18:51:13.004Z" }, + { url = "https://files.pythonhosted.org/packages/15/57/28d79b44b51933119e21f65479d0864a8d5893e494cf5daab15df0247c17/charset_normalizer-3.4.6-cp312-cp312-win_arm64.whl", hash = "sha256:899d28f422116b08be5118ef350c292b36fc15ec2daeb9ea987c89281c7bb5c4", size = 142817, upload-time = "2026-03-15T18:51:14.408Z" }, + { url = "https://files.pythonhosted.org/packages/2a/68/687187c7e26cb24ccbd88e5069f5ef00eba804d36dde11d99aad0838ab45/charset_normalizer-3.4.6-py3-none-any.whl", hash = "sha256:947cf925bc916d90adba35a64c82aace04fa39b46b52d4630ece166655905a69", size = 61455, upload-time = "2026-03-15T18:53:23.833Z" }, ] [[package]] @@ -1097,7 +1141,7 @@ wheels = [ [[package]] name = "clickhouse-connect" -version = "0.10.0" +version = "0.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1106,29 +1150,29 @@ dependencies = [ { name = "urllib3" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7b/fd/f8bea1157d40f117248dcaa9abdbf68c729513fcf2098ab5cb4aa58768b8/clickhouse_connect-0.10.0.tar.gz", hash = "sha256:a0256328802c6e5580513e197cef7f9ba49a99fc98e9ba410922873427569564", size = 104753, upload-time = "2025-11-14T20:31:00.947Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/0e/96958db88b6ce6e9d96dc7a836f12c7644934b3a436b04843f19eb8da2db/clickhouse_connect-0.14.1.tar.gz", hash = "sha256:dc107ae9ab7b86409049ae8abe21817543284b438291796d3dd639ad5496a1ab", size = 120093, upload-time = "2026-03-12T15:51:03.606Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/4e/f90caf963d14865c7a3f0e5d80b77e67e0fe0bf39b3de84110707746fa6b/clickhouse_connect-0.10.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:195f1824405501b747b572e1365c6265bb1629eeb712ce91eda91da3c5794879", size = 272911, upload-time = "2025-11-14T20:29:57.129Z" }, - { url = "https://files.pythonhosted.org/packages/50/c7/e01bd2dd80ea4fbda8968e5022c60091a872fd9de0a123239e23851da231/clickhouse_connect-0.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7907624635fe7f28e1b85c7c8b125a72679a63ecdb0b9f4250b704106ef438f8", size = 265938, upload-time = "2025-11-14T20:29:58.443Z" }, - { url = "https://files.pythonhosted.org/packages/f4/07/8b567b949abca296e118331d13380bbdefa4225d7d1d32233c59d4b4b2e1/clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:60772faa54d56f0fa34650460910752a583f5948f44dddeabfafaecbca21fc54", size = 1113548, upload-time = "2025-11-14T20:29:59.781Z" }, - { url = "https://files.pythonhosted.org/packages/9c/13/11f2d37fc95e74d7e2d80702cde87666ce372486858599a61f5209e35fc5/clickhouse_connect-0.10.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7fe2a6cd98517330c66afe703fb242c0d3aa2c91f2f7dc9fb97c122c5c60c34b", size = 1135061, upload-time = "2025-11-14T20:30:01.244Z" }, - { url = "https://files.pythonhosted.org/packages/a0/d0/517181ea80060f84d84cff4d42d330c80c77bb352b728fb1f9681fbad291/clickhouse_connect-0.10.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a2427d312bc3526520a0be8c648479af3f6353da7a33a62db2368d6203b08efd", size = 1105105, upload-time = "2025-11-14T20:30:02.679Z" }, - { url = "https://files.pythonhosted.org/packages/7c/b2/4ad93e898562725b58c537cad83ab2694c9b1c1ef37fa6c3f674bdad366a/clickhouse_connect-0.10.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:63bbb5721bfece698e155c01b8fa95ce4377c584f4d04b43f383824e8a8fa129", size = 1150791, upload-time = "2025-11-14T20:30:03.824Z" }, - { url = "https://files.pythonhosted.org/packages/45/a4/fdfbfacc1fa67b8b1ce980adcf42f9e3202325586822840f04f068aff395/clickhouse_connect-0.10.0-cp311-cp311-win32.whl", hash = "sha256:48554e836c6b56fe0854d9a9f565569010583d4960094d60b68a53f9f83042f0", size = 244014, upload-time = "2025-11-14T20:30:05.157Z" }, - { url = "https://files.pythonhosted.org/packages/08/50/cf53f33f4546a9ce2ab1b9930db4850aa1ae53bff1e4e4fa97c566cdfa19/clickhouse_connect-0.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:9eb8df083e5fda78ac7249938691c2c369e8578b5df34c709467147e8289f1d9", size = 262356, upload-time = "2025-11-14T20:30:06.478Z" }, - { url = "https://files.pythonhosted.org/packages/9e/59/fadbbf64f4c6496cd003a0a3c9223772409a86d0eea9d4ff45d2aa88aabf/clickhouse_connect-0.10.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b090c7d8e602dd084b2795265cd30610461752284763d9ad93a5d619a0e0ff21", size = 276401, upload-time = "2025-11-14T20:30:07.469Z" }, - { url = "https://files.pythonhosted.org/packages/1c/e3/781f9970f2ef202410f0d64681e42b2aecd0010097481a91e4df186a36c7/clickhouse_connect-0.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b8a708d38b81dcc8c13bb85549c904817e304d2b7f461246fed2945524b7a31b", size = 268193, upload-time = "2025-11-14T20:30:08.503Z" }, - { url = "https://files.pythonhosted.org/packages/f0/e0/64ab66b38fce762b77b5203a4fcecc603595f2a2361ce1605fc7bb79c835/clickhouse_connect-0.10.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3646fc9184a5469b95cf4a0846e6954e6e9e85666f030a5d2acae58fa8afb37e", size = 1123810, upload-time = "2025-11-14T20:30:09.62Z" }, - { url = "https://files.pythonhosted.org/packages/f5/03/19121aecf11a30feaf19049be96988131798c54ac6ba646a38e5faecaa0a/clickhouse_connect-0.10.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fe7e6be0f40a8a77a90482944f5cc2aa39084c1570899e8d2d1191f62460365b", size = 1153409, upload-time = "2025-11-14T20:30:10.855Z" }, - { url = "https://files.pythonhosted.org/packages/ce/ee/63870fd8b666c6030393950ad4ee76b7b69430f5a49a5d3fa32a70b11942/clickhouse_connect-0.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:88b4890f13163e163bf6fa61f3a013bb974c95676853b7a4e63061faf33911ac", size = 1104696, upload-time = "2025-11-14T20:30:12.187Z" }, - { url = "https://files.pythonhosted.org/packages/e9/bc/fcd8da1c4d007ebce088783979c495e3d7360867cfa8c91327ed235778f5/clickhouse_connect-0.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6286832cc79affc6fddfbf5563075effa65f80e7cd1481cf2b771ce317c67d08", size = 1156389, upload-time = "2025-11-14T20:30:13.385Z" }, - { url = "https://files.pythonhosted.org/packages/4e/33/7cb99cc3fc503c23fd3a365ec862eb79cd81c8dc3037242782d709280fa9/clickhouse_connect-0.10.0-cp312-cp312-win32.whl", hash = "sha256:92b8b6691a92d2613ee35f5759317bd4be7ba66d39bf81c4deed620feb388ca6", size = 243682, upload-time = "2025-11-14T20:30:14.52Z" }, - { url = "https://files.pythonhosted.org/packages/48/5c/12eee6a1f5ecda2dfc421781fde653c6d6ca6f3080f24547c0af40485a5a/clickhouse_connect-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:1159ee2c33e7eca40b53dda917a8b6a2ed889cb4c54f3d83b303b31ddb4f351d", size = 262790, upload-time = "2025-11-14T20:30:15.555Z" }, + { url = "https://files.pythonhosted.org/packages/66/b0/04bc82ca70d4dcc35987c83e4ef04f6dec3c29d3cce4cda3523ebf4498dc/clickhouse_connect-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2b1d1acb8f64c3cd9d922d9e8c0b6328238c4a38e084598c86cc95a0edbd8bd", size = 278797, upload-time = "2026-03-12T15:49:34.728Z" }, + { url = "https://files.pythonhosted.org/packages/97/03/f8434ed43946dcab2d8b4ccf8e90b1c6d69abea0fa8b8aaddb1dc9931657/clickhouse_connect-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:573f3e5a6b49135b711c086050f46510d4738cc09e5a354cc18ef26f8de5cd98", size = 271849, upload-time = "2026-03-12T15:49:35.881Z" }, + { url = "https://files.pythonhosted.org/packages/a0/db/b3665f4d855c780be8d00638d874fc0d62613d1f1c06ffcad7c11a333f06/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:86b28932faab182a312779e5c3cf341abe19d31028a399bda9d8b06b3b9adab4", size = 1090975, upload-time = "2026-03-12T15:49:37.064Z" }, + { url = "https://files.pythonhosted.org/packages/ea/a2/7ba2d9669c5771734573397b034169653cdf3348dc4cc66bd66d8ab18910/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc9650906ff96452c2b5676a7e68e8a77a5642504596f8482e0f3c0ccdffbf1", size = 1095899, upload-time = "2026-03-12T15:49:38.36Z" }, + { url = "https://files.pythonhosted.org/packages/e2/f4/0394af37b491ca832610f2ca7a129e85d8d857d40c94a42f2c2e6d3d9481/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b379749a962599f9d6ec81e773a3b907ac58b001f4a977e4ac397f6a76fedff2", size = 1077567, upload-time = "2026-03-12T15:49:40.027Z" }, + { url = "https://files.pythonhosted.org/packages/9a/b8/9279a88afac94c262b55cc75aadc6a3e83f7fa1641e618f9060d9d38415f/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43ccb5debd13d41b97af81940c0cac01e92d39f17131d984591bedee13439a5d", size = 1100264, upload-time = "2026-03-12T15:49:41.414Z" }, + { url = "https://files.pythonhosted.org/packages/19/36/20e19ab392c211b83c967e275eb46f663853e0b8ce4da89056fda8a35fc6/clickhouse_connect-0.14.1-cp311-cp311-win32.whl", hash = "sha256:13cbe46c04be8e49da4f6aed698f2570a5295d15f498dd5511b4f761d1ef0edc", size = 250488, upload-time = "2026-03-12T15:49:42.649Z" }, + { url = "https://files.pythonhosted.org/packages/9d/3b/74a07e692a21cad4692e72595cdefbd709bd74a9f778c7334d57a98ee548/clickhouse_connect-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:7038cf547c542a17a465e062cd837659f46f99c991efcb010a9ea08ce70960ab", size = 268730, upload-time = "2026-03-12T15:49:44.225Z" }, + { url = "https://files.pythonhosted.org/packages/58/9e/d84a14241967b3aa1e657bbbee83e2eee02d3d6df1ebe8edd4ed72cd8643/clickhouse_connect-0.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:97665169090889a8bc4dbae4a5fc758b91a23e49a8f8ddc1ae993f18f6d71e02", size = 280679, upload-time = "2026-03-12T15:49:45.497Z" }, + { url = "https://files.pythonhosted.org/packages/d8/29/80835a980be6298a7a2ae42d5a14aab0c9c066ecafe1763bc1958a6f6f0f/clickhouse_connect-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3ee6b513ca7d83e0f7b46d87bc2e48260316431cb466680e3540400379bcd1db", size = 271570, upload-time = "2026-03-12T15:49:46.721Z" }, + { url = "https://files.pythonhosted.org/packages/8b/bf/25c17cb91d72143742d2b060c6954e8000a7753c1fd21f7bf8b49ef2bd89/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a0e8a3f46aba99f1c574927d196e12f1ee689e31c41bf0caec86ad3e181abf3", size = 1115637, upload-time = "2026-03-12T15:49:47.921Z" }, + { url = "https://files.pythonhosted.org/packages/2d/5f/5d5df3585d98889aedc55c9eeb2ea90dba27ec4329eee392101619daf0c0/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25698cddcdd6c2e4ea12dc5c56d6035d77fc99c5d75e96a54123826c36fdd8ae", size = 1131995, upload-time = "2026-03-12T15:49:49.791Z" }, + { url = "https://files.pythonhosted.org/packages/ad/50/acc9f4c6a1d712f2ed11626f8451eff222e841cf0809655362f0e90454b6/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:29ab49e5cac44b830b58de73d17a7d895f6c362bf67a50134ff405b428774f44", size = 1095380, upload-time = "2026-03-12T15:49:51.388Z" }, + { url = "https://files.pythonhosted.org/packages/08/18/1ef01beee93d243ec9d9c37f0ce62b3083478a5dd7f59cc13279600cd3a5/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3cbf7d7a134692bacd68dd5f8661e87f5db94af60db9f3a74bd732596794910a", size = 1127217, upload-time = "2026-03-12T15:49:53.016Z" }, + { url = "https://files.pythonhosted.org/packages/18/e2/b4daee8287dc49eb9918c77b1e57f5644e47008f719b77281bf5fca63f6e/clickhouse_connect-0.14.1-cp312-cp312-win32.whl", hash = "sha256:6f295b66f3e2ed931dd0d3bb80e00ee94c6f4a584b2dc6d998872b2e0ceaa706", size = 250775, upload-time = "2026-03-12T15:49:54.639Z" }, + { url = "https://files.pythonhosted.org/packages/01/c7/7b55d346952fcd8f0f491faca4449f607a04764fd23cada846dc93facb9e/clickhouse_connect-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:c6bb2cce37041c90f8a3b1b380665acbaf252f125e401c13ce8f8df105378f69", size = 269353, upload-time = "2026-03-12T15:49:55.854Z" }, ] [[package]] name = "clickzetta-connector-python" -version = "0.8.107" +version = "0.8.106" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, @@ -1142,7 +1186,16 @@ dependencies = [ { name = "urllib3" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/19/b4/91dfe25592bbcaf7eede05849c77d09d43a2656943585bbcf7ba4cc604bc/clickzetta_connector_python-0.8.107-py3-none-any.whl", hash = "sha256:7f28752bfa0a50e89ed218db0540c02c6bfbfdae3589ac81cf28523d7caa93b0", size = 76864, upload-time = "2025-12-01T07:56:39.177Z" }, + { url = "https://files.pythonhosted.org/packages/23/38/749c708619f402d4d582dfa73fbeb64ade77b1f250a93bd064d2a1aa3776/clickzetta_connector_python-0.8.106-py3-none-any.whl", hash = "sha256:120d6700051d97609dbd6655c002ab3bc260b7c8e67d39dfc7191e749563f7b4", size = 78121, upload-time = "2025-10-29T02:38:15.014Z" }, +] + +[[package]] +name = "cloudpathlib" +version = "0.23.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f4/18/2ac35d6b3015a0c74e923d94fc69baf8307f7c3233de015d69f99e17afa8/cloudpathlib-0.23.0.tar.gz", hash = "sha256:eb38a34c6b8a048ecfd2b2f60917f7cbad4a105b7c979196450c2f541f4d6b4b", size = 53126, upload-time = "2025-10-07T22:47:56.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ae/8a/c4bb04426d608be4a3171efa2e233d2c59a5c8937850c10d098e126df18e/cloudpathlib-0.23.0-py3-none-any.whl", hash = "sha256:8520b3b01468fee77de37ab5d50b1b524ea6b4a8731c35d1b7407ac0cd716002", size = 62755, upload-time = "2025-10-07T22:47:54.905Z" }, ] [[package]] @@ -1178,20 +1231,21 @@ wheels = [ ] [[package]] -name = "coloredlogs" -version = "15.0.1" +name = "confection" +version = "0.1.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "humanfriendly" }, + { name = "pydantic" }, + { name = "srsly" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cc/c7/eed8f27100517e8c0e6b923d5f0845d0cb99763da6fdee00478f91db7325/coloredlogs-15.0.1.tar.gz", hash = "sha256:7c991aa71a4577af2f82600d8f8f3a89f936baeaf9b50a9c197da014e5bf16b0", size = 278520, upload-time = "2021-06-11T10:22:45.202Z" } +sdist = { url = "https://files.pythonhosted.org/packages/51/d3/57c6631159a1b48d273b40865c315cf51f89df7a9d1101094ef12e3a37c2/confection-0.1.5.tar.gz", hash = "sha256:8e72dd3ca6bd4f48913cd220f10b8275978e740411654b6e8ca6d7008c590f0e", size = 38924, upload-time = "2024-05-31T16:17:01.559Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/06/3d6badcf13db419e25b07041d9c7b4a2c331d3f4e7134445ec5df57714cd/coloredlogs-15.0.1-py2.py3-none-any.whl", hash = "sha256:612ee75c546f53e92e70049c9dbfcc18c935a2b9a53b66085ce9ef6a6e5c0934", size = 46018, upload-time = "2021-06-11T10:22:42.561Z" }, + { url = "https://files.pythonhosted.org/packages/0c/00/3106b1854b45bd0474ced037dfe6b73b90fe68a68968cef47c23de3d43d2/confection-0.1.5-py3-none-any.whl", hash = "sha256:e29d3c3f8eac06b3f77eb9dfb4bf2fc6bcc9622a98ca00a698e3d019c6430b14", size = 35451, upload-time = "2024-05-31T16:16:59.075Z" }, ] [[package]] name = "cos-python-sdk-v5" -version = "1.9.38" +version = "1.9.41" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crcmod" }, @@ -1200,56 +1254,68 @@ dependencies = [ { name = "six" }, { name = "xmltodict" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/24/3c/d208266fec7cc3221b449e236b87c3fc1999d5ac4379d4578480321cfecc/cos_python_sdk_v5-1.9.38.tar.gz", hash = "sha256:491a8689ae2f1a6f04dacba66a877b2c8d361456f9cfd788ed42170a1cbf7a9f", size = 98092, upload-time = "2025-07-22T07:56:20.34Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/38/c0029f413f51238aa2319715f45d74bcae931768e36c7e4604b02f407c6c/cos_python_sdk_v5-1.9.41.tar.gz", hash = "sha256:68f4be7d8fe27a1d186b3159b93c622816e398effdc236eddd442b86db592b82", size = 102625, upload-time = "2026-01-06T07:00:11.692Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ab/c8/c9c156aa3bc7caba9b4f8a2b6abec3da6263215988f3fec0ea843f137a10/cos_python_sdk_v5-1.9.38-py3-none-any.whl", hash = "sha256:1d3dd3be2bd992b2e9c2dcd018e2596aa38eab022dbc86b4a5d14c8fc88370e6", size = 92601, upload-time = "2025-08-17T05:12:30.867Z" }, + { url = "https://files.pythonhosted.org/packages/aa/2f/ead3fb551509fdc94e4a42093b770e3de2827ff7227570165df5e35c2a3e/cos_python_sdk_v5-1.9.41-py3-none-any.whl", hash = "sha256:f465aae43a4ba3f1caa8caeaca838d0395932f6848e89d6dde2807725e3c88a0", size = 98285, upload-time = "2026-01-06T06:43:02.754Z" }, ] [[package]] name = "couchbase" -version = "4.3.6" +version = "4.5.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2f/70/7cf92b2443330e7a4b626a02fe15fbeb1531337d75e6ae6393294e960d18/couchbase-4.3.6.tar.gz", hash = "sha256:d58c5ccdad5d85fc026f328bf4190c4fc0041fdbe68ad900fb32fc5497c3f061", size = 6517695, upload-time = "2025-05-15T17:21:38.157Z" } +sdist = { url = "https://files.pythonhosted.org/packages/73/2f/8f92e743a91c2f4e2ebad0bcfc31ef386c817c64415d89bf44e64dde227a/couchbase-4.5.0.tar.gz", hash = "sha256:fb74386ea5e807ae12cfa294fa6740fe6be3ecaf3bb9ce4fb9ea73706ed05982", size = 6562752, upload-time = "2025-09-30T01:27:37.423Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f3/0a/eae21d3a9331f7c93e8483f686e1bcb9e3b48f2ce98193beb0637a620926/couchbase-4.3.6-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:4c10fd26271c5630196b9bcc0dd7e17a45fa9c7e46ed5756e5690d125423160c", size = 4775710, upload-time = "2025-05-15T17:20:29.388Z" }, - { url = "https://files.pythonhosted.org/packages/f6/98/0ca042a42f5807bbf8050f52fff39ebceebc7bea7e5897907758f3e1ad39/couchbase-4.3.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:811eee7a6013cea7b15a718e201ee1188df162c656d27c7882b618ab57a08f3a", size = 4020743, upload-time = "2025-05-15T17:20:31.515Z" }, - { url = "https://files.pythonhosted.org/packages/f8/0f/c91407cb082d2322217e8f7ca4abb8eda016a81a4db5a74b7ac6b737597d/couchbase-4.3.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2fc177e0161beb1e6e8c4b9561efcb97c51aed55a77ee11836ca194d33ae22b7", size = 4796091, upload-time = "2025-05-15T17:20:33.818Z" }, - { url = "https://files.pythonhosted.org/packages/8c/02/5567b660543828bdbbc68dcae080e388cb0be391aa8a97cce9d8c8a6c147/couchbase-4.3.6-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:02afb1c1edd6b215f702510412b5177ed609df8135930c23789bbc5901dd1b45", size = 5015684, upload-time = "2025-05-15T17:20:36.364Z" }, - { url = "https://files.pythonhosted.org/packages/dc/d1/767908826d5bdd258addab26d7f1d21bc42bafbf5f30d1b556ace06295af/couchbase-4.3.6-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:594e9eb17bb76ba8e10eeee17a16aef897dd90d33c6771cf2b5b4091da415b32", size = 5673513, upload-time = "2025-05-15T17:20:38.972Z" }, - { url = "https://files.pythonhosted.org/packages/f2/25/39ecde0a06692abce8bb0df4f15542933f05883647a1a57cdc7bbed9c77c/couchbase-4.3.6-cp311-cp311-win_amd64.whl", hash = "sha256:db22c56e38b8313f65807aa48309c8b8c7c44d5517b9ff1d8b4404d4740ec286", size = 4010728, upload-time = "2025-05-15T17:20:43.286Z" }, - { url = "https://files.pythonhosted.org/packages/b1/55/c12b8f626de71363fbe30578f4a0de1b8bb41afbe7646ff8538c3b38ce2a/couchbase-4.3.6-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:a2ae13432b859f513485d4cee691e1e4fce4af23ed4218b9355874b146343f8c", size = 4693517, upload-time = "2025-05-15T17:20:45.433Z" }, - { url = "https://files.pythonhosted.org/packages/a1/aa/2184934d283d99b34a004f577bf724d918278a2962781ca5690d4fa4b6c6/couchbase-4.3.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4ea5ca7e34b5d023c8bab406211ab5d71e74a976ba25fa693b4f8e6c74f85aa2", size = 4022393, upload-time = "2025-05-15T17:20:47.442Z" }, - { url = "https://files.pythonhosted.org/packages/80/29/ba6d3b205a51c04c270c1b56ea31da678b7edc565b35a34237ec2cfc708d/couchbase-4.3.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6eaca0a71fd8f9af4344b7d6474d7b74d1784ae9a658f6bc3751df5f9a4185ae", size = 4798396, upload-time = "2025-05-15T17:20:49.473Z" }, - { url = "https://files.pythonhosted.org/packages/4a/94/d7d791808bd9064c01f965015ff40ee76e6bac10eaf2c73308023b9bdedf/couchbase-4.3.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0470378b986f69368caed6d668ac6530e635b0c1abaef3d3f524cfac0dacd878", size = 5018099, upload-time = "2025-05-15T17:20:52.541Z" }, - { url = "https://files.pythonhosted.org/packages/a6/04/cec160f9f4b862788e2a0167616472a5695b2f569bd62204938ab674835d/couchbase-4.3.6-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:374ce392558f1688ac073aa0b15c256b1a441201d965811fd862357ff05d27a9", size = 5672633, upload-time = "2025-05-15T17:20:55.994Z" }, - { url = "https://files.pythonhosted.org/packages/1b/a2/1da2ab45412b9414e2c6a578e0e7a24f29b9261ef7de11707c2fc98045b8/couchbase-4.3.6-cp312-cp312-win_amd64.whl", hash = "sha256:cd734333de34d8594504c163bb6c47aea9cc1f2cefdf8e91875dd9bf14e61e29", size = 4013298, upload-time = "2025-05-15T17:20:59.533Z" }, + { url = "https://files.pythonhosted.org/packages/ca/a7/ba28fcab4f211e570582990d9592d8a57566158a0712fbc9d0d9ac486c2a/couchbase-4.5.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:3d3258802baa87d9ffeccbb2b31dcabe2a4ef27c9be81e0d3d710fd7436da24a", size = 5037084, upload-time = "2025-09-30T01:25:16.748Z" }, + { url = "https://files.pythonhosted.org/packages/85/38/f26912b56a41f22ab9606304014ef1435fc4bef76144382f91c1a4ce1d4c/couchbase-4.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:18b47f1f3a2007f88203f611570d96e62bb1fb9568dec0483a292a5e87f6d1df", size = 4323514, upload-time = "2025-09-30T01:25:22.628Z" }, + { url = "https://files.pythonhosted.org/packages/35/a6/5ef140f8681a2488ed6eb2a2bc9fc918b6f11e9f71bbad75e4de73b8dbf3/couchbase-4.5.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9c2a16830db9437aae92e31f9ceda6c7b70707e316152fc99552b866b09a1967", size = 5181111, upload-time = "2025-09-30T01:25:30.538Z" }, + { url = "https://files.pythonhosted.org/packages/7b/2e/1f0f06e920dbae07c3d8af6b2af3d5213e43d3825e0931c19564fe4d5c1b/couchbase-4.5.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4a86774680e46488a7955c6eae8fba5200a1fd5f9de9ac0a34acb6c87dc2b513", size = 5442969, upload-time = "2025-09-30T01:25:37.976Z" }, + { url = "https://files.pythonhosted.org/packages/9a/2e/6ece47df4d987dbeaae3fdcf7aa4d6a8154c949c28e925f01074dfd0b8b8/couchbase-4.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:b68dae005ab4c157930c76a3116e478df25aa1af00fa10cc1cc755df1831ad59", size = 6108562, upload-time = "2025-09-30T01:25:45.674Z" }, + { url = "https://files.pythonhosted.org/packages/be/a7/2f84a1d117cf70ad30e8b08ae9b1c4a03c65146bab030ed6eb84f454045b/couchbase-4.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:cbc50956fb68d42929d21d969f4512b38798259ae48c47cbf6d676cc3a01b058", size = 4269303, upload-time = "2025-09-30T01:25:49.341Z" }, + { url = "https://files.pythonhosted.org/packages/2f/bc/3b00403edd8b188a93f48b8231dbf7faf7b40d318d3e73bb0e68c4965bbd/couchbase-4.5.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:be1ac2bf7cbccf28eebd7fa8b1d7199fbe84c96b0f7f2c0d69963b1d6ce53985", size = 5128307, upload-time = "2025-09-30T01:25:53.615Z" }, + { url = "https://files.pythonhosted.org/packages/7f/52/2ccfa8c8650cc341813713a47eeeb8ad13a25e25b0f4747d224106602a24/couchbase-4.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:035c394d38297c484bd57fc92b27f6a571a36ab5675b4ec873fd15bf65e8f28e", size = 4326149, upload-time = "2025-09-30T01:25:57.524Z" }, + { url = "https://files.pythonhosted.org/packages/32/80/fe3f074f321474c824ec67b97c5c4aa99047d45c777bb29353f9397c6604/couchbase-4.5.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:117685f6827abbc332e151625b0a9890c2fafe0d3c3d9e564b903d5c411abe5d", size = 5184623, upload-time = "2025-09-30T01:26:02.166Z" }, + { url = "https://files.pythonhosted.org/packages/f3/e5/86381f49e4cf1c6db23c397b6a32b532cd4df7b9975b0cd2da3db2ffe269/couchbase-4.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:632a918f81a7373832991b79b6ab429e56ef4ff68dfb3517af03f0e2be7e3e4f", size = 5446579, upload-time = "2025-09-30T01:26:09.39Z" }, + { url = "https://files.pythonhosted.org/packages/c8/85/a68d04233a279e419062ceb1c6866b61852c016d1854cd09cde7f00bc53c/couchbase-4.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:67fc0fd1a4535b5be093f834116a70fb6609085399e6b63539241b919da737b7", size = 6104619, upload-time = "2025-09-30T01:26:15.525Z" }, + { url = "https://files.pythonhosted.org/packages/56/8c/0511bac5dd2d998aeabcfba6a2804ecd9eb3d83f9d21cc3293a56fbc70a8/couchbase-4.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:02199b4528f3106c231c00aaf85b7cc6723accbc654b903bb2027f78a04d12f4", size = 4274424, upload-time = "2025-09-30T01:26:21.484Z" }, ] [[package]] name = "coverage" -version = "7.2.7" +version = "7.13.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/45/8b/421f30467e69ac0e414214856798d4bc32da1336df745e49e49ae5c1e2a8/coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59", size = 762575, upload-time = "2023-05-29T20:08:50.273Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179", size = 915967, upload-time = "2026-03-17T10:33:18.341Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c6/fa/529f55c9a1029c840bcc9109d5a15ff00478b7ff550a1ae361f8745f8ad5/coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f", size = 200895, upload-time = "2023-05-29T20:07:21.963Z" }, - { url = "https://files.pythonhosted.org/packages/67/d7/cd8fe689b5743fffac516597a1222834c42b80686b99f5b44ef43ccc2a43/coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe", size = 201120, upload-time = "2023-05-29T20:07:23.765Z" }, - { url = "https://files.pythonhosted.org/packages/8c/95/16eed713202406ca0a37f8ac259bbf144c9d24f9b8097a8e6ead61da2dbb/coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3", size = 233178, upload-time = "2023-05-29T20:07:25.281Z" }, - { url = "https://files.pythonhosted.org/packages/c1/49/4d487e2ad5d54ed82ac1101e467e8994c09d6123c91b2a962145f3d262c2/coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f", size = 230754, upload-time = "2023-05-29T20:07:27.044Z" }, - { url = "https://files.pythonhosted.org/packages/a7/cd/3ce94ad9d407a052dc2a74fbeb1c7947f442155b28264eb467ee78dea812/coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb", size = 232558, upload-time = "2023-05-29T20:07:28.743Z" }, - { url = "https://files.pythonhosted.org/packages/8f/a8/12cc7b261f3082cc299ab61f677f7e48d93e35ca5c3c2f7241ed5525ccea/coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833", size = 241509, upload-time = "2023-05-29T20:07:30.434Z" }, - { url = "https://files.pythonhosted.org/packages/04/fa/43b55101f75a5e9115259e8be70ff9279921cb6b17f04c34a5702ff9b1f7/coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97", size = 239924, upload-time = "2023-05-29T20:07:32.065Z" }, - { url = "https://files.pythonhosted.org/packages/68/5f/d2bd0f02aa3c3e0311986e625ccf97fdc511b52f4f1a063e4f37b624772f/coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a", size = 240977, upload-time = "2023-05-29T20:07:34.184Z" }, - { url = "https://files.pythonhosted.org/packages/ba/92/69c0722882643df4257ecc5437b83f4c17ba9e67f15dc6b77bad89b6982e/coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a", size = 203168, upload-time = "2023-05-29T20:07:35.869Z" }, - { url = "https://files.pythonhosted.org/packages/b1/96/c12ed0dfd4ec587f3739f53eb677b9007853fd486ccb0e7d5512a27bab2e/coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562", size = 204185, upload-time = "2023-05-29T20:07:37.39Z" }, - { url = "https://files.pythonhosted.org/packages/ff/d5/52fa1891d1802ab2e1b346d37d349cb41cdd4fd03f724ebbf94e80577687/coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4", size = 201020, upload-time = "2023-05-29T20:07:38.724Z" }, - { url = "https://files.pythonhosted.org/packages/24/df/6765898d54ea20e3197a26d26bb65b084deefadd77ce7de946b9c96dfdc5/coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4", size = 233994, upload-time = "2023-05-29T20:07:40.274Z" }, - { url = "https://files.pythonhosted.org/packages/15/81/b108a60bc758b448c151e5abceed027ed77a9523ecbc6b8a390938301841/coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01", size = 231358, upload-time = "2023-05-29T20:07:41.998Z" }, - { url = "https://files.pythonhosted.org/packages/61/90/c76b9462f39897ebd8714faf21bc985b65c4e1ea6dff428ea9dc711ed0dd/coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6", size = 233316, upload-time = "2023-05-29T20:07:43.539Z" }, - { url = "https://files.pythonhosted.org/packages/04/d6/8cba3bf346e8b1a4fb3f084df7d8cea25a6b6c56aaca1f2e53829be17e9e/coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d", size = 240159, upload-time = "2023-05-29T20:07:44.982Z" }, - { url = "https://files.pythonhosted.org/packages/6e/ea/4a252dc77ca0605b23d477729d139915e753ee89e4c9507630e12ad64a80/coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de", size = 238127, upload-time = "2023-05-29T20:07:46.522Z" }, - { url = "https://files.pythonhosted.org/packages/9f/5c/d9760ac497c41f9c4841f5972d0edf05d50cad7814e86ee7d133ec4a0ac8/coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d", size = 239833, upload-time = "2023-05-29T20:07:47.992Z" }, - { url = "https://files.pythonhosted.org/packages/69/8c/26a95b08059db1cbb01e4b0e6d40f2e9debb628c6ca86b78f625ceaf9bab/coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511", size = 203463, upload-time = "2023-05-29T20:07:49.939Z" }, - { url = "https://files.pythonhosted.org/packages/b7/00/14b00a0748e9eda26e97be07a63cc911108844004687321ddcc213be956c/coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3", size = 204347, upload-time = "2023-05-29T20:07:51.909Z" }, + { url = "https://files.pythonhosted.org/packages/4b/37/d24c8f8220ff07b839b2c043ea4903a33b0f455abe673ae3c03bbdb7f212/coverage-7.13.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66a80c616f80181f4d643b0f9e709d97bcea413ecd9631e1dedc7401c8e6695d", size = 219381, upload-time = "2026-03-17T10:30:14.68Z" }, + { url = "https://files.pythonhosted.org/packages/35/8b/cd129b0ca4afe886a6ce9d183c44d8301acbd4ef248622e7c49a23145605/coverage-7.13.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:145ede53ccbafb297c1c9287f788d1bc3efd6c900da23bf6931b09eafc931587", size = 219880, upload-time = "2026-03-17T10:30:16.231Z" }, + { url = "https://files.pythonhosted.org/packages/55/2f/e0e5b237bffdb5d6c530ce87cc1d413a5b7d7dfd60fb067ad6d254c35c76/coverage-7.13.5-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0672854dc733c342fa3e957e0605256d2bf5934feeac328da9e0b5449634a642", size = 250303, upload-time = "2026-03-17T10:30:17.748Z" }, + { url = "https://files.pythonhosted.org/packages/92/be/b1afb692be85b947f3401375851484496134c5554e67e822c35f28bf2fbc/coverage-7.13.5-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ec10e2a42b41c923c2209b846126c6582db5e43a33157e9870ba9fb70dc7854b", size = 252218, upload-time = "2026-03-17T10:30:19.804Z" }, + { url = "https://files.pythonhosted.org/packages/da/69/2f47bb6fa1b8d1e3e5d0c4be8ccb4313c63d742476a619418f85740d597b/coverage-7.13.5-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:be3d4bbad9d4b037791794ddeedd7d64a56f5933a2c1373e18e9e568b9141686", size = 254326, upload-time = "2026-03-17T10:30:21.321Z" }, + { url = "https://files.pythonhosted.org/packages/d5/d0/79db81da58965bd29dabc8f4ad2a2af70611a57cba9d1ec006f072f30a54/coverage-7.13.5-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4d2afbc5cc54d286bfb54541aa50b64cdb07a718227168c87b9e2fb8f25e1743", size = 256267, upload-time = "2026-03-17T10:30:23.094Z" }, + { url = "https://files.pythonhosted.org/packages/e5/32/d0d7cc8168f91ddab44c0ce4806b969df5f5fdfdbb568eaca2dbc2a04936/coverage-7.13.5-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:3ad050321264c49c2fa67bb599100456fc51d004b82534f379d16445da40fb75", size = 250430, upload-time = "2026-03-17T10:30:25.311Z" }, + { url = "https://files.pythonhosted.org/packages/4d/06/a055311d891ddbe231cd69fdd20ea4be6e3603ffebddf8704b8ca8e10a3c/coverage-7.13.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7300c8a6d13335b29bb76d7651c66af6bd8658517c43499f110ddc6717bfc209", size = 252017, upload-time = "2026-03-17T10:30:27.284Z" }, + { url = "https://files.pythonhosted.org/packages/d6/f6/d0fd2d21e29a657b5f77a2fe7082e1568158340dceb941954f776dce1b7b/coverage-7.13.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:eb07647a5738b89baab047f14edd18ded523de60f3b30e75c2acc826f79c839a", size = 250080, upload-time = "2026-03-17T10:30:29.481Z" }, + { url = "https://files.pythonhosted.org/packages/4e/ab/0d7fb2efc2e9a5eb7ddcc6e722f834a69b454b7e6e5888c3a8567ecffb31/coverage-7.13.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:9adb6688e3b53adffefd4a52d72cbd8b02602bfb8f74dcd862337182fd4d1a4e", size = 253843, upload-time = "2026-03-17T10:30:31.301Z" }, + { url = "https://files.pythonhosted.org/packages/ba/6f/7467b917bbf5408610178f62a49c0ed4377bb16c1657f689cc61470da8ce/coverage-7.13.5-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:7c8d4bc913dd70b93488d6c496c77f3aff5ea99a07e36a18f865bca55adef8bd", size = 249802, upload-time = "2026-03-17T10:30:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/75/2c/1172fb689df92135f5bfbbd69fc83017a76d24ea2e2f3a1154007e2fb9f8/coverage-7.13.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0e3c426ffc4cd952f54ee9ffbdd10345709ecc78a3ecfd796a57236bfad0b9b8", size = 250707, upload-time = "2026-03-17T10:30:35.2Z" }, + { url = "https://files.pythonhosted.org/packages/67/21/9ac389377380a07884e3b48ba7a620fcd9dbfaf1d40565facdc6b36ec9ef/coverage-7.13.5-cp311-cp311-win32.whl", hash = "sha256:259b69bb83ad9894c4b25be2528139eecba9a82646ebdda2d9db1ba28424a6bf", size = 221880, upload-time = "2026-03-17T10:30:36.775Z" }, + { url = "https://files.pythonhosted.org/packages/af/7f/4cd8a92531253f9d7c1bbecd9fa1b472907fb54446ca768c59b531248dc5/coverage-7.13.5-cp311-cp311-win_amd64.whl", hash = "sha256:258354455f4e86e3e9d0d17571d522e13b4e1e19bf0f8596bcf9476d61e7d8a9", size = 222816, upload-time = "2026-03-17T10:30:38.891Z" }, + { url = "https://files.pythonhosted.org/packages/12/a6/1d3f6155fb0010ca68eba7fe48ca6c9da7385058b77a95848710ecf189b1/coverage-7.13.5-cp311-cp311-win_arm64.whl", hash = "sha256:bff95879c33ec8da99fc9b6fe345ddb5be6414b41d6d1ad1c8f188d26f36e028", size = 221483, upload-time = "2026-03-17T10:30:40.463Z" }, + { url = "https://files.pythonhosted.org/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01", size = 219554, upload-time = "2026-03-17T10:30:42.208Z" }, + { url = "https://files.pythonhosted.org/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422", size = 219908, upload-time = "2026-03-17T10:30:43.906Z" }, + { url = "https://files.pythonhosted.org/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f", size = 251419, upload-time = "2026-03-17T10:30:45.545Z" }, + { url = "https://files.pythonhosted.org/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5", size = 254159, upload-time = "2026-03-17T10:30:47.204Z" }, + { url = "https://files.pythonhosted.org/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376", size = 255270, upload-time = "2026-03-17T10:30:48.812Z" }, + { url = "https://files.pythonhosted.org/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256", size = 257538, upload-time = "2026-03-17T10:30:50.77Z" }, + { url = "https://files.pythonhosted.org/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c", size = 251821, upload-time = "2026-03-17T10:30:52.5Z" }, + { url = "https://files.pythonhosted.org/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5", size = 253191, upload-time = "2026-03-17T10:30:54.543Z" }, + { url = "https://files.pythonhosted.org/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09", size = 251337, upload-time = "2026-03-17T10:30:56.663Z" }, + { url = "https://files.pythonhosted.org/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9", size = 255404, upload-time = "2026-03-17T10:30:58.427Z" }, + { url = "https://files.pythonhosted.org/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf", size = 250903, upload-time = "2026-03-17T10:31:00.093Z" }, + { url = "https://files.pythonhosted.org/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c", size = 252780, upload-time = "2026-03-17T10:31:01.916Z" }, + { url = "https://files.pythonhosted.org/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf", size = 222093, upload-time = "2026-03-17T10:31:03.642Z" }, + { url = "https://files.pythonhosted.org/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810", size = 222900, upload-time = "2026-03-17T10:31:05.651Z" }, + { url = "https://files.pythonhosted.org/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de", size = 221515, upload-time = "2026-03-17T10:31:07.293Z" }, + { url = "https://files.pythonhosted.org/packages/9e/ee/a4cf96b8ce1e566ed238f0659ac2d3f007ed1d14b181bcb684e19561a69a/coverage-7.13.5-py3-none-any.whl", hash = "sha256:34b02417cf070e173989b3db962f7ed56d2f644307b2cf9d5a0f258e13084a61", size = 211346, upload-time = "2026-03-17T10:33:15.691Z" }, ] [package.optional-dependencies] @@ -1296,74 +1362,106 @@ sdist = { url = "https://files.pythonhosted.org/packages/6b/b0/e595ce2a2527e169c [[package]] name = "croniter" -version = "6.0.0" +version = "6.2.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "python-dateutil" }, - { name = "pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ad/2f/44d1ae153a0e27be56be43465e5cb39b9650c781e001e7864389deb25090/croniter-6.0.0.tar.gz", hash = "sha256:37c504b313956114a983ece2c2b07790b1f1094fe9d81cc94739214748255577", size = 64481, upload-time = "2024-12-17T17:17:47.32Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/de/5832661ed55107b8a09af3f0a2e71e0957226a59eb1dcf0a445cce6daf20/croniter-6.2.2.tar.gz", hash = "sha256:ba60832a5ec8e12e51b8691c3309a113d1cf6526bdf1a48150ce8ec7a532d0ab", size = 113762, upload-time = "2026-03-15T08:43:48.112Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/07/4b/290b4c3efd6417a8b0c284896de19b1d5855e6dbdb97d2a35e68fa42de85/croniter-6.0.0-py2.py3-none-any.whl", hash = "sha256:2f878c3856f17896979b2a4379ba1f09c83e374931ea15cc835c5dd2eee9b368", size = 25468, upload-time = "2024-12-17T17:17:45.359Z" }, + { url = "https://files.pythonhosted.org/packages/d0/39/783980e78cb92c2d7bdb1fc7dbc86e94ccc6d58224d76a7f1f51b6c51e30/croniter-6.2.2-py3-none-any.whl", hash = "sha256:a5d17b1060974d36251ea4faf388233eca8acf0d09cbd92d35f4c4ac8f279960", size = 45422, upload-time = "2026-03-15T08:43:46.626Z" }, ] [[package]] name = "cryptography" -version = "46.0.5" +version = "44.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/60/04/ee2a9e8542e4fa2773b81771ff8349ff19cdd56b7258a0cc442639052edb/cryptography-46.0.5.tar.gz", hash = "sha256:abace499247268e3757271b2f1e244b36b06f8515cf27c4d49468fc9eb16e93d", size = 750064, upload-time = "2026-02-10T19:18:38.255Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/81/b0bb27f2ba931a65409c6b8a8b358a7f03c0e46eceacddff55f7c84b1f3b/cryptography-46.0.5-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:351695ada9ea9618b3500b490ad54c739860883df6c1f555e088eaf25b1bbaad", size = 7176289, upload-time = "2026-02-10T19:17:08.274Z" }, - { url = "https://files.pythonhosted.org/packages/ff/9e/6b4397a3e3d15123de3b1806ef342522393d50736c13b20ec4c9ea6693a6/cryptography-46.0.5-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c18ff11e86df2e28854939acde2d003f7984f721eba450b56a200ad90eeb0e6b", size = 4275637, upload-time = "2026-02-10T19:17:10.53Z" }, - { url = "https://files.pythonhosted.org/packages/63/e7/471ab61099a3920b0c77852ea3f0ea611c9702f651600397ac567848b897/cryptography-46.0.5-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:4d7e3d356b8cd4ea5aff04f129d5f66ebdc7b6f8eae802b93739ed520c47c79b", size = 4424742, upload-time = "2026-02-10T19:17:12.388Z" }, - { url = "https://files.pythonhosted.org/packages/37/53/a18500f270342d66bf7e4d9f091114e31e5ee9e7375a5aba2e85a91e0044/cryptography-46.0.5-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:50bfb6925eff619c9c023b967d5b77a54e04256c4281b0e21336a130cd7fc263", size = 4277528, upload-time = "2026-02-10T19:17:13.853Z" }, - { url = "https://files.pythonhosted.org/packages/22/29/c2e812ebc38c57b40e7c583895e73c8c5adb4d1e4a0cc4c5a4fdab2b1acc/cryptography-46.0.5-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:803812e111e75d1aa73690d2facc295eaefd4439be1023fefc4995eaea2af90d", size = 4947993, upload-time = "2026-02-10T19:17:15.618Z" }, - { url = "https://files.pythonhosted.org/packages/6b/e7/237155ae19a9023de7e30ec64e5d99a9431a567407ac21170a046d22a5a3/cryptography-46.0.5-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ee190460e2fbe447175cda91b88b84ae8322a104fc27766ad09428754a618ed", size = 4456855, upload-time = "2026-02-10T19:17:17.221Z" }, - { url = "https://files.pythonhosted.org/packages/2d/87/fc628a7ad85b81206738abbd213b07702bcbdada1dd43f72236ef3cffbb5/cryptography-46.0.5-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:f145bba11b878005c496e93e257c1e88f154d278d2638e6450d17e0f31e558d2", size = 3984635, upload-time = "2026-02-10T19:17:18.792Z" }, - { url = "https://files.pythonhosted.org/packages/84/29/65b55622bde135aedf4565dc509d99b560ee4095e56989e815f8fd2aa910/cryptography-46.0.5-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e9251e3be159d1020c4030bd2e5f84d6a43fe54b6c19c12f51cde9542a2817b2", size = 4277038, upload-time = "2026-02-10T19:17:20.256Z" }, - { url = "https://files.pythonhosted.org/packages/bc/36/45e76c68d7311432741faf1fbf7fac8a196a0a735ca21f504c75d37e2558/cryptography-46.0.5-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:47fb8a66058b80e509c47118ef8a75d14c455e81ac369050f20ba0d23e77fee0", size = 4912181, upload-time = "2026-02-10T19:17:21.825Z" }, - { url = "https://files.pythonhosted.org/packages/6d/1a/c1ba8fead184d6e3d5afcf03d569acac5ad063f3ac9fb7258af158f7e378/cryptography-46.0.5-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:4c3341037c136030cb46e4b1e17b7418ea4cbd9dd207e4a6f3b2b24e0d4ac731", size = 4456482, upload-time = "2026-02-10T19:17:25.133Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e5/3fb22e37f66827ced3b902cf895e6a6bc1d095b5b26be26bd13c441fdf19/cryptography-46.0.5-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:890bcb4abd5a2d3f852196437129eb3667d62630333aacc13dfd470fad3aaa82", size = 4405497, upload-time = "2026-02-10T19:17:26.66Z" }, - { url = "https://files.pythonhosted.org/packages/1a/df/9d58bb32b1121a8a2f27383fabae4d63080c7ca60b9b5c88be742be04ee7/cryptography-46.0.5-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:80a8d7bfdf38f87ca30a5391c0c9ce4ed2926918e017c29ddf643d0ed2778ea1", size = 4667819, upload-time = "2026-02-10T19:17:28.569Z" }, - { url = "https://files.pythonhosted.org/packages/ea/ed/325d2a490c5e94038cdb0117da9397ece1f11201f425c4e9c57fe5b9f08b/cryptography-46.0.5-cp311-abi3-win32.whl", hash = "sha256:60ee7e19e95104d4c03871d7d7dfb3d22ef8a9b9c6778c94e1c8fcc8365afd48", size = 3028230, upload-time = "2026-02-10T19:17:30.518Z" }, - { url = "https://files.pythonhosted.org/packages/e9/5a/ac0f49e48063ab4255d9e3b79f5def51697fce1a95ea1370f03dc9db76f6/cryptography-46.0.5-cp311-abi3-win_amd64.whl", hash = "sha256:38946c54b16c885c72c4f59846be9743d699eee2b69b6988e0a00a01f46a61a4", size = 3480909, upload-time = "2026-02-10T19:17:32.083Z" }, - { url = "https://files.pythonhosted.org/packages/e2/fa/a66aa722105ad6a458bebd64086ca2b72cdd361fed31763d20390f6f1389/cryptography-46.0.5-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:4108d4c09fbbf2789d0c926eb4152ae1760d5a2d97612b92d508d96c861e4d31", size = 7170514, upload-time = "2026-02-10T19:17:56.267Z" }, - { url = "https://files.pythonhosted.org/packages/0f/04/c85bdeab78c8bc77b701bf0d9bdcf514c044e18a46dcff330df5448631b0/cryptography-46.0.5-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7d1f30a86d2757199cb2d56e48cce14deddf1f9c95f1ef1b64ee91ea43fe2e18", size = 4275349, upload-time = "2026-02-10T19:17:58.419Z" }, - { url = "https://files.pythonhosted.org/packages/5c/32/9b87132a2f91ee7f5223b091dc963055503e9b442c98fc0b8a5ca765fab0/cryptography-46.0.5-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:039917b0dc418bb9f6edce8a906572d69e74bd330b0b3fea4f79dab7f8ddd235", size = 4420667, upload-time = "2026-02-10T19:18:00.619Z" }, - { url = "https://files.pythonhosted.org/packages/a1/a6/a7cb7010bec4b7c5692ca6f024150371b295ee1c108bdc1c400e4c44562b/cryptography-46.0.5-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:ba2a27ff02f48193fc4daeadf8ad2590516fa3d0adeeb34336b96f7fa64c1e3a", size = 4276980, upload-time = "2026-02-10T19:18:02.379Z" }, - { url = "https://files.pythonhosted.org/packages/8e/7c/c4f45e0eeff9b91e3f12dbd0e165fcf2a38847288fcfd889deea99fb7b6d/cryptography-46.0.5-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:61aa400dce22cb001a98014f647dc21cda08f7915ceb95df0c9eaf84b4b6af76", size = 4939143, upload-time = "2026-02-10T19:18:03.964Z" }, - { url = "https://files.pythonhosted.org/packages/37/19/e1b8f964a834eddb44fa1b9a9976f4e414cbb7aa62809b6760c8803d22d1/cryptography-46.0.5-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:3ce58ba46e1bc2aac4f7d9290223cead56743fa6ab94a5d53292ffaac6a91614", size = 4453674, upload-time = "2026-02-10T19:18:05.588Z" }, - { url = "https://files.pythonhosted.org/packages/db/ed/db15d3956f65264ca204625597c410d420e26530c4e2943e05a0d2f24d51/cryptography-46.0.5-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:420d0e909050490d04359e7fdb5ed7e667ca5c3c402b809ae2563d7e66a92229", size = 3978801, upload-time = "2026-02-10T19:18:07.167Z" }, - { url = "https://files.pythonhosted.org/packages/41/e2/df40a31d82df0a70a0daf69791f91dbb70e47644c58581d654879b382d11/cryptography-46.0.5-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:582f5fcd2afa31622f317f80426a027f30dc792e9c80ffee87b993200ea115f1", size = 4276755, upload-time = "2026-02-10T19:18:09.813Z" }, - { url = "https://files.pythonhosted.org/packages/33/45/726809d1176959f4a896b86907b98ff4391a8aa29c0aaaf9450a8a10630e/cryptography-46.0.5-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:bfd56bb4b37ed4f330b82402f6f435845a5f5648edf1ad497da51a8452d5d62d", size = 4901539, upload-time = "2026-02-10T19:18:11.263Z" }, - { url = "https://files.pythonhosted.org/packages/99/0f/a3076874e9c88ecb2ecc31382f6e7c21b428ede6f55aafa1aa272613e3cd/cryptography-46.0.5-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:a3d507bb6a513ca96ba84443226af944b0f7f47dcc9a399d110cd6146481d24c", size = 4452794, upload-time = "2026-02-10T19:18:12.914Z" }, - { url = "https://files.pythonhosted.org/packages/02/ef/ffeb542d3683d24194a38f66ca17c0a4b8bf10631feef44a7ef64e631b1a/cryptography-46.0.5-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9f16fbdf4da055efb21c22d81b89f155f02ba420558db21288b3d0035bafd5f4", size = 4404160, upload-time = "2026-02-10T19:18:14.375Z" }, - { url = "https://files.pythonhosted.org/packages/96/93/682d2b43c1d5f1406ed048f377c0fc9fc8f7b0447a478d5c65ab3d3a66eb/cryptography-46.0.5-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:ced80795227d70549a411a4ab66e8ce307899fad2220ce5ab2f296e687eacde9", size = 4667123, upload-time = "2026-02-10T19:18:15.886Z" }, - { url = "https://files.pythonhosted.org/packages/45/2d/9c5f2926cb5300a8eefc3f4f0b3f3df39db7f7ce40c8365444c49363cbda/cryptography-46.0.5-cp38-abi3-win32.whl", hash = "sha256:02f547fce831f5096c9a567fd41bc12ca8f11df260959ecc7c3202555cc47a72", size = 3010220, upload-time = "2026-02-10T19:18:17.361Z" }, - { url = "https://files.pythonhosted.org/packages/48/ef/0c2f4a8e31018a986949d34a01115dd057bf536905dca38897bacd21fac3/cryptography-46.0.5-cp38-abi3-win_amd64.whl", hash = "sha256:556e106ee01aa13484ce9b0239bca667be5004efb0aabbed28d353df86445595", size = 3467050, upload-time = "2026-02-10T19:18:18.899Z" }, - { url = "https://files.pythonhosted.org/packages/eb/dd/2d9fdb07cebdf3d51179730afb7d5e576153c6744c3ff8fded23030c204e/cryptography-46.0.5-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:3b4995dc971c9fb83c25aa44cf45f02ba86f71ee600d81091c2f0cbae116b06c", size = 3476964, upload-time = "2026-02-10T19:18:20.687Z" }, - { url = "https://files.pythonhosted.org/packages/e9/6f/6cc6cc9955caa6eaf83660b0da2b077c7fe8ff9950a3c5e45d605038d439/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:bc84e875994c3b445871ea7181d424588171efec3e185dced958dad9e001950a", size = 4218321, upload-time = "2026-02-10T19:18:22.349Z" }, - { url = "https://files.pythonhosted.org/packages/3e/5d/c4da701939eeee699566a6c1367427ab91a8b7088cc2328c09dbee940415/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:2ae6971afd6246710480e3f15824ed3029a60fc16991db250034efd0b9fb4356", size = 4381786, upload-time = "2026-02-10T19:18:24.529Z" }, - { url = "https://files.pythonhosted.org/packages/ac/97/a538654732974a94ff96c1db621fa464f455c02d4bb7d2652f4edc21d600/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:d861ee9e76ace6cf36a6a89b959ec08e7bc2493ee39d07ffe5acb23ef46d27da", size = 4217990, upload-time = "2026-02-10T19:18:25.957Z" }, - { url = "https://files.pythonhosted.org/packages/ae/11/7e500d2dd3ba891197b9efd2da5454b74336d64a7cc419aa7327ab74e5f6/cryptography-46.0.5-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:2b7a67c9cd56372f3249b39699f2ad479f6991e62ea15800973b956f4b73e257", size = 4381252, upload-time = "2026-02-10T19:18:27.496Z" }, - { url = "https://files.pythonhosted.org/packages/bc/58/6b3d24e6b9bc474a2dcdee65dfd1f008867015408a271562e4b690561a4d/cryptography-46.0.5-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:8456928655f856c6e1533ff59d5be76578a7157224dbd9ce6872f25055ab9ab7", size = 3407605, upload-time = "2026-02-10T19:18:29.233Z" }, + { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, + { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, + { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, + { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, + { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, + { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, + { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, + { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, + { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, + { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, + { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, + { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, + { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, + { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, + { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, + { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, + { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, + { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, + { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, + { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, + { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, + { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, + { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, + { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, + { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, + { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, + { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, +] + +[[package]] +name = "cymem" +version = "2.0.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/2f0fbb32535c3731b7c2974c569fb9325e0a38ed5565a08e1139a3b71e82/cymem-2.0.13.tar.gz", hash = "sha256:1c91a92ae8c7104275ac26bd4d29b08ccd3e7faff5893d3858cb6fadf1bc1588", size = 12320, upload-time = "2025-11-14T14:58:36.902Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/10/64/1db41f7576a6b69f70367e3c15e968fd775ba7419e12059c9966ceb826f8/cymem-2.0.13-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:673183466b0ff2e060d97ec5116711d44200b8f7be524323e080d215ee2d44a5", size = 43587, upload-time = "2025-11-14T14:57:22.39Z" }, + { url = "https://files.pythonhosted.org/packages/81/13/57f936fc08551323aab3f92ff6b7f4d4b89d5b4e495c870a67cb8d279757/cymem-2.0.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bee2791b3f6fc034ce41268851462bf662ff87e8947e35fb6dd0115b4644a61f", size = 43139, upload-time = "2025-11-14T14:57:23.363Z" }, + { url = "https://files.pythonhosted.org/packages/32/a6/9345754be51e0479aa387b7b6cffc289d0fd3201aaeb8dade4623abd1e02/cymem-2.0.13-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f3aee3adf16272bca81c5826eed55ba3c938add6d8c9e273f01c6b829ecfde22", size = 245063, upload-time = "2025-11-14T14:57:24.839Z" }, + { url = "https://files.pythonhosted.org/packages/d6/01/6bc654101526fa86e82bf6b05d99b2cd47c30a333cfe8622c26c0592beb2/cymem-2.0.13-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:30c4e75a3a1d809e89106b0b21803eb78e839881aa1f5b9bd27b454bc73afde3", size = 244496, upload-time = "2025-11-14T14:57:26.42Z" }, + { url = "https://files.pythonhosted.org/packages/c4/fb/853b7b021e701a1f41687f3704d5f469aeb2a4f898c3fbb8076806885955/cymem-2.0.13-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ec99efa03cf8ec11c8906aa4d4cc0c47df393bc9095c9dd64b89b9b43e220b04", size = 243287, upload-time = "2025-11-14T14:57:27.542Z" }, + { url = "https://files.pythonhosted.org/packages/d4/2b/0e4664cafc581de2896d75000651fd2ce7094d33263f466185c28ffc96e4/cymem-2.0.13-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c90a6ecba994a15b17a3f45d7ec74d34081df2f73bd1b090e2adc0317e4e01b6", size = 248287, upload-time = "2025-11-14T14:57:29.055Z" }, + { url = "https://files.pythonhosted.org/packages/21/0f/f94c6950edbfc2aafb81194fc40b6cacc8e994e9359d3cb4328c5705b9b5/cymem-2.0.13-cp311-cp311-win_amd64.whl", hash = "sha256:ce821e6ba59148ed17c4567113b8683a6a0be9c9ac86f14e969919121efb61a5", size = 40116, upload-time = "2025-11-14T14:57:30.592Z" }, + { url = "https://files.pythonhosted.org/packages/00/df/2455eff6ac0381ff165db6883b311f7016e222e3dd62185517f8e8187ed0/cymem-2.0.13-cp311-cp311-win_arm64.whl", hash = "sha256:0dca715e708e545fd1d97693542378a00394b20a37779c1ae2c8bdbb43acef79", size = 36349, upload-time = "2025-11-14T14:57:31.573Z" }, + { url = "https://files.pythonhosted.org/packages/c9/52/478a2911ab5028cb710b4900d64aceba6f4f882fcb13fd8d40a456a1b6dc/cymem-2.0.13-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e8afbc5162a0fe14b6463e1c4e45248a1b2fe2cbcecc8a5b9e511117080da0eb", size = 43745, upload-time = "2025-11-14T14:57:32.52Z" }, + { url = "https://files.pythonhosted.org/packages/f9/71/f0f8adee945524774b16af326bd314a14a478ed369a728a22834e6785a18/cymem-2.0.13-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9251d889348fe79a75e9b3e4d1b5fa651fca8a64500820685d73a3acc21b6a8", size = 42927, upload-time = "2025-11-14T14:57:33.827Z" }, + { url = "https://files.pythonhosted.org/packages/62/6d/159780fe162ff715d62b809246e5fc20901cef87ca28b67d255a8d741861/cymem-2.0.13-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:742fc19764467a49ed22e56a4d2134c262d73a6c635409584ae3bf9afa092c33", size = 258346, upload-time = "2025-11-14T14:57:34.917Z" }, + { url = "https://files.pythonhosted.org/packages/eb/12/678d16f7aa1996f947bf17b8cfb917ea9c9674ef5e2bd3690c04123d5680/cymem-2.0.13-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f190a92fe46197ee64d32560eb121c2809bb843341733227f51538ce77b3410d", size = 260843, upload-time = "2025-11-14T14:57:36.503Z" }, + { url = "https://files.pythonhosted.org/packages/31/5d/0dd8c167c08cd85e70d274b7235cfe1e31b3cebc99221178eaf4bbb95c6f/cymem-2.0.13-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:d670329ee8dbbbf241b7c08069fe3f1d3a1a3e2d69c7d05ea008a7010d826298", size = 254607, upload-time = "2025-11-14T14:57:38.036Z" }, + { url = "https://files.pythonhosted.org/packages/b7/c9/d6514a412a1160aa65db539836b3d47f9b59f6675f294ec34ae32f867c82/cymem-2.0.13-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a84ba3178d9128b9ffb52ce81ebab456e9fe959125b51109f5b73ebdfc6b60d6", size = 262421, upload-time = "2025-11-14T14:57:39.265Z" }, + { url = "https://files.pythonhosted.org/packages/dd/fe/3ee37d02ca4040f2fb22d34eb415198f955862b5dd47eee01df4c8f5454c/cymem-2.0.13-cp312-cp312-win_amd64.whl", hash = "sha256:2ff1c41fd59b789579fdace78aa587c5fc091991fa59458c382b116fc36e30dc", size = 40176, upload-time = "2025-11-14T14:57:40.706Z" }, + { url = "https://files.pythonhosted.org/packages/94/fb/1b681635bfd5f2274d0caa8f934b58435db6c091b97f5593738065ddb786/cymem-2.0.13-cp312-cp312-win_arm64.whl", hash = "sha256:6bbd701338df7bf408648191dff52472a9b334f71bcd31a21a41d83821050f67", size = 35959, upload-time = "2025-11-14T14:57:41.682Z" }, +] + +[[package]] +name = "darabonba-core" +version = "1.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "alibabacloud-tea" }, + { name = "requests" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/66/d3/a7daaee544c904548e665829b51a9fa2572acb82c73ad787a8ff90273002/darabonba_core-1.0.5-py3-none-any.whl", hash = "sha256:671ab8dbc4edc2a8f88013da71646839bb8914f1259efc069353243ef52ea27c", size = 24580, upload-time = "2025-12-12T07:53:59.494Z" }, ] [[package]] name = "databricks-sdk" -version = "0.73.0" +version = "0.102.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth" }, { name = "protobuf" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a8/7f/cfb2a00d10f6295332616e5b22f2ae3aaf2841a3afa6c49262acb6b94f5b/databricks_sdk-0.73.0.tar.gz", hash = "sha256:db09eaaacd98e07dded78d3e7ab47d2f6c886e0380cb577977bd442bace8bd8d", size = 801017, upload-time = "2025-11-05T06:52:58.509Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ab/b3/41ff1c3afe092df9085e084e0dc81c45bca5ed65f7b60dc59df0ade43c76/databricks_sdk-0.102.0.tar.gz", hash = "sha256:8fa5f82317ee27cc46323c6e2543d2cfefb4468653f92ba558271043c6f72fb9", size = 887450, upload-time = "2026-03-19T08:15:54.428Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/27/b822b474aaefb684d11df358d52e012699a2a8af231f9b47c54b73f280cb/databricks_sdk-0.73.0-py3-none-any.whl", hash = "sha256:a4d3cfd19357a2b459d2dc3101454d7f0d1b62865ce099c35d0c342b66ac64ff", size = 753896, upload-time = "2025-11-05T06:52:56.451Z" }, + { url = "https://files.pythonhosted.org/packages/02/8c/d082bd5f72d7613524d5b35dfe1f71732b2246be2704fad68cd0e3fdd020/databricks_sdk-0.102.0-py3-none-any.whl", hash = "sha256:75d1253276ee8f3dd5e7b00d62594b7051838435e618f74a8570a6dbd723ec12", size = 838533, upload-time = "2026-03-19T08:15:52.248Z" }, ] [[package]] @@ -1381,32 +1479,33 @@ wheels = [ [[package]] name = "datasets" -version = "4.7.0" +version = "2.19.2" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "aiohttp" }, { name = "dill" }, { name = "filelock" }, { name = "fsspec", extra = ["http"] }, - { name = "httpx" }, { name = "huggingface-hub" }, { name = "multiprocess" }, { name = "numpy" }, { name = "packaging" }, { name = "pandas" }, { name = "pyarrow" }, + { name = "pyarrow-hotfix" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "xxhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/9c/ba18de0b70858533e422ed6cfe0e46789473cef7fc7fc3653e23fa494730/datasets-4.7.0.tar.gz", hash = "sha256:4984cdfc65d04464da7f95205a55cb50515fd94ae3176caacb50a1b7273792e2", size = 602008, upload-time = "2026-03-09T19:01:49.298Z" } +sdist = { url = "https://files.pythonhosted.org/packages/52/e7/6ee66732f74e4fb1c8915e58b3c253aded777ad0fa457f3f831dd0cd09b4/datasets-2.19.2.tar.gz", hash = "sha256:eccb82fb3bb5ee26ccc6d7a15b7f1f834e2cc4e59b7cff7733a003552bad51ef", size = 2215337, upload-time = "2024-06-03T05:11:44.756Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/03/c6d9c3119cf712f638fe763e887ecaac6acbb62bf1e2acc3cbde0df340fd/datasets-4.7.0-py3-none-any.whl", hash = "sha256:d5fe3025ec6acc3b5649f10d5576dff5e054134927604e6913c1467a04adc3c2", size = 527530, upload-time = "2026-03-09T19:01:47.443Z" }, + { url = "https://files.pythonhosted.org/packages/3f/59/46818ebeb708234a60e42ccf409d20709e482519d2aa450b501ddbba4594/datasets-2.19.2-py3-none-any.whl", hash = "sha256:e07ff15d75b1af75c87dd96323ba2a361128d495136652f37fd62f918d17bb4e", size = 542113, upload-time = "2024-06-03T05:11:41.151Z" }, ] [[package]] name = "dateparser" -version = "1.2.2" +version = "1.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "python-dateutil" }, @@ -1414,9 +1513,9 @@ dependencies = [ { name = "regex" }, { name = "tzlocal" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a9/30/064144f0df1749e7bb5faaa7f52b007d7c2d08ec08fed8411aba87207f68/dateparser-1.2.2.tar.gz", hash = "sha256:986316f17cb8cdc23ea8ce563027c5ef12fc725b6fb1d137c14ca08777c5ecf7", size = 329840, upload-time = "2025-06-26T09:29:23.211Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3d/2c/668dfb8c073a5dde3efb80fa382de1502e3b14002fd386a8c1b0b49e92a9/dateparser-1.3.0.tar.gz", hash = "sha256:5bccf5d1ec6785e5be71cc7ec80f014575a09b4923e762f850e57443bddbf1a5", size = 337152, upload-time = "2026-02-04T16:00:06.162Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/87/22/f020c047ae1346613db9322638186468238bcfa8849b4668a22b97faad65/dateparser-1.2.2-py3-none-any.whl", hash = "sha256:5a5d7211a09013499867547023a2a0c91d5a27d15dd4dbcea676ea9fe66f2482", size = 315453, upload-time = "2025-06-26T09:29:21.412Z" }, + { url = "https://files.pythonhosted.org/packages/9a/c7/95349670e193b2891176e1b8e5f43e12b31bff6d9994f70e74ab385047f6/dateparser-1.3.0-py3-none-any.whl", hash = "sha256:8dc678b0a526e103379f02ae44337d424bd366aac727d3c6cf52ce1b01efbb5a", size = 318688, upload-time = "2026-02-04T16:00:04.652Z" }, ] [[package]] @@ -1430,29 +1529,28 @@ wheels = [ [[package]] name = "deepeval" -version = "2.6.7" +version = "3.9.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, - { name = "anthropic" }, - { name = "black" }, - { name = "coverage" }, + { name = "click" }, { name = "grpcio" }, - { name = "instructor" }, - { name = "langchain-community" }, - { name = "langchain-openai" }, - { name = "llama-index" }, - { name = "ollama" }, + { name = "jinja2" }, + { name = "nest-asyncio" }, { name = "openai" }, { name = "opentelemetry-api" }, - { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-sdk" }, { name = "portalocker" }, + { name = "posthog" }, + { name = "pydantic" }, + { name = "pydantic-settings" }, + { name = "pyfiglet" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-repeat" }, { name = "pytest-rerunfailures" }, { name = "pytest-xdist" }, + { name = "python-dotenv" }, { name = "requests" }, { name = "rich" }, { name = "sentry-sdk" }, @@ -1460,13 +1558,12 @@ dependencies = [ { name = "tabulate" }, { name = "tenacity" }, { name = "tqdm" }, - { name = "twine" }, { name = "typer" }, { name = "wheel" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/da/1fb5f6f257bc17cd4b492ea10a5f5240fb8a29efe917c62b38b31a85149a/deepeval-2.6.7.tar.gz", hash = "sha256:44d4f2e0684341d1093101d3016203bb8720790b46cfeaa0bd796d4de4999b2d", size = 360966, upload-time = "2025-04-06T07:54:47.556Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/74/767e253720c58d488ffa7340618161b05dea5295c747a251d2d1900fa9a5/deepeval-3.9.2.tar.gz", hash = "sha256:267491bbb25fd1c4e8e4991721b0e44bd5b54bbd219edd6705b038028874964f", size = 611140, upload-time = "2026-03-20T10:30:29.596Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/b1/9b131e2592b7c61dd696943aecda11d87189e4852153f2971e5676111788/deepeval-2.6.7-py3-none-any.whl", hash = "sha256:22241137682f4d5d1d76f3c219e0266e0c94c04c266980ecee653b8fb9c710a7", size = 589175, upload-time = "2025-04-06T07:54:41.752Z" }, + { url = "https://files.pythonhosted.org/packages/e7/0a/3273444affbe05e1278c86e02fb8d728264fb8e3b8d59f8448b0d1bcb32b/deepeval-3.9.2-py3-none-any.whl", hash = "sha256:47f34f54ac648e06610dbf6d1eeca49c53150a04e9219d4f0619d63403c016df", size = 839941, upload-time = "2026-03-20T10:30:26.982Z" }, ] [[package]] @@ -1504,7 +1601,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.0" +version = "1.13.2" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1512,6 +1609,7 @@ dependencies = [ { name = "arize-phoenix-otel" }, { name = "azure-identity" }, { name = "beautifulsoup4" }, + { name = "bleach" }, { name = "boto3" }, { name = "bs4" }, { name = "cachetools" }, @@ -1575,6 +1673,7 @@ dependencies = [ { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, + { name = "pypandoc" }, { name = "pypdfium2" }, { name = "python-docx" }, { name = "python-dotenv" }, @@ -1689,6 +1788,7 @@ vdb = [ { name = "clickzetta-connector-python" }, { name = "couchbase" }, { name = "elasticsearch" }, + { name = "holo-search-sdk" }, { name = "intersystems-irispython" }, { name = "mo-vector" }, { name = "mysql-connector-python" }, @@ -1713,140 +1813,142 @@ vdb = [ requires-dist = [ { name = "aliyun-log-python-sdk", specifier = "~=0.9.37" }, { name = "apscheduler", specifier = ">=3.11.0" }, - { name = "arize-phoenix-otel", specifier = "~=0.9.2" }, - { name = "azure-identity", specifier = "==1.16.1" }, - { name = "beautifulsoup4", specifier = "==4.12.2" }, - { name = "boto3", specifier = "==1.35.99" }, + { name = "arize-phoenix-otel", specifier = "~=0.15.0" }, + { name = "azure-identity", specifier = "==1.25.3" }, + { name = "beautifulsoup4", specifier = "==4.14.3" }, + { name = "bleach", specifier = "~=6.3.0" }, + { name = "boto3", specifier = "==1.42.73" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, - { name = "celery", specifier = "~=5.5.2" }, + { name = "celery", specifier = "~=5.6.2" }, { name = "charset-normalizer", specifier = ">=3.4.4" }, { name = "croniter", specifier = ">=6.0.0" }, { name = "fastopenapi", extras = ["flask"], specifier = ">=0.7.0" }, { name = "flask", specifier = "~=3.1.2" }, - { name = "flask-compress", specifier = ">=1.17,<1.18" }, + { name = "flask-compress", specifier = ">=1.17,<1.24" }, { name = "flask-cors", specifier = "~=6.0.0" }, { name = "flask-login", specifier = "~=0.6.3" }, - { name = "flask-migrate", specifier = "~=4.0.7" }, + { name = "flask-migrate", specifier = "~=4.1.0" }, { name = "flask-orjson", specifier = "~=2.0.0" }, { name = "flask-restx", specifier = "~=1.3.2" }, { name = "flask-sqlalchemy", specifier = "~=3.1.1" }, { name = "gevent", specifier = "~=25.9.1" }, { name = "gmpy2", specifier = "~=2.3.0" }, - { name = "google-api-core", specifier = "==2.18.0" }, - { name = "google-api-python-client", specifier = "==2.189.0" }, - { name = "google-auth", specifier = "==2.29.0" }, - { name = "google-auth-httplib2", specifier = "==0.2.0" }, - { name = "google-cloud-aiplatform", specifier = "==1.49.0" }, - { name = "googleapis-common-protos", specifier = "==1.63.0" }, - { name = "gunicorn", specifier = "~=23.0.0" }, - { name = "httpx", extras = ["socks"], specifier = "~=0.27.0" }, + { name = "google-api-core", specifier = ">=2.19.1" }, + { name = "google-api-python-client", specifier = "==2.193.0" }, + { name = "google-auth", specifier = ">=2.47.0" }, + { name = "google-auth-httplib2", specifier = "==0.3.0" }, + { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, + { name = "googleapis-common-protos", specifier = ">=1.65.0" }, + { name = "gunicorn", specifier = "~=25.1.0" }, + { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, - { name = "langsmith", specifier = "~=0.1.77" }, - { name = "litellm", specifier = "==1.77.1" }, - { name = "markdown", specifier = "~=3.5.1" }, + { name = "langsmith", specifier = "~=0.7.16" }, + { name = "litellm", specifier = "==1.82.6" }, + { name = "markdown", specifier = "~=3.10.2" }, { name = "mlflow-skinny", specifier = ">=3.0.0" }, { name = "numpy", specifier = "~=1.26.4" }, { name = "openpyxl", specifier = "~=3.1.5" }, - { name = "opentelemetry-api", specifier = "==1.27.0" }, - { name = "opentelemetry-distro", specifier = "==0.48b0" }, - { name = "opentelemetry-exporter-otlp", specifier = "==1.27.0" }, - { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.27.0" }, - { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.27.0" }, - { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.27.0" }, - { name = "opentelemetry-instrumentation", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-celery", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-flask", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-httpx", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-redis", specifier = "==0.48b0" }, - { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.48b0" }, - { name = "opentelemetry-propagator-b3", specifier = "==1.27.0" }, - { name = "opentelemetry-proto", specifier = "==1.27.0" }, - { name = "opentelemetry-sdk", specifier = "==1.27.0" }, - { name = "opentelemetry-semantic-conventions", specifier = "==0.48b0" }, - { name = "opentelemetry-util-http", specifier = "==0.48b0" }, - { name = "opik", specifier = "~=1.8.72" }, + { name = "opentelemetry-api", specifier = "==1.28.0" }, + { name = "opentelemetry-distro", specifier = "==0.49b0" }, + { name = "opentelemetry-exporter-otlp", specifier = "==1.28.0" }, + { name = "opentelemetry-exporter-otlp-proto-common", specifier = "==1.28.0" }, + { name = "opentelemetry-exporter-otlp-proto-grpc", specifier = "==1.28.0" }, + { name = "opentelemetry-exporter-otlp-proto-http", specifier = "==1.28.0" }, + { name = "opentelemetry-instrumentation", specifier = "==0.49b0" }, + { name = "opentelemetry-instrumentation-celery", specifier = "==0.49b0" }, + { name = "opentelemetry-instrumentation-flask", specifier = "==0.49b0" }, + { name = "opentelemetry-instrumentation-httpx", specifier = "==0.49b0" }, + { name = "opentelemetry-instrumentation-redis", specifier = "==0.49b0" }, + { name = "opentelemetry-instrumentation-sqlalchemy", specifier = "==0.49b0" }, + { name = "opentelemetry-propagator-b3", specifier = "==1.40.0" }, + { name = "opentelemetry-proto", specifier = "==1.28.0" }, + { name = "opentelemetry-sdk", specifier = "==1.28.0" }, + { name = "opentelemetry-semantic-conventions", specifier = "==0.49b0" }, + { name = "opentelemetry-util-http", specifier = "==0.49b0" }, + { name = "opik", specifier = "~=1.10.37" }, { name = "packaging", specifier = "~=23.2" }, - { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" }, + { name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=3.0.1" }, { name = "psycogreen", specifier = "~=1.0.2" }, { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.10.3" }, - { name = "pydantic-settings", specifier = "~=2.12.0" }, - { name = "pyjwt", specifier = "~=2.11.0" }, - { name = "pypdfium2", specifier = "==5.2.0" }, + { name = "pydantic-extra-types", specifier = "~=2.11.0" }, + { name = "pydantic-settings", specifier = "~=2.13.1" }, + { name = "pyjwt", specifier = "~=2.12.0" }, + { name = "pypandoc", specifier = "~=1.13" }, + { name = "pypdfium2", specifier = "==5.6.0" }, { name = "python-docx", specifier = "~=1.2.0" }, - { name = "python-dotenv", specifier = "==1.0.1" }, + { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, - { name = "redis", extras = ["hiredis"], specifier = "~=7.2.0" }, - { name = "resend", specifier = "~=2.9.0" }, + { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, + { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, - { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.28.0" }, + { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, { name = "sqlalchemy", specifier = "~=2.0.29" }, - { name = "sseclient-py", specifier = "~=1.8.0" }, - { name = "starlette", specifier = "==0.49.1" }, - { name = "tiktoken", specifier = "~=0.9.0" }, - { name = "transformers", specifier = "~=4.56.1" }, - { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" }, + { name = "sseclient-py", specifier = "~=1.9.0" }, + { name = "starlette", specifier = "==1.0.0" }, + { name = "tiktoken", specifier = "~=0.12.0" }, + { name = "transformers", specifier = "~=5.3.0" }, + { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, - { name = "weaviate-client", specifier = "==4.17.0" }, + { name = "weaviate-client", specifier = "==4.20.4" }, { name = "webvtt-py", specifier = "~=0.5.1" }, - { name = "yarl", specifier = "~=1.18.3" }, + { name = "yarl", specifier = "~=1.23.0" }, ] [package.metadata.requires-dev] dev = [ - { name = "basedpyright", specifier = "~=1.31.0" }, + { name = "basedpyright", specifier = "~=1.38.2" }, { name = "boto3-stubs", specifier = ">=1.38.20" }, { name = "celery-types", specifier = ">=0.23.0" }, - { name = "coverage", specifier = "~=7.2.4" }, - { name = "dotenv-linter", specifier = "~=0.5.0" }, - { name = "faker", specifier = "~=38.2.0" }, + { name = "coverage", specifier = "~=7.13.4" }, + { name = "dotenv-linter", specifier = "~=0.7.0" }, + { name = "faker", specifier = "~=40.11.0" }, { name = "hypothesis", specifier = ">=6.131.15" }, { name = "import-linter", specifier = ">=2.3" }, { name = "lxml-stubs", specifier = "~=0.5.1" }, - { name = "mypy", specifier = "~=1.17.1" }, - { name = "pandas-stubs", specifier = "~=2.2.3" }, - { name = "pyrefly", specifier = ">=0.54.0" }, - { name = "pytest", specifier = "~=8.3.2" }, - { name = "pytest-benchmark", specifier = "~=4.0.0" }, - { name = "pytest-cov", specifier = "~=4.1.0" }, - { name = "pytest-env", specifier = "~=1.1.3" }, - { name = "pytest-mock", specifier = "~=3.14.0" }, + { name = "mypy", specifier = "~=1.19.1" }, + { name = "pandas-stubs", specifier = "~=3.0.0" }, + { name = "pyrefly", specifier = ">=0.55.0" }, + { name = "pytest", specifier = "~=9.0.2" }, + { name = "pytest-benchmark", specifier = "~=5.2.3" }, + { name = "pytest-cov", specifier = "~=7.1.0" }, + { name = "pytest-env", specifier = "~=1.6.0" }, + { name = "pytest-mock", specifier = "~=3.15.1" }, { name = "pytest-timeout", specifier = ">=2.4.0" }, { name = "pytest-xdist", specifier = ">=3.8.0" }, - { name = "ruff", specifier = "~=0.14.0" }, + { name = "ruff", specifier = "~=0.15.5" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "sseclient-py", specifier = ">=1.8.0" }, - { name = "testcontainers", specifier = "~=4.13.2" }, + { name = "testcontainers", specifier = "~=4.14.1" }, { name = "types-aiofiles", specifier = "~=25.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, - { name = "types-cachetools", specifier = "~=5.5.0" }, + { name = "types-cachetools", specifier = "~=6.2.0" }, { name = "types-cffi", specifier = ">=1.17.0" }, { name = "types-colorama", specifier = "~=0.4.15" }, { name = "types-defusedxml", specifier = "~=0.7.0" }, - { name = "types-deprecated", specifier = "~=1.2.15" }, - { name = "types-docutils", specifier = "~=0.21.0" }, - { name = "types-flask-cors", specifier = "~=5.0.0" }, + { name = "types-deprecated", specifier = "~=1.3.1" }, + { name = "types-docutils", specifier = "~=0.22.3" }, + { name = "types-flask-cors", specifier = "~=6.0.0" }, { name = "types-flask-migrate", specifier = "~=4.1.0" }, { name = "types-gevent", specifier = "~=25.9.0" }, { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.23.0" }, + { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, - { name = "types-oauthlib", specifier = "~=3.2.0" }, + { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, { name = "types-olefile", specifier = "~=0.47.0" }, { name = "types-openpyxl", specifier = "~=3.1.5" }, { name = "types-pexpect", specifier = "~=4.9.0" }, - { name = "types-protobuf", specifier = "~=5.29.1" }, + { name = "types-protobuf", specifier = "~=6.32.1" }, { name = "types-psutil", specifier = "~=7.2.2" }, { name = "types-psycopg2", specifier = "~=2.9.21" }, { name = "types-pygments", specifier = "~=2.19.0" }, @@ -1854,10 +1956,10 @@ dev = [ { name = "types-pyopenssl", specifier = ">=24.1.0" }, { name = "types-python-dateutil", specifier = "~=2.9.0" }, { name = "types-python-http-client", specifier = ">=3.3.7.20240910" }, - { name = "types-pywin32", specifier = "~=310.0.0" }, + { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, - { name = "types-regex", specifier = "~=2024.11.6" }, + { name = "types-regex", specifier = "~=2026.2.28" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1871,13 +1973,13 @@ evaluation = [ { name = "ragas", specifier = ">=0.2.0" }, ] storage = [ - { name = "azure-storage-blob", specifier = "==12.26.0" }, + { name = "azure-storage-blob", specifier = "==12.28.0" }, { name = "bce-python-sdk", specifier = "~=0.9.23" }, - { name = "cos-python-sdk-v5", specifier = "==1.9.38" }, - { name = "esdk-obs-python", specifier = "==3.25.8" }, - { name = "google-cloud-storage", specifier = "==2.16.0" }, + { name = "cos-python-sdk-v5", specifier = "==1.9.41" }, + { name = "esdk-obs-python", specifier = "==3.26.2" }, + { name = "google-cloud-storage", specifier = ">=3.0.0" }, { name = "opendal", specifier = "~=0.46.0" }, - { name = "oss2", specifier = "==2.18.5" }, + { name = "oss2", specifier = "==2.19.1" }, { name = "supabase", specifier = "~=2.18.1" }, { name = "tos", specifier = "~=2.9.0" }, ] @@ -1887,48 +1989,40 @@ tools = [ ] vdb = [ { name = "alibabacloud-gpdb20160503", specifier = "~=3.8.0" }, - { name = "alibabacloud-tea-openapi", specifier = "~=0.3.9" }, + { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, - { name = "clickhouse-connect", specifier = "~=0.10.0" }, + { name = "clickhouse-connect", specifier = "~=0.14.1" }, { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, - { name = "couchbase", specifier = "~=4.3.0" }, + { name = "couchbase", specifier = "~=4.5.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, + { name = "holo-search-sdk", specifier = ">=0.4.1" }, { name = "intersystems-irispython", specifier = ">=5.1.0" }, { name = "mo-vector", specifier = "~=0.1.13" }, { name = "mysql-connector-python", specifier = ">=9.3.0" }, { name = "opensearch-py", specifier = "==3.1.0" }, - { name = "oracledb", specifier = "==3.3.0" }, + { name = "oracledb", specifier = "==3.4.2" }, { name = "pgvecto-rs", extras = ["sqlalchemy"], specifier = "~=0.2.1" }, - { name = "pgvector", specifier = "==0.2.5" }, - { name = "pymilvus", specifier = "~=2.5.0" }, - { name = "pymochow", specifier = "==2.2.9" }, + { name = "pgvector", specifier = "==0.4.2" }, + { name = "pymilvus", specifier = "~=2.6.10" }, + { name = "pymochow", specifier = "==2.3.6" }, { name = "pyobvector", specifier = "~=0.2.17" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.3.7" }, - { name = "tcvectordb", specifier = "~=1.6.4" }, - { name = "tidb-vector", specifier = "==0.0.9" }, - { name = "upstash-vector", specifier = "==0.6.0" }, + { name = "tablestore", specifier = "==6.4.1" }, + { name = "tcvectordb", specifier = "~=2.0.0" }, + { name = "tidb-vector", specifier = "==0.0.15" }, + { name = "upstash-vector", specifier = "==0.8.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, - { name = "weaviate-client", specifier = "==4.17.0" }, - { name = "xinference-client", specifier = "~=1.2.2" }, + { name = "weaviate-client", specifier = "==4.20.4" }, + { name = "xinference-client", specifier = "~=2.3.1" }, ] [[package]] name = "dill" -version = "0.4.0" +version = "0.3.8" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/12/80/630b4b88364e9a8c8c5797f4602d0f76ef820909ee32f0bacb9f90654042/dill-0.4.0.tar.gz", hash = "sha256:0633f1d2df477324f53a895b02c901fb961bdbf65a17122586ea7019292cbcf0", size = 186976, upload-time = "2025-04-16T00:41:48.867Z" } +sdist = { url = "https://files.pythonhosted.org/packages/17/4d/ac7ffa80c69ea1df30a8aa11b3578692a5118e7cd1aa157e3ef73b092d15/dill-0.3.8.tar.gz", hash = "sha256:3ebe3c479ad625c4553aca177444d89b486b1d84982eeacded644afc0cf797ca", size = 184847, upload-time = "2024-01-27T23:42:16.145Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668, upload-time = "2025-04-16T00:41:47.671Z" }, -] - -[[package]] -name = "dirtyjson" -version = "1.0.8" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/db/04/d24f6e645ad82ba0ef092fa17d9ef7a21953781663648a01c9371d9e8e98/dirtyjson-1.0.8.tar.gz", hash = "sha256:90ca4a18f3ff30ce849d100dcf4a003953c79d3a2348ef056f1d9c22231a25fd", size = 30782, upload-time = "2022-11-28T23:32:33.319Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/68/69/1bcf70f81de1b4a9f21b3a62ec0c83bdff991c88d6cc2267d02408457e88/dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53", size = 25197, upload-time = "2022-11-28T23:32:31.219Z" }, + { url = "https://files.pythonhosted.org/packages/c9/7a/cef76fd8438a42f96db64ddaa85280485a9c395e7df3db8158cfec1eee34/dill-0.3.8-py3-none-any.whl", hash = "sha256:c36ca9ffb54365bdd2f8eb3eff7d2a21237f8452b57ace88b1ac615b7e815bd7", size = 116252, upload-time = "2024-01-27T23:42:14.239Z" }, ] [[package]] @@ -1940,15 +2034,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/27/4570e78fc0bf5ea0ca45eb1de3818a23787af9b390c0b0a0033a1b8236f9/diskcache-5.6.3-py3-none-any.whl", hash = "sha256:5e31b2d5fbad117cc363ebaf6b689474db18a1f6438bc82358b024abd4c2ca19", size = 45550, upload-time = "2023-08-31T06:11:58.822Z" }, ] -[[package]] -name = "distlib" -version = "0.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/96/8e/709914eb2b5749865801041647dc7f4e6d00b549cfe88b65ca192995f07c/distlib-0.4.0.tar.gz", hash = "sha256:feec40075be03a04501a973d81f633735b4b69f98b05450592310c0f401a4e0d", size = 614605, upload-time = "2025-07-17T16:52:00.465Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/33/6b/e0547afaf41bf2c42e52430072fa5658766e3d65bd4b03a563d1b6336f57/distlib-0.4.0-py2.py3-none-any.whl", hash = "sha256:9659f7d87e46584a30b5780e43ac7a2143098441670ff0a49d5f9034c54a6c16", size = 469047, upload-time = "2025-07-17T16:51:58.613Z" }, -] - [[package]] name = "distro" version = "1.9.0" @@ -1981,29 +2066,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/55/e2/2537ebcff11c1ee1ff17d8d0b6f4db75873e3b0fb32c2d4a2ee31ecb310a/docstring_parser-0.17.0-py3-none-any.whl", hash = "sha256:cf2569abd23dce8099b300f9b4fa8191e9582dda731fd533daf54c4551658708", size = 36896, upload-time = "2025-07-21T07:35:00.684Z" }, ] -[[package]] -name = "docutils" -version = "0.22.4" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ae/b6/03bb70946330e88ffec97aefd3ea75ba575cb2e762061e0e62a213befee8/docutils-0.22.4.tar.gz", hash = "sha256:4db53b1fde9abecbb74d91230d32ab626d94f6badfc575d6db9194a49df29968", size = 2291750, upload-time = "2025-12-18T19:00:26.443Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/02/10/5da547df7a391dcde17f59520a231527b8571e6f46fc8efb02ccb370ab12/docutils-0.22.4-py3-none-any.whl", hash = "sha256:d0013f540772d1420576855455d050a2180186c91c15779301ac2ccb3eeb68de", size = 633196, upload-time = "2025-12-18T19:00:18.077Z" }, -] - [[package]] name = "dotenv-linter" -version = "0.5.0" +version = "0.7.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "click" }, { name = "click-default-group" }, - { name = "ply" }, + { name = "lark" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ef/fe/77e184ccc312f6263cbcc48a9579eec99f5c7ff72a9b1bd7812cafc22bbb/dotenv_linter-0.5.0.tar.gz", hash = "sha256:4862a8393e5ecdfb32982f1b32dbc006fff969a7b3c8608ba7db536108beeaea", size = 15346, upload-time = "2024-03-13T11:52:10.52Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1e/e5/515ca4e069b70ba0be477ab0a193855c08066f9ef1a9350dcfbdc8f12f87/dotenv_linter-0.7.0.tar.gz", hash = "sha256:24ed93c1028d6305d6787e51773badf3346e53012ad4f5ada9cf747d2da6de13", size = 14033, upload-time = "2025-04-28T17:40:00.771Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/01/62ed4374340e6cf17c5084828974d96db8085e4018439ac41dc3cbbbcab3/dotenv_linter-0.5.0-py3-none-any.whl", hash = "sha256:fd01cca7f2140cb1710f49cbc1bf0e62397a75a6f0522d26a8b9b2331143c8bd", size = 21770, upload-time = "2024-03-13T11:52:08.607Z" }, + { url = "https://files.pythonhosted.org/packages/6e/5e/e26881b8d6bd6498c1a7225fba8ead3626a9f4b2d7d29dd272a875753d0d/dotenv_linter-0.7.0-py3-none-any.whl", hash = "sha256:0ffdf0c7435bd638aba5ff6cc9ea53bf093488bf1c722e363e902008659bb1fb", size = 19806, upload-time = "2025-04-28T17:39:58.395Z" }, ] [[package]] @@ -2015,6 +2091,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b0/0d/9feae160378a3553fa9a339b0e9c1a048e147a4127210e286ef18b730f03/durationpy-0.10-py3-none-any.whl", hash = "sha256:3b41e1b601234296b4fb368338fdcd3e13e0b4fb5b67345948f4f2bf9868b286", size = 3922, upload-time = "2025-05-17T13:52:36.463Z" }, ] +[[package]] +name = "ecdsa" +version = "0.19.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c0/1f/924e3caae75f471eae4b26bd13b698f6af2c44279f67af317439c2f4c46a/ecdsa-0.19.1.tar.gz", hash = "sha256:478cba7b62555866fcb3bb3fe985e06decbdb68ef55713c4e5ab98c57d508e61", size = 201793, upload-time = "2025-03-13T11:52:43.25Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cb/a3/460c57f094a4a165c84a1341c373b0a4f5ec6ac244b998d5021aade89b77/ecdsa-0.19.1-py2.py3-none-any.whl", hash = "sha256:30638e27cf77b7e15c4c4cc1973720149e1033827cfd00661ca5c8cc0cdb24c3", size = 150607, upload-time = "2025-03-13T11:52:41.757Z" }, +] + [[package]] name = "elastic-transport" version = "8.17.1" @@ -2051,14 +2139,14 @@ wheels = [ [[package]] name = "esdk-obs-python" -version = "3.25.8" +version = "3.26.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "crcmod" }, { name = "pycryptodome" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/40/99/52362d6e081a642d6de78f6ab53baa5e3f82f2386c48954e18ee7b4ab22b/esdk-obs-python-3.25.8.tar.gz", hash = "sha256:aeded00b27ecd5a25ffaec38a2cc9416b51923d48db96c663f1a735f859b5273", size = 96302, upload-time = "2025-09-01T11:35:20.432Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ef/9a/090f718114eec808c04762d9ea64f9e6f170ee419a673beba8b7810ec758/esdk_obs_python-3.26.2.tar.gz", hash = "sha256:dc865356bb4be474e5eaa557ff226f0f89ac8f5afff61a1cc85143079bf6e223", size = 95922, upload-time = "2026-03-07T10:38:16.732Z" } [[package]] name = "et-xmlfile" @@ -2069,15 +2157,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c1/8b/5fe2cc11fee489817272089c4203e679c63b570a5aaeb18d852ae3cbba6a/et_xmlfile-2.0.0-py3-none-any.whl", hash = "sha256:7a91720bc756843502c3b7504c77b8fe44217c85c537d85037f0f536151b2caa", size = 18059, upload-time = "2024-10-25T17:25:39.051Z" }, ] -[[package]] -name = "eval-type-backport" -version = "0.3.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/23/079e39571d6dd8d90d7a369ecb55ad766efb6bae4e77389629e14458c280/eval_type_backport-0.3.0.tar.gz", hash = "sha256:1638210401e184ff17f877e9a2fa076b60b5838790f4532a21761cc2be67aea1", size = 9272, upload-time = "2025-11-13T20:56:50.845Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/19/d8/2a1c638d9e0aa7e269269a1a1bf423ddd94267f1a01bbe3ad03432b67dd4/eval_type_backport-0.3.0-py3-none-any.whl", hash = "sha256:975a10a0fe333c8b6260d7fdb637698c9a16c3a9e3b6eb943fee6a6f67a37fe8", size = 6061, upload-time = "2025-11-13T20:56:49.499Z" }, -] - [[package]] name = "events" version = "0.5" @@ -2097,29 +2176,30 @@ wheels = [ [[package]] name = "faker" -version = "38.2.0" +version = "40.11.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "tzdata" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/27/022d4dbd4c20567b4c294f79a133cc2f05240ea61e0d515ead18c995c249/faker-38.2.0.tar.gz", hash = "sha256:20672803db9c7cb97f9b56c18c54b915b6f1d8991f63d1d673642dc43f5ce7ab", size = 1941469, upload-time = "2025-11-19T16:37:31.892Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/dc/b68e5378e5a7db0ab776efcdd53b6fe374b29d703e156fd5bb4c5437069e/faker-40.11.0.tar.gz", hash = "sha256:7c419299103b13126bd02ec14bd2b47b946edb5a5eedf305e66a193b25f9a734", size = 1957570, upload-time = "2026-03-13T14:36:11.844Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/93/00c94d45f55c336434a15f98d906387e87ce28f9918e4444829a8fda432d/faker-38.2.0-py3-none-any.whl", hash = "sha256:35fe4a0a79dee0dc4103a6083ee9224941e7d3594811a50e3969e547b0d2ee65", size = 1980505, upload-time = "2025-11-19T16:37:30.208Z" }, + { url = "https://files.pythonhosted.org/packages/b1/fa/a86c6ba66f0308c95b9288b1e3eaccd934b545646f63494a86f1ec2f8c8e/faker-40.11.0-py3-none-any.whl", hash = "sha256:0e9816c950528d2a37d74863f3ef389ea9a3a936cbcde0b11b8499942e25bf90", size = 1989457, upload-time = "2026-03-13T14:36:09.792Z" }, ] [[package]] name = "fastapi" -version = "0.122.0" +version = "0.135.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "annotated-doc" }, { name = "pydantic" }, { name = "starlette" }, { name = "typing-extensions" }, + { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b2/de/3ee97a4f6ffef1fb70bf20561e4f88531633bb5045dc6cebc0f8471f764d/fastapi-0.122.0.tar.gz", hash = "sha256:cd9b5352031f93773228af8b4c443eedc2ac2aa74b27780387b853c3726fb94b", size = 346436, upload-time = "2025-11-24T19:17:47.95Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e7/7b/f8e0211e9380f7195ba3f3d40c292594fd81ba8ec4629e3854c353aaca45/fastapi-0.135.1.tar.gz", hash = "sha256:d04115b508d936d254cea545b7312ecaa58a7b3a0f84952535b4c9afae7668cd", size = 394962, upload-time = "2026-03-01T18:18:29.369Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7a/93/aa8072af4ff37b795f6bbf43dcaf61115f40f49935c7dbb180c9afc3f421/fastapi-0.122.0-py3-none-any.whl", hash = "sha256:a456e8915dfc6c8914a50d9651133bd47ec96d331c5b44600baa635538a30d67", size = 110671, upload-time = "2025-11-24T19:17:45.96Z" }, + { url = "https://files.pythonhosted.org/packages/e4/72/42e900510195b23a56bde950d26a51f8b723846bfcaa0286e90287f0422b/fastapi-0.135.1-py3-none-any.whl", hash = "sha256:46e2fc5745924b7c840f71ddd277382af29ce1cdb7d5eab5bf697e3fb9999c9e", size = 116999, upload-time = "2026-03-01T18:18:30.831Z" }, ] [[package]] @@ -2171,20 +2251,20 @@ wheels = [ [[package]] name = "fickling" -version = "0.1.8" +version = "0.1.10" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/be/cd91e3921f064230ac9462479e4647fb91a7b0d01677103fce89f52e3042/fickling-0.1.8.tar.gz", hash = "sha256:25a0bc7acda76176a9087b405b05f7f5021f76079aa26c6fe3270855ec57d9bf", size = 336756, upload-time = "2026-02-21T00:57:26.106Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9f/06/1818b8f52267599e54041349c553d5894e17ec8a539a246eb3f9eaf05629/fickling-0.1.10.tar.gz", hash = "sha256:8c8b76abd29936f1a5932e4087b8c8becb2d7ab1cf08549e63519ebcb2f71644", size = 338062, upload-time = "2026-03-13T16:34:29.287Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/02/92/af72f783ac57fa2452f8f921c9441366c42ae1f03f5af41718445114c82f/fickling-0.1.8-py3-none-any.whl", hash = "sha256:97218785cfe00a93150808dcf9e3eb512371e0484e3ce0b05bc460b97240f292", size = 52613, upload-time = "2026-02-21T00:57:24.82Z" }, + { url = "https://files.pythonhosted.org/packages/05/86/620960dff970da5311f05e25fc045dac8495557d51030e5a0827084b18fd/fickling-0.1.10-py3-none-any.whl", hash = "sha256:962c35c38ece1b3632fc119c0f4cb1eebc02dc6d65bfd93a1803afd42ca91d25", size = 52853, upload-time = "2026-03-13T16:34:27.821Z" }, ] [[package]] name = "filelock" -version = "3.20.3" +version = "3.25.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1d/65/ce7f1b70157833bf3cb851b556a37d4547ceafc158aa9b34b36782f23696/filelock-3.20.3.tar.gz", hash = "sha256:18c57ee915c7ec61cff0ecf7f0f869936c7c30191bb0cf406f1341778d0834e1", size = 19485, upload-time = "2026-01-09T17:55:05.421Z" } +sdist = { url = "https://files.pythonhosted.org/packages/94/b8/00651a0f559862f3bb7d6f7477b192afe3f583cc5e26403b44e59a55ab34/filelock-3.25.2.tar.gz", hash = "sha256:b64ece2b38f4ca29dd3e810287aa8c48182bbecd1ae6e9ae126c9b35f1382694", size = 40480, upload-time = "2026-03-11T20:45:38.487Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/36/7fb70f04bf00bc646cd5bb45aa9eddb15e19437a28b8fb2b4a5249fac770/filelock-3.20.3-py3-none-any.whl", hash = "sha256:4b0dda527ee31078689fc205ec4f1c1bf7d56cf88b6dc9426c4f230e46c2dce1", size = 16701, upload-time = "2026-01-09T17:55:04.334Z" }, + { url = "https://files.pythonhosted.org/packages/a4/a5/842ae8f0c08b61d6484b52f99a03510a3a72d23141942d216ebe81fefbce/filelock-3.25.2-py3-none-any.whl", hash = "sha256:ca8afb0da15f229774c9ad1b455ed96e85a81373065fb10446672f64444ddf70", size = 26759, upload-time = "2026-03-11T20:45:37.437Z" }, ] [[package]] @@ -2215,31 +2295,30 @@ wheels = [ [[package]] name = "flask-compress" -version = "1.17" +version = "1.23" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "backports-zstd" }, { name = "brotli", marker = "platform_python_implementation != 'PyPy'" }, { name = "brotlicffi", marker = "platform_python_implementation == 'PyPy'" }, { name = "flask" }, - { name = "zstandard" }, - { name = "zstandard", marker = "platform_python_implementation == 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cc/1f/260db5a4517d59bfde7b4a0d71052df68fb84983bda9231100e3b80f5989/flask_compress-1.17.tar.gz", hash = "sha256:1ebb112b129ea7c9e7d6ee6d5cc0d64f226cbc50c4daddf1a58b9bd02253fbd8", size = 15733, upload-time = "2024-10-14T08:13:33.196Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/e4/2b54da5cf8ae5d38a495ca20154aa40d6d2ee6dc1756429a82856181aa2c/flask_compress-1.23.tar.gz", hash = "sha256:5580935b422e3f136b9a90909e4b1015ac2b29c9aebe0f8733b790fde461c545", size = 20135, upload-time = "2025-11-06T09:06:29.56Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/54/ff08f947d07c0a8a5d8f1c8e57b142c97748ca912b259db6467ab35983cd/Flask_Compress-1.17-py3-none-any.whl", hash = "sha256:415131f197c41109f08e8fdfc3a6628d83d81680fb5ecd0b3a97410e02397b20", size = 8723, upload-time = "2024-10-14T08:13:31.726Z" }, + { url = "https://files.pythonhosted.org/packages/7d/9a/bebdcdba82d2786b33cd9f5fd65b8d309797c27176a9c4f357c1150c4ac0/flask_compress-1.23-py3-none-any.whl", hash = "sha256:52108afb4d133a5aab9809e6ac3c085ed7b9c788c75c6846c129faa28468f08c", size = 10515, upload-time = "2025-11-06T09:06:28.691Z" }, ] [[package]] name = "flask-cors" -version = "6.0.1" +version = "6.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/76/37/bcfa6c7d5eec777c4c7cf45ce6b27631cebe5230caf88d85eadd63edd37a/flask_cors-6.0.1.tar.gz", hash = "sha256:d81bcb31f07b0985be7f48406247e9243aced229b7747219160a0559edd678db", size = 13463, upload-time = "2025-06-11T01:32:08.518Z" } +sdist = { url = "https://files.pythonhosted.org/packages/70/74/0fc0fa68d62f21daef41017dafab19ef4b36551521260987eb3a5394c7ba/flask_cors-6.0.2.tar.gz", hash = "sha256:6e118f3698249ae33e429760db98ce032a8bf9913638d085ca0f4c5534ad2423", size = 13472, upload-time = "2025-12-12T20:31:42.861Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/17/f8/01bf35a3afd734345528f98d0353f2a978a476528ad4d7e78b70c4d149dd/flask_cors-6.0.1-py3-none-any.whl", hash = "sha256:c7b2cbfb1a31aa0d2e5341eea03a6805349f7a61647daee1a15c46bbe981494c", size = 13244, upload-time = "2025-06-11T01:32:07.352Z" }, + { url = "https://files.pythonhosted.org/packages/4f/af/72ad54402e599152de6d067324c46fe6a4f531c7c65baf7e96c63db55eaf/flask_cors-6.0.2-py3-none-any.whl", hash = "sha256:e57544d415dfd7da89a9564e1e3a9e515042df76e12130641ca6f3f2f03b699a", size = 13257, upload-time = "2025-12-12T20:31:41.3Z" }, ] [[package]] @@ -2257,16 +2336,16 @@ wheels = [ [[package]] name = "flask-migrate" -version = "4.0.7" +version = "4.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alembic" }, { name = "flask" }, { name = "flask-sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3b/e2/4008fc0d298d7ce797021b194bbe151d4d12db670691648a226d4fc8aefc/Flask-Migrate-4.0.7.tar.gz", hash = "sha256:dff7dd25113c210b069af280ea713b883f3840c1e3455274745d7355778c8622", size = 21770, upload-time = "2024-03-11T18:43:01.498Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5a/8e/47c7b3c93855ceffc2eabfa271782332942443321a07de193e4198f920cf/flask_migrate-4.1.0.tar.gz", hash = "sha256:1a336b06eb2c3ace005f5f2ded8641d534c18798d64061f6ff11f79e1434126d", size = 21965, upload-time = "2025-01-10T18:51:11.848Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/01/587023575286236f95d2ab8a826c320375ed5ea2102bb103ed89704ffa6b/Flask_Migrate-4.0.7-py3-none-any.whl", hash = "sha256:5c532be17e7b43a223b7500d620edae33795df27c75811ddf32560f7d48ec617", size = 21127, upload-time = "2024-03-11T18:42:59.462Z" }, + { url = "https://files.pythonhosted.org/packages/d2/c4/3f329b23d769fe7628a5fc57ad36956f1fb7132cf8837be6da762b197327/Flask_Migrate-4.1.0-py3-none-any.whl", hash = "sha256:24d8051af161782e0743af1b04a152d007bad9772b2bca67b7ec1e8ceeb3910d", size = 21237, upload-time = "2025-01-10T18:51:09.527Z" }, ] [[package]] @@ -2314,11 +2393,10 @@ wheels = [ [[package]] name = "flatbuffers" -version = "25.9.23" +version = "25.12.19" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9d/1f/3ee70b0a55137442038f2a33469cc5fddd7e0ad2abf83d7497c18a2b6923/flatbuffers-25.9.23.tar.gz", hash = "sha256:676f9fa62750bb50cf531b42a0a2a118ad8f7f797a511eda12881c016f093b12", size = 22067, upload-time = "2025-09-24T05:25:30.106Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/1b/00a78aa2e8fbd63f9af08c9c19e6deb3d5d66b4dda677a0f61654680ee89/flatbuffers-25.9.23-py2.py3-none-any.whl", hash = "sha256:255538574d6cb6d0a79a17ec8bc0d30985913b87513a01cce8bcdb6b4c44d0e2", size = 30869, upload-time = "2025-09-24T05:25:28.912Z" }, + { url = "https://files.pythonhosted.org/packages/e8/2d/d2a548598be01649e2d46231d151a6c56d10b964d94043a335ae56ea2d92/flatbuffers-25.12.19-py2.py3-none-any.whl", hash = "sha256:7634f50c427838bb021c2d66a3d1168e9d199b0607e6329399f04846d42e20b4", size = 26661, upload-time = "2025-12-19T23:16:13.622Z" }, ] [[package]] @@ -2364,11 +2442,11 @@ wheels = [ [[package]] name = "fsspec" -version = "2025.10.0" +version = "2024.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/7f/2747c0d332b9acfa75dc84447a066fdf812b5a6b8d30472b74d309bfe8cb/fsspec-2025.10.0.tar.gz", hash = "sha256:b6789427626f068f9a83ca4e8a3cc050850b6c0f71f99ddb4f542b8266a26a59", size = 309285, upload-time = "2025-10-30T14:58:44.036Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/b8/e3ba21f03c00c27adc9a8cd1cab8adfb37b6024757133924a9a4eab63a83/fsspec-2024.3.1.tar.gz", hash = "sha256:f39780e282d7d117ffb42bb96992f8a90795e4d0fb0f661a70ca39fe9c43ded9", size = 170742, upload-time = "2024-03-18T19:35:13.995Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/eb/02/a6b21098b1d5d6249b7c5ab69dde30108a71e4e819d4a9778f1de1d5b70d/fsspec-2025.10.0-py3-none-any.whl", hash = "sha256:7c7712353ae7d875407f97715f0e1ffcc21e33d5b24556cb1e090ae9409ec61d", size = 200966, upload-time = "2025-10-30T14:58:42.53Z" }, + { url = "https://files.pythonhosted.org/packages/93/6d/66d48b03460768f523da62a57a7e14e5e95fdf339d79e996ce3cecda2cdb/fsspec-2024.3.1-py3-none-any.whl", hash = "sha256:918d18d41bf73f0e2b261824baeb1b124bcf771767e3a26425cd7dec3332f512", size = 171991, upload-time = "2024-03-18T19:35:11.259Z" }, ] [package.optional-dependencies] @@ -2428,14 +2506,14 @@ wheels = [ [[package]] name = "gitpython" -version = "3.1.45" +version = "3.1.46" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "gitdb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9a/c8/dd58967d119baab745caec2f9d853297cec1989ec1d63f677d3880632b88/gitpython-3.1.45.tar.gz", hash = "sha256:85b0ee964ceddf211c41b9f27a49086010a190fd8132a24e21f362a4b36a791c", size = 215076, upload-time = "2025-07-24T03:45:54.871Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/b5/59d16470a1f0dfe8c793f9ef56fd3826093fc52b3bd96d6b9d6c26c7e27b/gitpython-3.1.46.tar.gz", hash = "sha256:400124c7d0ef4ea03f7310ac2fbf7151e09ff97f2a3288d64a440c584a29c37f", size = 215371, upload-time = "2026-01-01T15:37:32.073Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/01/61/d4b89fec821f72385526e1b9d9a3a0385dda4a72b206d28049e2c7cd39b8/gitpython-3.1.45-py3-none-any.whl", hash = "sha256:8908cb2e02fb3b93b7eb0f2827125cb699869470432cc885f019b8fd0fccff77", size = 208168, upload-time = "2025-07-24T03:45:52.517Z" }, + { url = "https://files.pythonhosted.org/packages/6a/09/e21df6aef1e1ffc0c816f0522ddc3f6dcded766c3261813131c78a704470/gitpython-3.1.46-py3-none-any.whl", hash = "sha256:79812ed143d9d25b6d176a10bb511de0f9c67b1fa641d82097b0ab90398a2058", size = 208620, upload-time = "2026-01-01T15:37:30.574Z" }, ] [[package]] @@ -2481,7 +2559,7 @@ wheels = [ [[package]] name = "google-api-core" -version = "2.18.0" +version = "2.30.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth" }, @@ -2490,9 +2568,9 @@ dependencies = [ { name = "protobuf" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b2/8f/ecd68579bd2bf5e9321df60dcdee6e575adf77fedacb1d8378760b2b16b6/google-api-core-2.18.0.tar.gz", hash = "sha256:62d97417bfc674d6cef251e5c4d639a9655e00c45528c4364fbfebb478ce72a9", size = 148047, upload-time = "2024-03-21T20:16:56.269Z" } +sdist = { url = "https://files.pythonhosted.org/packages/22/98/586ec94553b569080caef635f98a3723db36a38eac0e3d7eb3ea9d2e4b9a/google_api_core-2.30.0.tar.gz", hash = "sha256:02edfa9fab31e17fc0befb5f161b3bf93c9096d99aed584625f38065c511ad9b", size = 176959, upload-time = "2026-02-18T20:28:11.926Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/86/75/59a3ad90d9b4ff5b3e0537611dbe885aeb96124521c9d35aa079f1e0f2c9/google_api_core-2.18.0-py3-none-any.whl", hash = "sha256:5a63aa102e0049abe85b5b88cb9409234c1f70afcda21ce1e40b285b9629c1d6", size = 138293, upload-time = "2024-03-21T20:16:53.645Z" }, + { url = "https://files.pythonhosted.org/packages/45/27/09c33d67f7e0dcf06d7ac17d196594e66989299374bfb0d4331d1038e76b/google_api_core-2.30.0-py3-none-any.whl", hash = "sha256:80be49ee937ff9aba0fd79a6eddfde35fe658b9953ab9b79c57dd7061afa8df5", size = 173288, upload-time = "2026-02-18T20:28:10.367Z" }, ] [package.optional-dependencies] @@ -2503,7 +2581,7 @@ grpc = [ [[package]] name = "google-api-python-client" -version = "2.189.0" +version = "2.193.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2512,41 +2590,45 @@ dependencies = [ { name = "httplib2" }, { name = "uritemplate" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6f/f8/0783aeca3410ee053d4dd1fccafd85197847b8f84dd038e036634605d083/google_api_python_client-2.189.0.tar.gz", hash = "sha256:45f2d8559b5c895dde6ad3fb33de025f5cb2c197fa5862f18df7f5295a172741", size = 13979470, upload-time = "2026-02-03T19:24:55.432Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/f4/e14b6815d3b1885328dd209676a3a4c704882743ac94e18ef0093894f5c8/google_api_python_client-2.193.0.tar.gz", hash = "sha256:8f88d16e89d11341e0a8b199cafde0fb7e6b44260dffb88d451577cbd1bb5d33", size = 14281006, upload-time = "2026-03-17T18:25:29.415Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/44/3677ff27998214f2fa7957359da48da378a0ffff1bd0bdaba42e752bc13e/google_api_python_client-2.189.0-py3-none-any.whl", hash = "sha256:a258c09660a49c6159173f8bbece171278e917e104a11f0640b34751b79c8a1a", size = 14547633, upload-time = "2026-02-03T19:24:52.845Z" }, + { url = "https://files.pythonhosted.org/packages/f0/6d/fe75167797790a56d17799b75e1129bb93f7ff061efc7b36e9731bd4be2b/google_api_python_client-2.193.0-py3-none-any.whl", hash = "sha256:c42aa324b822109901cfecab5dc4fc3915d35a7b376835233c916c70610322db", size = 14856490, upload-time = "2026-03-17T18:25:26.608Z" }, ] [[package]] name = "google-auth" -version = "2.29.0" +version = "2.49.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cachetools" }, + { name = "cryptography" }, { name = "pyasn1-modules" }, - { name = "rsa" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/18/b2/f14129111cfd61793609643a07ecb03651a71dd65c6974f63b0310ff4b45/google-auth-2.29.0.tar.gz", hash = "sha256:672dff332d073227550ffc7457868ac4218d6c500b155fe6cc17d2b13602c360", size = 244326, upload-time = "2024-03-20T17:24:27.72Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/80/6a696a07d3d3b0a92488933532f03dbefa4a24ab80fb231395b9a2a1be77/google_auth-2.49.1.tar.gz", hash = "sha256:16d40da1c3c5a0533f57d268fe72e0ebb0ae1cc3b567024122651c045d879b64", size = 333825, upload-time = "2026-03-12T19:30:58.135Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9e/8d/ddbcf81ec751d8ee5fd18ac11ff38a0e110f39dfbf105e6d9db69d556dd0/google_auth-2.29.0-py2.py3-none-any.whl", hash = "sha256:d452ad095688cd52bae0ad6fafe027f6a6d6f560e810fec20914e17a09526415", size = 189186, upload-time = "2024-03-20T17:24:24.292Z" }, + { url = "https://files.pythonhosted.org/packages/e9/eb/c6c2478d8a8d633460be40e2a8a6f8f429171997a35a96f81d3b680dec83/google_auth-2.49.1-py3-none-any.whl", hash = "sha256:195ebe3dca18eddd1b3db5edc5189b76c13e96f29e73043b923ebcf3f1a860f7", size = 240737, upload-time = "2026-03-12T19:30:53.159Z" }, +] + +[package.optional-dependencies] +requests = [ + { name = "requests" }, ] [[package]] name = "google-auth-httplib2" -version = "0.2.0" +version = "0.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-auth" }, { name = "httplib2" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/56/be/217a598a818567b28e859ff087f347475c807a5649296fb5a817c58dacef/google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05", size = 10842, upload-time = "2023-12-12T17:40:30.722Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d5/ad/c1f2b1175096a8d04cf202ad5ea6065f108d26be6fc7215876bde4a7981d/google_auth_httplib2-0.3.0.tar.gz", hash = "sha256:177898a0175252480d5ed916aeea183c2df87c1f9c26705d74ae6b951c268b0b", size = 11134, upload-time = "2025-12-15T22:13:51.825Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/be/8a/fe34d2f3f9470a27b01c9e76226965863f153d5fbe276f83608562e49c04/google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d", size = 9253, upload-time = "2023-12-12T17:40:13.055Z" }, + { url = "https://files.pythonhosted.org/packages/99/d5/3c97526c8796d3caf5f4b3bed2b05e8a7102326f00a334e7a438237f3b22/google_auth_httplib2-0.3.0-py3-none-any.whl", hash = "sha256:426167e5df066e3f5a0fc7ea18768c08e7296046594ce4c8c409c2457dd1f776", size = 9529, upload-time = "2025-12-15T22:13:51.048Z" }, ] [[package]] name = "google-cloud-aiplatform" -version = "1.49.0" +version = "1.142.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2555,15 +2637,16 @@ dependencies = [ { name = "google-cloud-bigquery" }, { name = "google-cloud-resource-manager" }, { name = "google-cloud-storage" }, + { name = "google-genai" }, { name = "packaging" }, { name = "proto-plus" }, { name = "protobuf" }, { name = "pydantic" }, - { name = "shapely" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/47/21/5930a1420f82bec246ae09e1b7cc8458544f3befe669193b33a7b5c0691c/google-cloud-aiplatform-1.49.0.tar.gz", hash = "sha256:e6e6d01079bb5def49e4be4db4d12b13c624b5c661079c869c13c855e5807429", size = 5766450, upload-time = "2024-04-29T17:25:31.646Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/39/6a/7d9e1c03c814e760361fe8b0ffd373ead4124ace66ed33bb16d526ae1ecf/google_cloud_aiplatform-1.49.0-py2.py3-none-any.whl", hash = "sha256:8072d9e0c18d8942c704233d1a93b8d6312fc7b278786a283247950e28ae98df", size = 4914049, upload-time = "2024-04-29T17:25:27.625Z" }, + { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, ] [[package]] @@ -2599,7 +2682,7 @@ wheels = [ [[package]] name = "google-cloud-resource-manager" -version = "1.15.0" +version = "1.16.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core", extra = ["grpc"] }, @@ -2609,14 +2692,14 @@ dependencies = [ { name = "proto-plus" }, { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/19/b95d0e8814ce42522e434cdd85c0cb6236d874d9adf6685fc8e6d1fda9d1/google_cloud_resource_manager-1.15.0.tar.gz", hash = "sha256:3d0b78c3daa713f956d24e525b35e9e9a76d597c438837171304d431084cedaf", size = 449227, upload-time = "2025-10-20T14:57:01.108Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/7f/db00b2820475793a52958dc55fe9ec2eb8e863546e05fcece9b921f86ebe/google_cloud_resource_manager-1.16.0.tar.gz", hash = "sha256:cc938f87cc36c2672f062b1e541650629e0d954c405a4dac35ceedee70c267c3", size = 459840, upload-time = "2026-01-15T13:04:07.726Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/93/5aef41a5f146ad4559dd7040ae5fa8e7ddcab4dfadbef6cb4b66d775e690/google_cloud_resource_manager-1.15.0-py3-none-any.whl", hash = "sha256:0ccde5db644b269ddfdf7b407a2c7b60bdbf459f8e666344a5285601d00c7f6d", size = 397151, upload-time = "2025-10-20T14:53:45.409Z" }, + { url = "https://files.pythonhosted.org/packages/94/ff/4b28bcc791d9d7e4ac8fea00fbd90ccb236afda56746a3b4564d2ae45df3/google_cloud_resource_manager-1.16.0-py3-none-any.whl", hash = "sha256:fb9a2ad2b5053c508e1c407ac31abfd1a22e91c32876c1892830724195819a28", size = 400218, upload-time = "2026-01-15T13:02:47.378Z" }, ] [[package]] name = "google-cloud-storage" -version = "2.16.0" +version = "3.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2626,29 +2709,50 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/c5/0bc3f97cf4c14a731ecc5a95c5cde6883aec7289dc74817f9b41f866f77e/google-cloud-storage-2.16.0.tar.gz", hash = "sha256:dda485fa503710a828d01246bd16ce9db0823dc51bbca742ce96a6817d58669f", size = 5525307, upload-time = "2024-03-18T23:55:37.102Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/e5/7d045d188f4ef85d94b9e3ae1bf876170c6b9f4c9a950124978efc36f680/google_cloud_storage-2.16.0-py2.py3-none-any.whl", hash = "sha256:91a06b96fb79cf9cdfb4e759f178ce11ea885c79938f89590344d079305f5852", size = 125604, upload-time = "2024-03-18T23:55:33.987Z" }, + { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, ] [[package]] name = "google-crc32c" -version = "1.7.1" +version = "1.8.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/ae/87802e6d9f9d69adfaedfcfd599266bf386a54d0be058b532d04c794f76d/google_crc32c-1.7.1.tar.gz", hash = "sha256:2bff2305f98846f3e825dbeec9ee406f89da7962accdb29356e4eadc251bd472", size = 14495, upload-time = "2025-03-26T14:29:13.32Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79", size = 14192, upload-time = "2025-12-16T00:35:25.142Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/94/220139ea87822b6fdfdab4fb9ba81b3fff7ea2c82e2af34adc726085bffc/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6fbab4b935989e2c3610371963ba1b86afb09537fd0c633049be82afe153ac06", size = 30468, upload-time = "2025-03-26T14:32:52.215Z" }, - { url = "https://files.pythonhosted.org/packages/94/97/789b23bdeeb9d15dc2904660463ad539d0318286d7633fe2760c10ed0c1c/google_crc32c-1.7.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:ed66cbe1ed9cbaaad9392b5259b3eba4a9e565420d734e6238813c428c3336c9", size = 30313, upload-time = "2025-03-26T14:57:38.758Z" }, - { url = "https://files.pythonhosted.org/packages/81/b8/976a2b843610c211e7ccb3e248996a61e87dbb2c09b1499847e295080aec/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee6547b657621b6cbed3562ea7826c3e11cab01cd33b74e1f677690652883e77", size = 33048, upload-time = "2025-03-26T14:41:30.679Z" }, - { url = "https://files.pythonhosted.org/packages/c9/16/a3842c2cf591093b111d4a5e2bfb478ac6692d02f1b386d2a33283a19dc9/google_crc32c-1.7.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d68e17bad8f7dd9a49181a1f5a8f4b251c6dbc8cc96fb79f1d321dfd57d66f53", size = 32669, upload-time = "2025-03-26T14:41:31.432Z" }, - { url = "https://files.pythonhosted.org/packages/04/17/ed9aba495916fcf5fe4ecb2267ceb851fc5f273c4e4625ae453350cfd564/google_crc32c-1.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:6335de12921f06e1f774d0dd1fbea6bf610abe0887a1638f64d694013138be5d", size = 33476, upload-time = "2025-03-26T14:29:10.211Z" }, - { url = "https://files.pythonhosted.org/packages/dd/b7/787e2453cf8639c94b3d06c9d61f512234a82e1d12d13d18584bd3049904/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:2d73a68a653c57281401871dd4aeebbb6af3191dcac751a76ce430df4d403194", size = 30470, upload-time = "2025-03-26T14:34:31.655Z" }, - { url = "https://files.pythonhosted.org/packages/ed/b4/6042c2b0cbac3ec3a69bb4c49b28d2f517b7a0f4a0232603c42c58e22b44/google_crc32c-1.7.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:22beacf83baaf59f9d3ab2bbb4db0fb018da8e5aebdce07ef9f09fce8220285e", size = 30315, upload-time = "2025-03-26T15:01:54.634Z" }, - { url = "https://files.pythonhosted.org/packages/29/ad/01e7a61a5d059bc57b702d9ff6a18b2585ad97f720bd0a0dbe215df1ab0e/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19eafa0e4af11b0a4eb3974483d55d2d77ad1911e6cf6f832e1574f6781fd337", size = 33180, upload-time = "2025-03-26T14:41:32.168Z" }, - { url = "https://files.pythonhosted.org/packages/3b/a5/7279055cf004561894ed3a7bfdf5bf90a53f28fadd01af7cd166e88ddf16/google_crc32c-1.7.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6d86616faaea68101195c6bdc40c494e4d76f41e07a37ffdef270879c15fb65", size = 32794, upload-time = "2025-03-26T14:41:33.264Z" }, - { url = "https://files.pythonhosted.org/packages/0f/d6/77060dbd140c624e42ae3ece3df53b9d811000729a5c821b9fd671ceaac6/google_crc32c-1.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:b7491bdc0c7564fcf48c0179d2048ab2f7c7ba36b84ccd3a3e1c3f7a72d3bba6", size = 33477, upload-time = "2025-03-26T14:29:10.94Z" }, - { url = "https://files.pythonhosted.org/packages/16/1b/1693372bf423ada422f80fd88260dbfd140754adb15cbc4d7e9a68b1cb8e/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:85fef7fae11494e747c9fd1359a527e5970fc9603c90764843caabd3a16a0a48", size = 28241, upload-time = "2025-03-26T14:41:45.898Z" }, - { url = "https://files.pythonhosted.org/packages/fd/3c/2a19a60a473de48717b4efb19398c3f914795b64a96cf3fbe82588044f78/google_crc32c-1.7.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6efb97eb4369d52593ad6f75e7e10d053cf00c48983f7a973105bc70b0ac4d82", size = 28048, upload-time = "2025-03-26T14:41:46.696Z" }, + { url = "https://files.pythonhosted.org/packages/5d/ef/21ccfaab3d5078d41efe8612e0ed0bfc9ce22475de074162a91a25f7980d/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:014a7e68d623e9a4222d663931febc3033c5c7c9730785727de2a81f87d5bab8", size = 31298, upload-time = "2025-12-16T00:20:32.241Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b8/f8413d3f4b676136e965e764ceedec904fe38ae8de0cdc52a12d8eb1096e/google_crc32c-1.8.0-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:86cfc00fe45a0ac7359e5214a1704e51a99e757d0272554874f419f79838c5f7", size = 30872, upload-time = "2025-12-16T00:33:58.785Z" }, + { url = "https://files.pythonhosted.org/packages/f6/fd/33aa4ec62b290477181c55bb1c9302c9698c58c0ce9a6ab4874abc8b0d60/google_crc32c-1.8.0-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:19b40d637a54cb71e0829179f6cb41835f0fbd9e8eb60552152a8b52c36cbe15", size = 33243, upload-time = "2025-12-16T00:40:21.46Z" }, + { url = "https://files.pythonhosted.org/packages/71/03/4820b3bd99c9653d1a5210cb32f9ba4da9681619b4d35b6a052432df4773/google_crc32c-1.8.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:17446feb05abddc187e5441a45971b8394ea4c1b6efd88ab0af393fd9e0a156a", size = 33608, upload-time = "2025-12-16T00:40:22.204Z" }, + { url = "https://files.pythonhosted.org/packages/7c/43/acf61476a11437bf9733fb2f70599b1ced11ec7ed9ea760fdd9a77d0c619/google_crc32c-1.8.0-cp311-cp311-win_amd64.whl", hash = "sha256:71734788a88f551fbd6a97be9668a0020698e07b2bf5b3aa26a36c10cdfb27b2", size = 34439, upload-time = "2025-12-16T00:35:20.458Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113", size = 31300, upload-time = "2025-12-16T00:21:56.723Z" }, + { url = "https://files.pythonhosted.org/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb", size = 30867, upload-time = "2025-12-16T00:38:31.302Z" }, + { url = "https://files.pythonhosted.org/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411", size = 33364, upload-time = "2025-12-16T00:40:22.96Z" }, + { url = "https://files.pythonhosted.org/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454", size = 33740, upload-time = "2025-12-16T00:40:23.96Z" }, + { url = "https://files.pythonhosted.org/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962", size = 34437, upload-time = "2025-12-16T00:35:21.395Z" }, + { url = "https://files.pythonhosted.org/packages/52/c5/c171e4d8c44fec1422d801a6d2e5d7ddabd733eeda505c79730ee9607f07/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:87fa445064e7db928226b2e6f0d5304ab4cd0339e664a4e9a25029f384d9bb93", size = 28615, upload-time = "2025-12-16T00:40:29.298Z" }, + { url = "https://files.pythonhosted.org/packages/9c/97/7d75fe37a7a6ed171a2cf17117177e7aab7e6e0d115858741b41e9dd4254/google_crc32c-1.8.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f639065ea2042d5c034bf258a9f085eaa7af0cd250667c0635a3118e8f92c69c", size = 28800, upload-time = "2025-12-16T00:40:30.322Z" }, +] + +[[package]] +name = "google-genai" +version = "1.68.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "google-auth", extra = ["requests"] }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "sniffio" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "websockets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9c/2c/f059982dbcb658cc535c81bbcbe7e2c040d675f4b563b03cdb01018a4bc3/google_genai-1.68.0.tar.gz", hash = "sha256:ac30c0b8bc630f9372993a97e4a11dae0e36f2e10d7c55eacdca95a9fa14ca96", size = 511285, upload-time = "2026-03-18T01:03:18.243Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15", size = 750912, upload-time = "2026-03-18T01:03:15.983Z" }, ] [[package]] @@ -2665,14 +2769,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.63.0" +version = "1.73.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d2/dc/291cebf3c73e108ef8210f19cb83d671691354f4f7dd956445560d778715/googleapis-common-protos-1.63.0.tar.gz", hash = "sha256:17ad01b11d5f1d0171c06d3ba5c04c54474e883b66b949722b4938ee2694ef4e", size = 121646, upload-time = "2024-03-11T12:33:15.765Z" } +sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dc/a6/12a0c976140511d8bc8a16ad15793b2aef29ac927baa0786ccb7ddbb6e1c/googleapis_common_protos-1.63.0-py2.py3-none-any.whl", hash = "sha256:ae45f75702f7c08b541f750854a678bd8f534a1a6bace6afe975f1d0a82d6632", size = 229141, upload-time = "2024-03-11T12:33:14.052Z" }, + { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, ] [package.optional-dependencies] @@ -2696,21 +2800,17 @@ wheels = [ ] [package.optional-dependencies] -aiohttp = [ - { name = "aiohttp" }, -] -requests = [ - { name = "requests" }, - { name = "requests-toolbelt" }, +httpx = [ + { name = "httpx" }, ] [[package]] name = "graphql-core" -version = "3.2.7" +version = "3.2.8" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ac/9b/037a640a2983b09aed4a823f9cf1729e6d780b0671f854efa4727a7affbe/graphql_core-3.2.7.tar.gz", hash = "sha256:27b6904bdd3b43f2a0556dad5d579bdfdeab1f38e8e8788e555bdcb586a6f62c", size = 513484, upload-time = "2025-11-01T22:30:40.436Z" } +sdist = { url = "https://files.pythonhosted.org/packages/68/c5/36aa96205c3ecbb3d34c7c24189e4553c7ca2ebc7e1dd07432339b980272/graphql_core-3.2.8.tar.gz", hash = "sha256:015457da5d996c924ddf57a43f4e959b0b94fb695b85ed4c29446e508ed65cf3", size = 513181, upload-time = "2026-03-05T19:55:37.332Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/14/933037032608787fb92e365883ad6a741c235e0ff992865ec5d904a38f1e/graphql_core-3.2.7-py3-none-any.whl", hash = "sha256:17fc8f3ca4a42913d8e24d9ac9f08deddf0a0b2483076575757f6c412ead2ec0", size = 207262, upload-time = "2025-11-01T22:30:38.912Z" }, + { url = "https://files.pythonhosted.org/packages/86/41/cb887d9afc5dabd78feefe6ccbaf83ff423c206a7a1b7aeeac05120b2125/graphql_core-3.2.8-py3-none-any.whl", hash = "sha256:cbee07bee1b3ed5e531723685369039f32ff815ef60166686e0162f540f1520c", size = 207349, upload-time = "2026-03-05T19:55:35.911Z" }, ] [[package]] @@ -2724,81 +2824,77 @@ wheels = [ [[package]] name = "greenlet" -version = "3.2.4" +version = "3.3.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/03/b8/704d753a5a45507a7aab61f18db9509302ed3d0a27ac7e0359ec2905b1a6/greenlet-3.2.4.tar.gz", hash = "sha256:0dca0d95ff849f9a364385f36ab49f50065d76964944638be9691e1832e9f86d", size = 188260, upload-time = "2025-08-07T13:24:33.51Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/51/1664f6b78fc6ebbd98019a1fd730e83fa78f2db7058f72b1463d3612b8db/greenlet-3.3.2.tar.gz", hash = "sha256:2eaf067fc6d886931c7962e8c6bede15d2f01965560f3359b27c80bde2d151f2", size = 188267, upload-time = "2026-02-20T20:54:15.531Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/de/f28ced0a67749cac23fecb02b694f6473f47686dff6afaa211d186e2ef9c/greenlet-3.2.4-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:96378df1de302bc38e99c3a9aa311967b7dc80ced1dcc6f171e99842987882a2", size = 272305, upload-time = "2025-08-07T13:15:41.288Z" }, - { url = "https://files.pythonhosted.org/packages/09/16/2c3792cba130000bf2a31c5272999113f4764fd9d874fb257ff588ac779a/greenlet-3.2.4-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:1ee8fae0519a337f2329cb78bd7a8e128ec0f881073d43f023c7b8d4831d5246", size = 632472, upload-time = "2025-08-07T13:42:55.044Z" }, - { url = "https://files.pythonhosted.org/packages/ae/8f/95d48d7e3d433e6dae5b1682e4292242a53f22df82e6d3dda81b1701a960/greenlet-3.2.4-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:94abf90142c2a18151632371140b3dba4dee031633fe614cb592dbb6c9e17bc3", size = 644646, upload-time = "2025-08-07T13:45:26.523Z" }, - { url = "https://files.pythonhosted.org/packages/d5/5e/405965351aef8c76b8ef7ad370e5da58d57ef6068df197548b015464001a/greenlet-3.2.4-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:4d1378601b85e2e5171b99be8d2dc85f594c79967599328f95c1dc1a40f1c633", size = 640519, upload-time = "2025-08-07T13:53:13.928Z" }, - { url = "https://files.pythonhosted.org/packages/25/5d/382753b52006ce0218297ec1b628e048c4e64b155379331f25a7316eb749/greenlet-3.2.4-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0db5594dce18db94f7d1650d7489909b57afde4c580806b8d9203b6e79cdc079", size = 639707, upload-time = "2025-08-07T13:18:27.146Z" }, - { url = "https://files.pythonhosted.org/packages/1f/8e/abdd3f14d735b2929290a018ecf133c901be4874b858dd1c604b9319f064/greenlet-3.2.4-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2523e5246274f54fdadbce8494458a2ebdcdbc7b802318466ac5606d3cded1f8", size = 587684, upload-time = "2025-08-07T13:18:25.164Z" }, - { url = "https://files.pythonhosted.org/packages/5d/65/deb2a69c3e5996439b0176f6651e0052542bb6c8f8ec2e3fba97c9768805/greenlet-3.2.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:1987de92fec508535687fb807a5cea1560f6196285a4cde35c100b8cd632cc52", size = 1116647, upload-time = "2025-08-07T13:42:38.655Z" }, - { url = "https://files.pythonhosted.org/packages/3f/cc/b07000438a29ac5cfb2194bfc128151d52f333cee74dd7dfe3fb733fc16c/greenlet-3.2.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:55e9c5affaa6775e2c6b67659f3a71684de4c549b3dd9afca3bc773533d284fa", size = 1142073, upload-time = "2025-08-07T13:18:21.737Z" }, - { url = "https://files.pythonhosted.org/packages/67/24/28a5b2fa42d12b3d7e5614145f0bd89714c34c08be6aabe39c14dd52db34/greenlet-3.2.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c9c6de1940a7d828635fbd254d69db79e54619f165ee7ce32fda763a9cb6a58c", size = 1548385, upload-time = "2025-11-04T12:42:11.067Z" }, - { url = "https://files.pythonhosted.org/packages/6a/05/03f2f0bdd0b0ff9a4f7b99333d57b53a7709c27723ec8123056b084e69cd/greenlet-3.2.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:03c5136e7be905045160b1b9fdca93dd6727b180feeafda6818e6496434ed8c5", size = 1613329, upload-time = "2025-11-04T12:42:12.928Z" }, - { url = "https://files.pythonhosted.org/packages/d8/0f/30aef242fcab550b0b3520b8e3561156857c94288f0332a79928c31a52cf/greenlet-3.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:9c40adce87eaa9ddb593ccb0fa6a07caf34015a29bf8d344811665b573138db9", size = 299100, upload-time = "2025-08-07T13:44:12.287Z" }, - { url = "https://files.pythonhosted.org/packages/44/69/9b804adb5fd0671f367781560eb5eb586c4d495277c93bde4307b9e28068/greenlet-3.2.4-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3b67ca49f54cede0186854a008109d6ee71f66bd57bb36abd6d0a0267b540cdd", size = 274079, upload-time = "2025-08-07T13:15:45.033Z" }, - { url = "https://files.pythonhosted.org/packages/46/e9/d2a80c99f19a153eff70bc451ab78615583b8dac0754cfb942223d2c1a0d/greenlet-3.2.4-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddf9164e7a5b08e9d22511526865780a576f19ddd00d62f8a665949327fde8bb", size = 640997, upload-time = "2025-08-07T13:42:56.234Z" }, - { url = "https://files.pythonhosted.org/packages/3b/16/035dcfcc48715ccd345f3a93183267167cdd162ad123cd93067d86f27ce4/greenlet-3.2.4-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:f28588772bb5fb869a8eb331374ec06f24a83a9c25bfa1f38b6993afe9c1e968", size = 655185, upload-time = "2025-08-07T13:45:27.624Z" }, - { url = "https://files.pythonhosted.org/packages/31/da/0386695eef69ffae1ad726881571dfe28b41970173947e7c558d9998de0f/greenlet-3.2.4-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:5c9320971821a7cb77cfab8d956fa8e39cd07ca44b6070db358ceb7f8797c8c9", size = 649926, upload-time = "2025-08-07T13:53:15.251Z" }, - { url = "https://files.pythonhosted.org/packages/68/88/69bf19fd4dc19981928ceacbc5fd4bb6bc2215d53199e367832e98d1d8fe/greenlet-3.2.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c60a6d84229b271d44b70fb6e5fa23781abb5d742af7b808ae3f6efd7c9c60f6", size = 651839, upload-time = "2025-08-07T13:18:30.281Z" }, - { url = "https://files.pythonhosted.org/packages/19/0d/6660d55f7373b2ff8152401a83e02084956da23ae58cddbfb0b330978fe9/greenlet-3.2.4-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b3812d8d0c9579967815af437d96623f45c0f2ae5f04e366de62a12d83a8fb0", size = 607586, upload-time = "2025-08-07T13:18:28.544Z" }, - { url = "https://files.pythonhosted.org/packages/8e/1a/c953fdedd22d81ee4629afbb38d2f9d71e37d23caace44775a3a969147d4/greenlet-3.2.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:abbf57b5a870d30c4675928c37278493044d7c14378350b3aa5d484fa65575f0", size = 1123281, upload-time = "2025-08-07T13:42:39.858Z" }, - { url = "https://files.pythonhosted.org/packages/3f/c7/12381b18e21aef2c6bd3a636da1088b888b97b7a0362fac2e4de92405f97/greenlet-3.2.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:20fb936b4652b6e307b8f347665e2c615540d4b42b3b4c8a321d8286da7e520f", size = 1151142, upload-time = "2025-08-07T13:18:22.981Z" }, - { url = "https://files.pythonhosted.org/packages/27/45/80935968b53cfd3f33cf99ea5f08227f2646e044568c9b1555b58ffd61c2/greenlet-3.2.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ee7a6ec486883397d70eec05059353b8e83eca9168b9f3f9a361971e77e0bcd0", size = 1564846, upload-time = "2025-11-04T12:42:15.191Z" }, - { url = "https://files.pythonhosted.org/packages/69/02/b7c30e5e04752cb4db6202a3858b149c0710e5453b71a3b2aec5d78a1aab/greenlet-3.2.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:326d234cbf337c9c3def0676412eb7040a35a768efc92504b947b3e9cfc7543d", size = 1633814, upload-time = "2025-11-04T12:42:17.175Z" }, - { url = "https://files.pythonhosted.org/packages/e9/08/b0814846b79399e585f974bbeebf5580fbe59e258ea7be64d9dfb253c84f/greenlet-3.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:a7d4e128405eea3814a12cc2605e0e6aedb4035bf32697f72deca74de4105e02", size = 299899, upload-time = "2025-08-07T13:38:53.448Z" }, + { url = "https://files.pythonhosted.org/packages/f3/47/16400cb42d18d7a6bb46f0626852c1718612e35dcb0dffa16bbaffdf5dd2/greenlet-3.3.2-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:c56692189a7d1c7606cb794be0a8381470d95c57ce5be03fb3d0ef57c7853b86", size = 278890, upload-time = "2026-02-20T20:19:39.263Z" }, + { url = "https://files.pythonhosted.org/packages/a3/90/42762b77a5b6aa96cd8c0e80612663d39211e8ae8a6cd47c7f1249a66262/greenlet-3.3.2-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1ebd458fa8285960f382841da585e02201b53a5ec2bac6b156fc623b5ce4499f", size = 581120, upload-time = "2026-02-20T20:47:30.161Z" }, + { url = "https://files.pythonhosted.org/packages/bf/6f/f3d64f4fa0a9c7b5c5b3c810ff1df614540d5aa7d519261b53fba55d4df9/greenlet-3.3.2-cp311-cp311-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a443358b33c4ec7b05b79a7c8b466f5d275025e750298be7340f8fc63dff2a55", size = 594363, upload-time = "2026-02-20T20:55:56.965Z" }, + { url = "https://files.pythonhosted.org/packages/9c/8b/1430a04657735a3f23116c2e0d5eb10220928846e4537a938a41b350bed6/greenlet-3.3.2-cp311-cp311-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4375a58e49522698d3e70cc0b801c19433021b5c37686f7ce9c65b0d5c8677d2", size = 605046, upload-time = "2026-02-20T21:02:45.234Z" }, + { url = "https://files.pythonhosted.org/packages/72/83/3e06a52aca8128bdd4dcd67e932b809e76a96ab8c232a8b025b2850264c5/greenlet-3.3.2-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e2cd90d413acbf5e77ae41e5d3c9b3ac1d011a756d7284d7f3f2b806bbd6358", size = 594156, upload-time = "2026-02-20T20:20:59.955Z" }, + { url = "https://files.pythonhosted.org/packages/70/79/0de5e62b873e08fe3cef7dbe84e5c4bc0e8ed0c7ff131bccb8405cd107c8/greenlet-3.3.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:442b6057453c8cb29b4fb36a2ac689382fc71112273726e2423f7f17dc73bf99", size = 1554649, upload-time = "2026-02-20T20:49:32.293Z" }, + { url = "https://files.pythonhosted.org/packages/5a/00/32d30dee8389dc36d42170a9c66217757289e2afb0de59a3565260f38373/greenlet-3.3.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:45abe8eb6339518180d5a7fa47fa01945414d7cca5ecb745346fc6a87d2750be", size = 1619472, upload-time = "2026-02-20T20:21:07.966Z" }, + { url = "https://files.pythonhosted.org/packages/f1/3a/efb2cf697fbccdf75b24e2c18025e7dfa54c4f31fab75c51d0fe79942cef/greenlet-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:1e692b2dae4cc7077cbb11b47d258533b48c8fde69a33d0d8a82e2fe8d8531d5", size = 230389, upload-time = "2026-02-20T20:17:18.772Z" }, + { url = "https://files.pythonhosted.org/packages/e1/a1/65bbc059a43a7e2143ec4fc1f9e3f673e04f9c7b371a494a101422ac4fd5/greenlet-3.3.2-cp311-cp311-win_arm64.whl", hash = "sha256:02b0a8682aecd4d3c6c18edf52bc8e51eacdd75c8eac52a790a210b06aa295fd", size = 229645, upload-time = "2026-02-20T20:18:18.695Z" }, + { url = "https://files.pythonhosted.org/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd", size = 280358, upload-time = "2026-02-20T20:17:43.971Z" }, + { url = "https://files.pythonhosted.org/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd", size = 601217, upload-time = "2026-02-20T20:47:31.462Z" }, + { url = "https://files.pythonhosted.org/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac", size = 611792, upload-time = "2026-02-20T20:55:58.423Z" }, + { url = "https://files.pythonhosted.org/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb", size = 626250, upload-time = "2026-02-20T21:02:46.596Z" }, + { url = "https://files.pythonhosted.org/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070", size = 613875, upload-time = "2026-02-20T20:21:01.102Z" }, + { url = "https://files.pythonhosted.org/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79", size = 1571467, upload-time = "2026-02-20T20:49:33.495Z" }, + { url = "https://files.pythonhosted.org/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395", size = 1640001, upload-time = "2026-02-20T20:21:09.154Z" }, + { url = "https://files.pythonhosted.org/packages/9b/40/cc802e067d02af8b60b6771cea7d57e21ef5e6659912814babb42b864713/greenlet-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:34308836d8370bddadb41f5a7ce96879b72e2fdfb4e87729330c6ab52376409f", size = 231081, upload-time = "2026-02-20T20:17:28.121Z" }, + { url = "https://files.pythonhosted.org/packages/58/2e/fe7f36ff1982d6b10a60d5e0740c759259a7d6d2e1dc41da6d96de32fff6/greenlet-3.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:d3a62fa76a32b462a97198e4c9e99afb9ab375115e74e9a83ce180e7a496f643", size = 230331, upload-time = "2026-02-20T20:17:23.34Z" }, ] [[package]] name = "grimp" -version = "3.13" +version = "3.14" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/80/b3/ff0d704cdc5cf399d74aabd2bf1694d4c4c3231d4d74b011b8f39f686a86/grimp-3.13.tar.gz", hash = "sha256:759bf6e05186e6473ee71af4119ec181855b2b324f4fcdd78dee9e5b59d87874", size = 847508, upload-time = "2025-10-29T13:04:57.704Z" } +sdist = { url = "https://files.pythonhosted.org/packages/63/46/79764cfb61a3ac80dadae5d94fb10acdb7800e31fecf4113cf3d345e4952/grimp-3.14.tar.gz", hash = "sha256:645fbd835983901042dae4e1b24fde3a89bf7ac152f9272dd17a97e55cb4f871", size = 830882, upload-time = "2025-12-10T17:55:01.287Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/45/cc/d272cf87728a7e6ddb44d3c57c1d3cbe7daf2ffe4dc76e3dc9b953b69ab1/grimp-3.13-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:57745996698932768274a2ed9ba3e5c424f60996c53ecaf1c82b75be9e819ee9", size = 2074518, upload-time = "2025-10-29T13:03:58.51Z" }, - { url = "https://files.pythonhosted.org/packages/06/11/31dc622c5a0d1615b20532af2083f4bba2573aebbba5b9d6911dfd60a37d/grimp-3.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ca29f09710342b94fa6441f4d1102a0e49f0b463b1d91e43223baa949c5e9337", size = 1988182, upload-time = "2025-10-29T13:03:50.129Z" }, - { url = "https://files.pythonhosted.org/packages/aa/83/a0e19beb5c42df09e9a60711b227b4f910ba57f46bea258a9e1df883976c/grimp-3.13-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:adda25aa158e11d96dd27166300b955c8ec0c76ce2fd1a13597e9af012aada06", size = 2145832, upload-time = "2025-10-29T13:02:35.218Z" }, - { url = "https://files.pythonhosted.org/packages/bc/f5/13752205e290588e970fdc019b4ab2c063ca8da352295c332e34df5d5842/grimp-3.13-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03e17029d75500a5282b40cb15cdae030bf14df9dfaa6a2b983f08898dfe74b6", size = 2106762, upload-time = "2025-10-29T13:02:51.681Z" }, - { url = "https://files.pythonhosted.org/packages/ff/30/c4d62543beda4b9a483a6cd5b7dd5e4794aafb511f144d21a452467989a1/grimp-3.13-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6cbfc9d2d0ebc0631fb4012a002f3d8f4e3acb8325be34db525c0392674433b8", size = 2256674, upload-time = "2025-10-29T13:03:27.923Z" }, - { url = "https://files.pythonhosted.org/packages/9b/ea/d07ed41b7121719c3f7bf30c9881dbde69efeacfc2daf4e4a628efe5f123/grimp-3.13-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:161449751a085484608c5b9f863e41e8fb2a98e93f7312ead5d831e487a94518", size = 2442699, upload-time = "2025-10-29T13:03:04.451Z" }, - { url = "https://files.pythonhosted.org/packages/fe/a0/1923f0480756effb53c7e6cef02a3918bb519a86715992720838d44f0329/grimp-3.13-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:119628fbe7f941d1e784edac98e8ced7e78a0b966a4ff2c449e436ee860bd507", size = 2317145, upload-time = "2025-10-29T13:03:15.941Z" }, - { url = "https://files.pythonhosted.org/packages/0d/d9/aef4c8350090653e34bc755a5d9e39cc300f5c46c651c1d50195f69bf9ab/grimp-3.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ca1ac776baf1fa105342b23c72f2e7fdd6771d4cce8d2903d28f92fd34a9e8f", size = 2180288, upload-time = "2025-10-29T13:03:41.023Z" }, - { url = "https://files.pythonhosted.org/packages/9f/2e/a206f76eccffa56310a1c5d5950ed34923a34ae360cb38e297604a288837/grimp-3.13-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:941ff414cc66458f56e6af93c618266091ea70bfdabe7a84039be31d937051ee", size = 2328696, upload-time = "2025-10-29T13:04:06.888Z" }, - { url = "https://files.pythonhosted.org/packages/40/3b/88ff1554409b58faf2673854770e6fc6e90167a182f5166147b7618767d7/grimp-3.13-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:87ad9bcd1caaa2f77c369d61a04b9f2f1b87f4c3b23ae6891b2c943193c4ec62", size = 2367574, upload-time = "2025-10-29T13:04:21.404Z" }, - { url = "https://files.pythonhosted.org/packages/b6/b3/e9c99ecd94567465a0926ae7136e589aed336f6979a4cddcb8dfba16d27c/grimp-3.13-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:751fe37104a4f023d5c6556558b723d843d44361245c20f51a5d196de00e4774", size = 2358842, upload-time = "2025-10-29T13:04:34.26Z" }, - { url = "https://files.pythonhosted.org/packages/74/65/a5fffeeb9273e06dfbe962c8096331ba181ca8415c5f9d110b347f2c0c34/grimp-3.13-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9b561f79ec0b3a4156937709737191ad57520f2d58fa1fc43cd79f67839a3cd7", size = 2382268, upload-time = "2025-10-29T13:04:46.864Z" }, - { url = "https://files.pythonhosted.org/packages/d9/79/2f3b4323184329b26b46de2b6d1bd64ba1c26e0a9c3cfa0aaecec237b75e/grimp-3.13-cp311-cp311-win32.whl", hash = "sha256:52405ea8c8f20cf5d2d1866c80ee3f0243a38af82bd49d1464c5e254bf2e1f8f", size = 1759345, upload-time = "2025-10-29T13:05:10.435Z" }, - { url = "https://files.pythonhosted.org/packages/b6/ce/e86cf73e412a6bf531cbfa5c733f8ca48b28ebea23a037338be763f24849/grimp-3.13-cp311-cp311-win_amd64.whl", hash = "sha256:6a45d1d3beeefad69717b3718e53680fb3579fe67696b86349d6f39b75e850bf", size = 1859382, upload-time = "2025-10-29T13:05:01.071Z" }, - { url = "https://files.pythonhosted.org/packages/1d/06/ff7e3d72839f46f0fccdc79e1afe332318986751e20f65d7211a5e51366c/grimp-3.13-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:3e715c56ffdd055e5c84d27b4c02d83369b733e6a24579d42bbbc284bd0664a9", size = 2070161, upload-time = "2025-10-29T13:03:59.755Z" }, - { url = "https://files.pythonhosted.org/packages/58/2f/a95bdf8996db9400fd7e288f32628b2177b8840fe5f6b7cd96247b5fa173/grimp-3.13-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f794dea35a4728b948ab8fec970ffbdf2589b34209f3ab902cf8a9148cf1eaad", size = 1984365, upload-time = "2025-10-29T13:03:51.805Z" }, - { url = "https://files.pythonhosted.org/packages/1f/45/cc3d7f3b7b4d93e0b9d747dc45ed73a96203ba083dc857f24159eb6966b4/grimp-3.13-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69571270f2c27e8a64b968195aa7ecc126797112a9bf1e804ff39ba9f42d6f6d", size = 2145486, upload-time = "2025-10-29T13:02:36.591Z" }, - { url = "https://files.pythonhosted.org/packages/16/92/a6e493b71cb5a9145ad414cc4790c3779853372b840a320f052b22879606/grimp-3.13-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8f7b226398ae476762ef0afb5ef8f838d39c8e0e2f6d1a4378ce47059b221a4a", size = 2106747, upload-time = "2025-10-29T13:02:53.084Z" }, - { url = "https://files.pythonhosted.org/packages/db/8d/36a09f39fe14ad8843ef3ff81090ef23abbd02984c1fcc1cef30e5713d82/grimp-3.13-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5498aeac4df0131a1787fcbe9bb460b52fc9b781ec6bba607fd6a7d6d3ea6fce", size = 2257027, upload-time = "2025-10-29T13:03:29.44Z" }, - { url = "https://files.pythonhosted.org/packages/a1/7a/90f78787f80504caeef501f1bff47e8b9f6058d45995f1d4c921df17bfef/grimp-3.13-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4be702bb2b5c001a6baf709c452358470881e15e3e074cfc5308903603485dcb", size = 2441208, upload-time = "2025-10-29T13:03:05.733Z" }, - { url = "https://files.pythonhosted.org/packages/61/71/0fbd3a3e914512b9602fa24c8ebc85a8925b101f04f8a8c1d1e220e0a717/grimp-3.13-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fcf988f3e3d272a88f7be68f0c1d3719fee8624d902e9c0346b9015a0ea6a65", size = 2318758, upload-time = "2025-10-29T13:03:17.454Z" }, - { url = "https://files.pythonhosted.org/packages/34/e9/29c685e88b3b0688f0a2e30c0825e02076ecdf22bc0e37b1468562eaa09a/grimp-3.13-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ede36d104ff88c208140f978de3345f439345f35b8ef2b4390c59ef6984deba", size = 2180523, upload-time = "2025-10-29T13:03:42.3Z" }, - { url = "https://files.pythonhosted.org/packages/86/bc/7cc09574b287b8850a45051e73272f365259d9b6ca58d7b8773265c6fe35/grimp-3.13-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b35e44bb8dc80e0bd909a64387f722395453593a1884caca9dc0748efea33764", size = 2328855, upload-time = "2025-10-29T13:04:08.111Z" }, - { url = "https://files.pythonhosted.org/packages/34/86/3b0845900c8f984a57c6afe3409b20638065462d48b6afec0fd409fd6118/grimp-3.13-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:becb88e9405fc40896acd6e2b9bbf6f242a5ae2fd43a1ec0a32319ab6c10a227", size = 2367756, upload-time = "2025-10-29T13:04:22.736Z" }, - { url = "https://files.pythonhosted.org/packages/06/2d/4e70e8c06542db92c3fffaecb43ebfc4114a411505bff574d4da7d82c7db/grimp-3.13-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:a66585b4af46c3fbadbef495483514bee037e8c3075ed179ba4f13e494eb7899", size = 2358595, upload-time = "2025-10-29T13:04:35.595Z" }, - { url = "https://files.pythonhosted.org/packages/dd/06/c511d39eb6c73069af277f4e74991f1f29a05d90cab61f5416b9fc43932f/grimp-3.13-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:29f68c6e2ff70d782ca0e989ec4ec44df73ba847937bcbb6191499224a2f84e2", size = 2381464, upload-time = "2025-10-29T13:04:48.265Z" }, - { url = "https://files.pythonhosted.org/packages/86/f5/42197d69e4c9e2e7eed091d06493da3824e07c37324155569aa895c3b5f7/grimp-3.13-cp312-cp312-win32.whl", hash = "sha256:cc996dcd1a44ae52d257b9a3e98838f8ecfdc42f7c62c8c82c2fcd3828155c98", size = 1758510, upload-time = "2025-10-29T13:05:11.74Z" }, - { url = "https://files.pythonhosted.org/packages/30/dd/59c5f19f51e25f3dbf1c9e88067a88165f649ba1b8e4174dbaf1c950f78b/grimp-3.13-cp312-cp312-win_amd64.whl", hash = "sha256:e2966435947e45b11568f04a65863dcf836343c11ae44aeefdaa7f07eb1a0576", size = 1859530, upload-time = "2025-10-29T13:05:02.638Z" }, - { url = "https://files.pythonhosted.org/packages/e5/81/82de1b5d82701214b1f8e32b2e71fde8e1edbb4f2cdca9beb22ee6c8796d/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a6a3c76525b018c85c0e3a632d94d72be02225f8ada56670f3f213cf0762be4", size = 2145955, upload-time = "2025-10-29T13:02:47.559Z" }, - { url = "https://files.pythonhosted.org/packages/8c/ae/ada18cb73bdf97094af1c60070a5b85549482a57c509ee9a23fdceed4fc3/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:239e9b347af4da4cf69465bfa7b2901127f6057bc73416ba8187fb1eabafc6ea", size = 2107150, upload-time = "2025-10-29T13:02:59.891Z" }, - { url = "https://files.pythonhosted.org/packages/10/5e/6d8c65643ad5a1b6e00cc2cd8f56fc063923485f07c59a756fa61eefe7f2/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6db85ce2dc2f804a2edd1c1e9eaa46d282e1f0051752a83ca08ca8b87f87376", size = 2257515, upload-time = "2025-10-29T13:03:36.705Z" }, - { url = "https://files.pythonhosted.org/packages/b2/62/72cbfd7d0f2b95a53edd01d5f6b0d02bde38db739a727e35b76c13e0d0a8/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e000f3590bcc6ff7c781ebbc1ac4eb919f97180f13cc4002c868822167bd9aed", size = 2441262, upload-time = "2025-10-29T13:03:12.158Z" }, - { url = "https://files.pythonhosted.org/packages/18/00/b9209ab385567c3bddffb5d9eeecf9cb432b05c30ca8f35904b06e206a89/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e2374c217c862c1af933a430192d6a7c6723ed1d90303f1abbc26f709bbb9263", size = 2318557, upload-time = "2025-10-29T13:03:23.925Z" }, - { url = "https://files.pythonhosted.org/packages/11/4d/a3d73c11d09da00a53ceafe2884a71c78f5a76186af6d633cadd6c85d850/grimp-3.13-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ed0ff17d559ff2e7fa1be8ae086bc4fedcace5d7b12017f60164db8d9a8d806", size = 2180811, upload-time = "2025-10-29T13:03:47.461Z" }, - { url = "https://files.pythonhosted.org/packages/c1/9a/1cdfaa7d7beefd8859b190dfeba11d5ec074e8702b2903e9f182d662ed63/grimp-3.13-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:43960234aabce018c8d796ec8b77c484a1c9cbb6a3bc036a0d307c8dade9874c", size = 2329205, upload-time = "2025-10-29T13:04:15.845Z" }, - { url = "https://files.pythonhosted.org/packages/86/73/b36f86ef98df96e7e8a6166dfa60c8db5d597f051e613a3112f39a870b4c/grimp-3.13-pp311-pypy311_pp73-musllinux_1_2_armv7l.whl", hash = "sha256:44420b638b3e303f32314bd4d309f15de1666629035acd1cdd3720c15917ac85", size = 2368745, upload-time = "2025-10-29T13:04:29.706Z" }, - { url = "https://files.pythonhosted.org/packages/02/2f/0ce37872fad5c4b82d727f6e435fd5bc76f701279bddc9666710318940cf/grimp-3.13-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:f6127fdb982cf135612504d34aa16b841f421e54751fcd54f80b9531decb2b3f", size = 2358753, upload-time = "2025-10-29T13:04:42.632Z" }, - { url = "https://files.pythonhosted.org/packages/bb/23/935c888ac9ee71184fe5adf5ea86648746739be23c85932857ac19fc1d17/grimp-3.13-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:69893a9ef1edea25226ed17e8e8981e32900c59703972e0780c0e927ce624f75", size = 2383066, upload-time = "2025-10-29T13:04:55.073Z" }, + { url = "https://files.pythonhosted.org/packages/25/31/d4a86207c38954b6c3d859a1fc740a80b04bbe6e3b8a39f4e66f9633dfa4/grimp-3.14-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f1c91e3fa48c2196bf62e3c71492140d227b2bfcd6d15e735cbc0b3e2d5308e0", size = 2185572, upload-time = "2025-12-10T17:53:41.287Z" }, + { url = "https://files.pythonhosted.org/packages/f5/61/ed4cba5bd75d37fe46e17a602f616619a9e4f74ad8adfcf560ce4b2a1697/grimp-3.14-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6291c8f1690a9fe21b70923c60b075f4a89676541999e3d33084cbc69ac06a1", size = 2118002, upload-time = "2025-12-10T17:53:18.546Z" }, + { url = "https://files.pythonhosted.org/packages/77/6a/688f6144d0b207d7845bd8ab403820a83630ce3c9420cbbc7c9e9282f9c0/grimp-3.14-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ec312383935c2d09e4085c8435780ada2e13ebef14e105609c2988a02a5b2ce", size = 2283939, upload-time = "2025-12-10T17:52:06.228Z" }, + { url = "https://files.pythonhosted.org/packages/a5/98/4c540de151bf3fd58d6d7b3fe2269b6a6af6c61c915de1bc991802bfaff8/grimp-3.14-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4f43cbf640e73ee703ad91639591046828d20103a1c363a02516e77a66a4ac07", size = 2233693, upload-time = "2025-12-10T17:52:18.938Z" }, + { url = "https://files.pythonhosted.org/packages/3e/7b/84b4b52b6c6dd5bf083cb1a72945748f56ea2e61768bbebf87e8d9d0ef75/grimp-3.14-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2a93c9fddccb9ff16f5c6b5fca44227f5f86cba7cffc145d2176119603d2d7c7", size = 2389745, upload-time = "2025-12-10T17:53:00.659Z" }, + { url = "https://files.pythonhosted.org/packages/a7/33/31b96907c7dd78953df5e1ce67c558bd6057220fa1203d28d52566315a2e/grimp-3.14-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5653a2769fdc062cb7598d12200352069c9c6559b6643af6ada3639edb98fcc3", size = 2569055, upload-time = "2025-12-10T17:52:33.556Z" }, + { url = "https://files.pythonhosted.org/packages/b2/24/ce1a8110f3d5b178153b903aafe54b6a9216588b5bff3656e30af43e9c29/grimp-3.14-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:071c7ddf5e5bb7b2fdf79aefdf6e1c237cd81c095d6d0a19620e777e85bf103c", size = 2358044, upload-time = "2025-12-10T17:52:47.545Z" }, + { url = "https://files.pythonhosted.org/packages/05/7f/16d98c02287bc99884843478b9a68b04a2ef13b5cb8b9f36a9ca7daea75b/grimp-3.14-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e01b7a4419f535b667dfdcb556d3815b52981474f791fb40d72607228389a31", size = 2310304, upload-time = "2025-12-10T17:53:09.679Z" }, + { url = "https://files.pythonhosted.org/packages/a5/8c/0fde9781b0f6b4f9227d485685f48f6bcc70b95af22e2f85ff7f416cbfc1/grimp-3.14-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c29682f336151d1d018d0c3aa9eeaa35734b970e4593fa396b901edca7ef5c79", size = 2463682, upload-time = "2025-12-10T17:53:49.185Z" }, + { url = "https://files.pythonhosted.org/packages/51/cb/2baff301c2c2cc2792b6e225ea0784793ca587c81b97572be0bad122cfc8/grimp-3.14-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:a5c4fd71f363ea39e8aab0630010ced77a8de9789f27c0acdd0d7e6269d4a8ef", size = 2500573, upload-time = "2025-12-10T17:54:03.899Z" }, + { url = "https://files.pythonhosted.org/packages/96/69/797e4242f42d6665da5fe22cb250cae3f14ece4cb22ad153e9cd97158179/grimp-3.14-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:766911e3ba0b13d833fdd03ad1f217523a8a2b2527b5507335f71dca1153183d", size = 2503005, upload-time = "2025-12-10T17:54:32.993Z" }, + { url = "https://files.pythonhosted.org/packages/fd/45/da1a27a6377807ca427cd56534231f0920e1895e16630204f382a0df14c5/grimp-3.14-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:154e84a2053e9f858ae48743de23a5ad4eb994007518c29371276f59b8419036", size = 2515776, upload-time = "2025-12-10T17:54:47.962Z" }, + { url = "https://files.pythonhosted.org/packages/4f/8d/b918a29ce98029cd7a9e33a584be43a93288d5283fb7ccef5b6b2ba39ede/grimp-3.14-cp311-cp311-win32.whl", hash = "sha256:3189c86c3e73016a1907ee3ba9f7a6ca037e3601ad09e60ce9bf12b88877f812", size = 1873189, upload-time = "2025-12-10T17:55:11.872Z" }, + { url = "https://files.pythonhosted.org/packages/90/d7/2327c203f83a25766fbd62b0df3b24230d422b6e53518ff4d1c5e69793f1/grimp-3.14-cp311-cp311-win_amd64.whl", hash = "sha256:201f46a6a4e5ee9dfba4a2f7d043f7deab080d1d84233f4a1aee812678c25307", size = 2014277, upload-time = "2025-12-10T17:55:04.144Z" }, + { url = "https://files.pythonhosted.org/packages/75/d6/a35ff62f35aa5fd148053506eddd7a8f2f6afaed31870dc608dd0eb38e4f/grimp-3.14-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:ffabc6940301214753bad89ec0bfe275892fa1f64b999e9a101f6cebfc777133", size = 2178573, upload-time = "2025-12-10T17:53:42.836Z" }, + { url = "https://files.pythonhosted.org/packages/93/e2/bd2e80273da4d46110969fc62252e5372e0249feb872bc7fe76fdc7f1818/grimp-3.14-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:075d9a1c78d607792d0ed8d4d3d7754a621ef04c8a95eaebf634930dc9232bb2", size = 2110452, upload-time = "2025-12-10T17:53:19.831Z" }, + { url = "https://files.pythonhosted.org/packages/44/c3/7307249c657d34dca9d250d73ba027d6cfe15a98fb3119b6e5210bc388b7/grimp-3.14-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:06ff52addeb20955a4d6aa097bee910573ffc9ef0d3c8a860844f267ad958156", size = 2283064, upload-time = "2025-12-10T17:52:07.673Z" }, + { url = "https://files.pythonhosted.org/packages/c7/d2/cae4cf32dc8d4188837cc4ab183300d655f898969b0f169e240f3b7c25be/grimp-3.14-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d10e0663e961fcbe8d0f54608854af31f911f164c96a44112d5173050132701f", size = 2235893, upload-time = "2025-12-10T17:52:20.418Z" }, + { url = "https://files.pythonhosted.org/packages/04/92/3f58bc3064fc305dac107d08003ba65713a5bc89a6d327f1c06b30cce752/grimp-3.14-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4ab874d7ddddc7a1291259cf7c31a4e7b5c612e9da2e24c67c0eb1a44a624e67", size = 2393376, upload-time = "2025-12-10T17:53:02.397Z" }, + { url = "https://files.pythonhosted.org/packages/06/b8/f476f30edf114f04cb58e8ae162cb4daf52bda0ab01919f3b5b7edb98430/grimp-3.14-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:54fec672ec83355636a852177f5a470c964bede0f6730f9ba3c7b5c8419c9eab", size = 2571342, upload-time = "2025-12-10T17:52:35.214Z" }, + { url = "https://files.pythonhosted.org/packages/c4/ae/2e44d3c4f591f95f86322a8f4dbb5aac17001d49e079f3a80e07e7caaf09/grimp-3.14-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b9e221b5e8070a916c780e88c877fee2a61c95a76a76a2a076396e459511b0bb", size = 2359022, upload-time = "2025-12-10T17:52:49.063Z" }, + { url = "https://files.pythonhosted.org/packages/69/ac/42b4d6bc0ea119ce2e91e1788feabf32c5433e9617dbb495c2a3d0dc7f12/grimp-3.14-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eea6b495f9b4a8d82f5ce544921e76d0d12017f5d1ac3a3bd2f5ac88ab055b1c", size = 2309424, upload-time = "2025-12-10T17:53:11.069Z" }, + { url = "https://files.pythonhosted.org/packages/e8/c7/6a731989625c1790f4da7602dcbf9d6525512264e853cda77b3b3602d5e0/grimp-3.14-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:655e8d3f79cd99bb859e09c9dd633515150e9d850879ca71417d5ac31809b745", size = 2462754, upload-time = "2025-12-10T17:53:50.886Z" }, + { url = "https://files.pythonhosted.org/packages/cd/4d/3d1571c0a39a59dd68be4835f766da64fe64cbab0d69426210b716a8bdf0/grimp-3.14-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:a14f10b1b71c6c37647a76e6a49c226509648107abc0f48c1e3ecd158ba05531", size = 2501356, upload-time = "2025-12-10T17:54:06.014Z" }, + { url = "https://files.pythonhosted.org/packages/eb/d1/8950b8229095ebda5c54c8784e4d1f0a6e19423f2847289ef9751f878798/grimp-3.14-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:81685111ee24d3e25f8ed9e77ed00b92b58b2414e1a1c2937236026900972744", size = 2504631, upload-time = "2025-12-10T17:54:34.441Z" }, + { url = "https://files.pythonhosted.org/packages/0a/e6/23bed3da9206138d36d01890b656c7fb7adfb3a37daac8842d84d8777ade/grimp-3.14-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce8352a8ea0e27b143136ea086582fc6653419aa8a7c15e28ed08c898c42b185", size = 2514751, upload-time = "2025-12-10T17:54:49.384Z" }, + { url = "https://files.pythonhosted.org/packages/eb/45/6f1f55c97ee982f133ec5ccb22fc99bf5335aee70c208f4fb86cd833b8d5/grimp-3.14-cp312-cp312-win32.whl", hash = "sha256:3fc0f98b3c60d88e9ffa08faff3200f36604930972f8b29155f323b76ea25a06", size = 1875041, upload-time = "2025-12-10T17:55:13.326Z" }, + { url = "https://files.pythonhosted.org/packages/cf/cf/03ba01288e2a41a948bc8526f32c2eeaddd683ed34be1b895e31658d5a4c/grimp-3.14-cp312-cp312-win_amd64.whl", hash = "sha256:6bca77d1d50c8dc402c96af21f4e28e2f1e9938eeabd7417592a22bd83cde3c3", size = 2013868, upload-time = "2025-12-10T17:55:05.907Z" }, + { url = "https://files.pythonhosted.org/packages/65/cc/dbc00210d0324b8fc1242d8e857757c7e0b62ff0fc0c1bc8dcc42342da85/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7c8a8aab9b4310a7e69d7d845cac21cf14563aa0520ea322b948eadeae56d303", size = 2284804, upload-time = "2025-12-10T17:52:16.379Z" }, + { url = "https://files.pythonhosted.org/packages/80/89/851d3d345342e9bcec3fe85d3997db29501fa59f958c1566bf3e24d9d7d9/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d781943b27e5875a41c8f9cfc80f8f0a349f864379192b8c3faa0e6a22593313", size = 2235176, upload-time = "2025-12-10T17:52:30.795Z" }, + { url = "https://files.pythonhosted.org/packages/58/78/5f94702a8d5c121cafcdc9664de34c34f19d0d91a1127bf3946a2631f7a3/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9630d4633607aff94d0ac84b9c64fef1382cdb05b00d9acbde47f8745e264871", size = 2391258, upload-time = "2025-12-10T17:53:06.906Z" }, + { url = "https://files.pythonhosted.org/packages/e9/a2/df8c79de5c9e227856d048cc1551c4742a5f97660c40304ac278bd48607f/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7cb00e1bcca583668554a8e9e1e4229a1d11b0620969310aae40148829ff6a32", size = 2571443, upload-time = "2025-12-10T17:52:43.853Z" }, + { url = "https://files.pythonhosted.org/packages/f0/21/747b7ed9572bbdc34a76dfec12ce510e80164b1aa06d3b21b34994e5f567/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3389da4ceaaa7f7de24a668c0afc307a9f95997bd90f81ec359a828a9bd1d270", size = 2357767, upload-time = "2025-12-10T17:52:57.84Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e6/485c5e3b64933e71f72f0cc45b0d7130418a6a5a13cedc2e8411bd76f290/grimp-3.14-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd7a32970ef97e42d4e7369397c7795287d84a736d788ccb90b6c14f0561d975", size = 2309069, upload-time = "2025-12-10T17:53:15.203Z" }, + { url = "https://files.pythonhosted.org/packages/31/bd/12024a8cba1c77facc1422a7b48cd0d04c252fc9178fd6f99dc05a8af57b/grimp-3.14-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:fd1278623fa09f62abc0fd8a6500f31b421a1fd479980f44c2926020a0becf02", size = 2466429, upload-time = "2025-12-10T17:54:00.286Z" }, + { url = "https://files.pythonhosted.org/packages/ee/7f/0e5977887e1c8f00f84bb4125217534806ffdcef9cf52f3580aa3b151f4b/grimp-3.14-pp311-pypy311_pp73-musllinux_1_2_armv7l.whl", hash = "sha256:9cfa52c89333d3d8fe9dc782529e888270d060231c3783e036d424044671dde0", size = 2501190, upload-time = "2025-12-10T17:54:30.107Z" }, + { url = "https://files.pythonhosted.org/packages/42/6b/06acb94b6d0d8c7277bb3e33f93224aa3be5b04643f853479d3bf7b23ace/grimp-3.14-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:48a5be4a12fca6587e6885b4fc13b9e242ab8bf874519292f0f13814aecf52cc", size = 2503440, upload-time = "2025-12-10T17:54:44.444Z" }, + { url = "https://files.pythonhosted.org/packages/5b/4d/2e531370d12e7a564f67f680234710bbc08554238a54991cd244feb61fb6/grimp-3.14-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:3fcc332466783a12a42cd317fd344c30fe734ba4fa2362efff132dc3f8d36da7", size = 2516525, upload-time = "2025-12-10T17:54:58.987Z" }, ] [[package]] @@ -2817,88 +2913,92 @@ wheels = [ [[package]] name = "grpcio" -version = "1.76.0" +version = "1.78.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b6/e0/318c1ce3ae5a17894d5791e87aea147587c9e702f24122cc7a5c8bbaeeb1/grpcio-1.76.0.tar.gz", hash = "sha256:7be78388d6da1a25c0d5ec506523db58b18be22d9c37d8d3a32c08be4987bd73", size = 12785182, upload-time = "2025-10-21T16:23:12.106Z" } +sdist = { url = "https://files.pythonhosted.org/packages/06/8a/3d098f35c143a89520e568e6539cc098fcd294495910e359889ce8741c84/grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5", size = 12852416, upload-time = "2026-02-06T09:57:18.093Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/00/8163a1beeb6971f66b4bbe6ac9457b97948beba8dd2fc8e1281dce7f79ec/grpcio-1.76.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2e1743fbd7f5fa713a1b0a8ac8ebabf0ec980b5d8809ec358d488e273b9cf02a", size = 5843567, upload-time = "2025-10-21T16:20:52.829Z" }, - { url = "https://files.pythonhosted.org/packages/10/c1/934202f5cf335e6d852530ce14ddb0fef21be612ba9ecbbcbd4d748ca32d/grpcio-1.76.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:a8c2cf1209497cf659a667d7dea88985e834c24b7c3b605e6254cbb5076d985c", size = 11848017, upload-time = "2025-10-21T16:20:56.705Z" }, - { url = "https://files.pythonhosted.org/packages/11/0b/8dec16b1863d74af6eb3543928600ec2195af49ca58b16334972f6775663/grpcio-1.76.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:08caea849a9d3c71a542827d6df9d5a69067b0a1efbea8a855633ff5d9571465", size = 6412027, upload-time = "2025-10-21T16:20:59.3Z" }, - { url = "https://files.pythonhosted.org/packages/d7/64/7b9e6e7ab910bea9d46f2c090380bab274a0b91fb0a2fe9b0cd399fffa12/grpcio-1.76.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:f0e34c2079d47ae9f6188211db9e777c619a21d4faba6977774e8fa43b085e48", size = 7075913, upload-time = "2025-10-21T16:21:01.645Z" }, - { url = "https://files.pythonhosted.org/packages/68/86/093c46e9546073cefa789bd76d44c5cb2abc824ca62af0c18be590ff13ba/grpcio-1.76.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8843114c0cfce61b40ad48df65abcfc00d4dba82eae8718fab5352390848c5da", size = 6615417, upload-time = "2025-10-21T16:21:03.844Z" }, - { url = "https://files.pythonhosted.org/packages/f7/b6/5709a3a68500a9c03da6fb71740dcdd5ef245e39266461a03f31a57036d8/grpcio-1.76.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8eddfb4d203a237da6f3cc8a540dad0517d274b5a1e9e636fd8d2c79b5c1d397", size = 7199683, upload-time = "2025-10-21T16:21:06.195Z" }, - { url = "https://files.pythonhosted.org/packages/91/d3/4b1f2bf16ed52ce0b508161df3a2d186e4935379a159a834cb4a7d687429/grpcio-1.76.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:32483fe2aab2c3794101c2a159070584e5db11d0aa091b2c0ea9c4fc43d0d749", size = 8163109, upload-time = "2025-10-21T16:21:08.498Z" }, - { url = "https://files.pythonhosted.org/packages/5c/61/d9043f95f5f4cf085ac5dd6137b469d41befb04bd80280952ffa2a4c3f12/grpcio-1.76.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:dcfe41187da8992c5f40aa8c5ec086fa3672834d2be57a32384c08d5a05b4c00", size = 7626676, upload-time = "2025-10-21T16:21:10.693Z" }, - { url = "https://files.pythonhosted.org/packages/36/95/fd9a5152ca02d8881e4dd419cdd790e11805979f499a2e5b96488b85cf27/grpcio-1.76.0-cp311-cp311-win32.whl", hash = "sha256:2107b0c024d1b35f4083f11245c0e23846ae64d02f40b2b226684840260ed054", size = 3997688, upload-time = "2025-10-21T16:21:12.746Z" }, - { url = "https://files.pythonhosted.org/packages/60/9c/5c359c8d4c9176cfa3c61ecd4efe5affe1f38d9bae81e81ac7186b4c9cc8/grpcio-1.76.0-cp311-cp311-win_amd64.whl", hash = "sha256:522175aba7af9113c48ec10cc471b9b9bd4f6ceb36aeb4544a8e2c80ed9d252d", size = 4709315, upload-time = "2025-10-21T16:21:15.26Z" }, - { url = "https://files.pythonhosted.org/packages/bf/05/8e29121994b8d959ffa0afd28996d452f291b48cfc0875619de0bde2c50c/grpcio-1.76.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:81fd9652b37b36f16138611c7e884eb82e0cec137c40d3ef7c3f9b3ed00f6ed8", size = 5799718, upload-time = "2025-10-21T16:21:17.939Z" }, - { url = "https://files.pythonhosted.org/packages/d9/75/11d0e66b3cdf998c996489581bdad8900db79ebd83513e45c19548f1cba4/grpcio-1.76.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:04bbe1bfe3a68bbfd4e52402ab7d4eb59d72d02647ae2042204326cf4bbad280", size = 11825627, upload-time = "2025-10-21T16:21:20.466Z" }, - { url = "https://files.pythonhosted.org/packages/28/50/2f0aa0498bc188048f5d9504dcc5c2c24f2eb1a9337cd0fa09a61a2e75f0/grpcio-1.76.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d388087771c837cdb6515539f43b9d4bf0b0f23593a24054ac16f7a960be16f4", size = 6359167, upload-time = "2025-10-21T16:21:23.122Z" }, - { url = "https://files.pythonhosted.org/packages/66/e5/bbf0bb97d29ede1d59d6588af40018cfc345b17ce979b7b45424628dc8bb/grpcio-1.76.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:9f8f757bebaaea112c00dba718fc0d3260052ce714e25804a03f93f5d1c6cc11", size = 7044267, upload-time = "2025-10-21T16:21:25.995Z" }, - { url = "https://files.pythonhosted.org/packages/f5/86/f6ec2164f743d9609691115ae8ece098c76b894ebe4f7c94a655c6b03e98/grpcio-1.76.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:980a846182ce88c4f2f7e2c22c56aefd515daeb36149d1c897f83cf57999e0b6", size = 6573963, upload-time = "2025-10-21T16:21:28.631Z" }, - { url = "https://files.pythonhosted.org/packages/60/bc/8d9d0d8505feccfdf38a766d262c71e73639c165b311c9457208b56d92ae/grpcio-1.76.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f92f88e6c033db65a5ae3d97905c8fea9c725b63e28d5a75cb73b49bda5024d8", size = 7164484, upload-time = "2025-10-21T16:21:30.837Z" }, - { url = "https://files.pythonhosted.org/packages/67/e6/5d6c2fc10b95edf6df9b8f19cf10a34263b7fd48493936fffd5085521292/grpcio-1.76.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:4baf3cbe2f0be3289eb68ac8ae771156971848bb8aaff60bad42005539431980", size = 8127777, upload-time = "2025-10-21T16:21:33.577Z" }, - { url = "https://files.pythonhosted.org/packages/3f/c8/dce8ff21c86abe025efe304d9e31fdb0deaaa3b502b6a78141080f206da0/grpcio-1.76.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:615ba64c208aaceb5ec83bfdce7728b80bfeb8be97562944836a7a0a9647d882", size = 7594014, upload-time = "2025-10-21T16:21:41.882Z" }, - { url = "https://files.pythonhosted.org/packages/e0/42/ad28191ebf983a5d0ecef90bab66baa5a6b18f2bfdef9d0a63b1973d9f75/grpcio-1.76.0-cp312-cp312-win32.whl", hash = "sha256:45d59a649a82df5718fd9527ce775fd66d1af35e6d31abdcdc906a49c6822958", size = 3984750, upload-time = "2025-10-21T16:21:44.006Z" }, - { url = "https://files.pythonhosted.org/packages/9e/00/7bd478cbb851c04a48baccaa49b75abaa8e4122f7d86da797500cccdd771/grpcio-1.76.0-cp312-cp312-win_amd64.whl", hash = "sha256:c088e7a90b6017307f423efbb9d1ba97a22aa2170876223f9709e9d1de0b5347", size = 4704003, upload-time = "2025-10-21T16:21:46.244Z" }, + { url = "https://files.pythonhosted.org/packages/86/c7/d0b780a29b0837bf4ca9580904dfb275c1fc321ded7897d620af7047ec57/grpcio-1.78.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:2777b783f6c13b92bd7b716667452c329eefd646bfb3f2e9dabea2e05dbd34f6", size = 5951525, upload-time = "2026-02-06T09:55:01.989Z" }, + { url = "https://files.pythonhosted.org/packages/c5/b1/96920bf2ee61df85a9503cb6f733fe711c0ff321a5a697d791b075673281/grpcio-1.78.0-cp311-cp311-macosx_11_0_universal2.whl", hash = "sha256:9dca934f24c732750389ce49d638069c3892ad065df86cb465b3fa3012b70c9e", size = 11830418, upload-time = "2026-02-06T09:55:04.462Z" }, + { url = "https://files.pythonhosted.org/packages/83/0c/7c1528f098aeb75a97de2bae18c530f56959fb7ad6c882db45d9884d6edc/grpcio-1.78.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:459ab414b35f4496138d0ecd735fed26f1318af5e52cb1efbc82a09f0d5aa911", size = 6524477, upload-time = "2026-02-06T09:55:07.111Z" }, + { url = "https://files.pythonhosted.org/packages/8d/52/e7c1f3688f949058e19a011c4e0dec973da3d0ae5e033909677f967ae1f4/grpcio-1.78.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:082653eecbdf290e6e3e2c276ab2c54b9e7c299e07f4221872380312d8cf395e", size = 7198266, upload-time = "2026-02-06T09:55:10.016Z" }, + { url = "https://files.pythonhosted.org/packages/e5/61/8ac32517c1e856677282c34f2e7812d6c328fa02b8f4067ab80e77fdc9c9/grpcio-1.78.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:85f93781028ec63f383f6bc90db785a016319c561cc11151fbb7b34e0d012303", size = 6730552, upload-time = "2026-02-06T09:55:12.207Z" }, + { url = "https://files.pythonhosted.org/packages/bd/98/b8ee0158199250220734f620b12e4a345955ac7329cfd908d0bf0fda77f0/grpcio-1.78.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f12857d24d98441af6a1d5c87442d624411db486f7ba12550b07788f74b67b04", size = 7304296, upload-time = "2026-02-06T09:55:15.044Z" }, + { url = "https://files.pythonhosted.org/packages/bd/0f/7b72762e0d8840b58032a56fdbd02b78fc645b9fa993d71abf04edbc54f4/grpcio-1.78.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:5397fff416b79e4b284959642a4e95ac4b0f1ece82c9993658e0e477d40551ec", size = 8288298, upload-time = "2026-02-06T09:55:17.276Z" }, + { url = "https://files.pythonhosted.org/packages/24/ae/ae4ce56bc5bb5caa3a486d60f5f6083ac3469228faa734362487176c15c5/grpcio-1.78.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fbe6e89c7ffb48518384068321621b2a69cab509f58e40e4399fdd378fa6d074", size = 7730953, upload-time = "2026-02-06T09:55:19.545Z" }, + { url = "https://files.pythonhosted.org/packages/b5/6e/8052e3a28eb6a820c372b2eb4b5e32d195c661e137d3eca94d534a4cfd8a/grpcio-1.78.0-cp311-cp311-win32.whl", hash = "sha256:6092beabe1966a3229f599d7088b38dfc8ffa1608b5b5cdda31e591e6500f856", size = 4076503, upload-time = "2026-02-06T09:55:21.521Z" }, + { url = "https://files.pythonhosted.org/packages/08/62/f22c98c5265dfad327251fa2f840b591b1df5f5e15d88b19c18c86965b27/grpcio-1.78.0-cp311-cp311-win_amd64.whl", hash = "sha256:1afa62af6e23f88629f2b29ec9e52ec7c65a7176c1e0a83292b93c76ca882558", size = 4799767, upload-time = "2026-02-06T09:55:24.107Z" }, + { url = "https://files.pythonhosted.org/packages/4e/f4/7384ed0178203d6074446b3c4f46c90a22ddf7ae0b3aee521627f54cfc2a/grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97", size = 5913985, upload-time = "2026-02-06T09:55:26.832Z" }, + { url = "https://files.pythonhosted.org/packages/81/ed/be1caa25f06594463f685b3790b320f18aea49b33166f4141bfdc2bfb236/grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e", size = 11811853, upload-time = "2026-02-06T09:55:29.224Z" }, + { url = "https://files.pythonhosted.org/packages/24/a7/f06d151afc4e64b7e3cc3e872d331d011c279aaab02831e40a81c691fb65/grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996", size = 6475766, upload-time = "2026-02-06T09:55:31.825Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a8/4482922da832ec0082d0f2cc3a10976d84a7424707f25780b82814aafc0a/grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7", size = 7170027, upload-time = "2026-02-06T09:55:34.7Z" }, + { url = "https://files.pythonhosted.org/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9", size = 6690766, upload-time = "2026-02-06T09:55:36.902Z" }, + { url = "https://files.pythonhosted.org/packages/c7/b9/521875265cc99fe5ad4c5a17010018085cae2810a928bf15ebe7d8bcd9cc/grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383", size = 7266161, upload-time = "2026-02-06T09:55:39.824Z" }, + { url = "https://files.pythonhosted.org/packages/05/86/296a82844fd40a4ad4a95f100b55044b4f817dece732bf686aea1a284147/grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6", size = 8253303, upload-time = "2026-02-06T09:55:42.353Z" }, + { url = "https://files.pythonhosted.org/packages/f3/e4/ea3c0caf5468537f27ad5aab92b681ed7cc0ef5f8c9196d3fd42c8c2286b/grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce", size = 7698222, upload-time = "2026-02-06T09:55:44.629Z" }, + { url = "https://files.pythonhosted.org/packages/d7/47/7f05f81e4bb6b831e93271fb12fd52ba7b319b5402cbc101d588f435df00/grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68", size = 4066123, upload-time = "2026-02-06T09:55:47.644Z" }, + { url = "https://files.pythonhosted.org/packages/ad/e7/d6914822c88aa2974dbbd10903d801a28a19ce9cd8bad7e694cbbcf61528/grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e", size = 4797657, upload-time = "2026-02-06T09:55:49.86Z" }, ] [[package]] name = "grpcio-status" -version = "1.62.3" +version = "1.71.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "googleapis-common-protos" }, { name = "grpcio" }, { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7c/d7/013ef01c5a1c2fd0932c27c904934162f69f41ca0f28396d3ffe4d386123/grpcio-status-1.62.3.tar.gz", hash = "sha256:289bdd7b2459794a12cf95dc0cb727bd4a1742c37bd823f760236c937e53a485", size = 13063, upload-time = "2024-08-06T00:37:08.003Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/d1/b6e9877fedae3add1afdeae1f89d1927d296da9cf977eca0eb08fb8a460e/grpcio_status-1.71.2.tar.gz", hash = "sha256:c7a97e176df71cdc2c179cd1847d7fc86cca5832ad12e9798d7fed6b7a1aab50", size = 13677, upload-time = "2025-06-28T04:24:05.426Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/90/40/972271de05f9315c0d69f9f7ebbcadd83bc85322f538637d11bb8c67803d/grpcio_status-1.62.3-py3-none-any.whl", hash = "sha256:f9049b762ba8de6b1086789d8315846e094edac2c50beaf462338b301a8fd4b8", size = 14448, upload-time = "2024-08-06T00:30:15.702Z" }, + { url = "https://files.pythonhosted.org/packages/67/58/317b0134129b556a93a3b0afe00ee675b5657f0155509e22fcb853bafe2d/grpcio_status-1.71.2-py3-none-any.whl", hash = "sha256:803c98cb6a8b7dc6dbb785b1111aed739f241ab5e9da0bba96888aa74704cfd3", size = 14424, upload-time = "2025-06-28T04:23:42.136Z" }, ] [[package]] name = "grpcio-tools" -version = "1.62.3" +version = "1.71.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "grpcio" }, { name = "protobuf" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/54/fa/b69bd8040eafc09b88bb0ec0fea59e8aacd1a801e688af087cead213b0d0/grpcio-tools-1.62.3.tar.gz", hash = "sha256:7c7136015c3d62c3eef493efabaf9e3380e3e66d24ee8e94c01cb71377f57833", size = 4538520, upload-time = "2024-08-06T00:37:11.035Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ad/9a/edfefb47f11ef6b0f39eea4d8f022c5bb05ac1d14fcc7058e84a51305b73/grpcio_tools-1.71.2.tar.gz", hash = "sha256:b5304d65c7569b21270b568e404a5a843cf027c66552a6a0978b23f137679c09", size = 5330655, upload-time = "2025-06-28T04:22:00.308Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/23/52/2dfe0a46b63f5ebcd976570aa5fc62f793d5a8b169e211c6a5aede72b7ae/grpcio_tools-1.62.3-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:703f46e0012af83a36082b5f30341113474ed0d91e36640da713355cd0ea5d23", size = 5147623, upload-time = "2024-08-06T00:30:54.894Z" }, - { url = "https://files.pythonhosted.org/packages/f0/2e/29fdc6c034e058482e054b4a3c2432f84ff2e2765c1342d4f0aa8a5c5b9a/grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:7cc83023acd8bc72cf74c2edbe85b52098501d5b74d8377bfa06f3e929803492", size = 2719538, upload-time = "2024-08-06T00:30:57.928Z" }, - { url = "https://files.pythonhosted.org/packages/f9/60/abe5deba32d9ec2c76cdf1a2f34e404c50787074a2fee6169568986273f1/grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7ff7d58a45b75df67d25f8f144936a3e44aabd91afec833ee06826bd02b7fbe7", size = 3070964, upload-time = "2024-08-06T00:31:00.267Z" }, - { url = "https://files.pythonhosted.org/packages/bc/ad/e2b066684c75f8d9a48508cde080a3a36618064b9cadac16d019ca511444/grpcio_tools-1.62.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f2483ea232bd72d98a6dc6d7aefd97e5bc80b15cd909b9e356d6f3e326b6e43", size = 2805003, upload-time = "2024-08-06T00:31:02.565Z" }, - { url = "https://files.pythonhosted.org/packages/9c/3f/59bf7af786eae3f9d24ee05ce75318b87f541d0950190ecb5ffb776a1a58/grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:962c84b4da0f3b14b3cdb10bc3837ebc5f136b67d919aea8d7bb3fd3df39528a", size = 3685154, upload-time = "2024-08-06T00:31:05.339Z" }, - { url = "https://files.pythonhosted.org/packages/f1/79/4dd62478b91e27084c67b35a2316ce8a967bd8b6cb8d6ed6c86c3a0df7cb/grpcio_tools-1.62.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8ad0473af5544f89fc5a1ece8676dd03bdf160fb3230f967e05d0f4bf89620e3", size = 3297942, upload-time = "2024-08-06T00:31:08.456Z" }, - { url = "https://files.pythonhosted.org/packages/b8/cb/86449ecc58bea056b52c0b891f26977afc8c4464d88c738f9648da941a75/grpcio_tools-1.62.3-cp311-cp311-win32.whl", hash = "sha256:db3bc9fa39afc5e4e2767da4459df82b095ef0cab2f257707be06c44a1c2c3e5", size = 910231, upload-time = "2024-08-06T00:31:11.464Z" }, - { url = "https://files.pythonhosted.org/packages/45/a4/9736215e3945c30ab6843280b0c6e1bff502910156ea2414cd77fbf1738c/grpcio_tools-1.62.3-cp311-cp311-win_amd64.whl", hash = "sha256:e0898d412a434e768a0c7e365acabe13ff1558b767e400936e26b5b6ed1ee51f", size = 1052496, upload-time = "2024-08-06T00:31:13.665Z" }, - { url = "https://files.pythonhosted.org/packages/2a/a5/d6887eba415ce318ae5005e8dfac3fa74892400b54b6d37b79e8b4f14f5e/grpcio_tools-1.62.3-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:d102b9b21c4e1e40af9a2ab3c6d41afba6bd29c0aa50ca013bf85c99cdc44ac5", size = 5147690, upload-time = "2024-08-06T00:31:16.436Z" }, - { url = "https://files.pythonhosted.org/packages/8a/7c/3cde447a045e83ceb4b570af8afe67ffc86896a2fe7f59594dc8e5d0a645/grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:0a52cc9444df978438b8d2332c0ca99000521895229934a59f94f37ed896b133", size = 2720538, upload-time = "2024-08-06T00:31:18.905Z" }, - { url = "https://files.pythonhosted.org/packages/88/07/f83f2750d44ac4f06c07c37395b9c1383ef5c994745f73c6bfaf767f0944/grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141d028bf5762d4a97f981c501da873589df3f7e02f4c1260e1921e565b376fa", size = 3071571, upload-time = "2024-08-06T00:31:21.684Z" }, - { url = "https://files.pythonhosted.org/packages/37/74/40175897deb61e54aca716bc2e8919155b48f33aafec8043dda9592d8768/grpcio_tools-1.62.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47a5c093ab256dec5714a7a345f8cc89315cb57c298b276fa244f37a0ba507f0", size = 2806207, upload-time = "2024-08-06T00:31:24.208Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ee/d8de915105a217cbcb9084d684abdc032030dcd887277f2ef167372287fe/grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:f6831fdec2b853c9daa3358535c55eed3694325889aa714070528cf8f92d7d6d", size = 3685815, upload-time = "2024-08-06T00:31:26.917Z" }, - { url = "https://files.pythonhosted.org/packages/fd/d9/4360a6c12be3d7521b0b8c39e5d3801d622fbb81cc2721dbd3eee31e28c8/grpcio_tools-1.62.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e02d7c1a02e3814c94ba0cfe43d93e872c758bd8fd5c2797f894d0c49b4a1dfc", size = 3298378, upload-time = "2024-08-06T00:31:30.401Z" }, - { url = "https://files.pythonhosted.org/packages/29/3b/7cdf4a9e5a3e0a35a528b48b111355cd14da601413a4f887aa99b6da468f/grpcio_tools-1.62.3-cp312-cp312-win32.whl", hash = "sha256:b881fd9505a84457e9f7e99362eeedd86497b659030cf57c6f0070df6d9c2b9b", size = 910416, upload-time = "2024-08-06T00:31:33.118Z" }, - { url = "https://files.pythonhosted.org/packages/6c/66/dd3ec249e44c1cc15e902e783747819ed41ead1336fcba72bf841f72c6e9/grpcio_tools-1.62.3-cp312-cp312-win_amd64.whl", hash = "sha256:11c625eebefd1fd40a228fc8bae385e448c7e32a6ae134e43cf13bbc23f902b7", size = 1052856, upload-time = "2024-08-06T00:31:36.519Z" }, + { url = "https://files.pythonhosted.org/packages/17/e4/0568d38b8da6237ea8ea15abb960fb7ab83eb7bb51e0ea5926dab3d865b1/grpcio_tools-1.71.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:0acb8151ea866be5b35233877fbee6445c36644c0aa77e230c9d1b46bf34b18b", size = 2385557, upload-time = "2025-06-28T04:20:54.323Z" }, + { url = "https://files.pythonhosted.org/packages/76/fb/700d46f72b0f636cf0e625f3c18a4f74543ff127471377e49a071f64f1e7/grpcio_tools-1.71.2-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:b28f8606f4123edb4e6da281547465d6e449e89f0c943c376d1732dc65e6d8b3", size = 5447590, upload-time = "2025-06-28T04:20:55.836Z" }, + { url = "https://files.pythonhosted.org/packages/12/69/d9bb2aec3de305162b23c5c884b9f79b1a195d42b1e6dabcc084cc9d0804/grpcio_tools-1.71.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:cbae6f849ad2d1f5e26cd55448b9828e678cb947fa32c8729d01998238266a6a", size = 2348495, upload-time = "2025-06-28T04:20:57.33Z" }, + { url = "https://files.pythonhosted.org/packages/d5/83/f840aba1690461b65330efbca96170893ee02fae66651bcc75f28b33a46c/grpcio_tools-1.71.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4d1027615cfb1e9b1f31f2f384251c847d68c2f3e025697e5f5c72e26ed1316", size = 2742333, upload-time = "2025-06-28T04:20:59.051Z" }, + { url = "https://files.pythonhosted.org/packages/30/34/c02cd9b37de26045190ba665ee6ab8597d47f033d098968f812d253bbf8c/grpcio_tools-1.71.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9bac95662dc69338edb9eb727cc3dd92342131b84b12b3e8ec6abe973d4cbf1b", size = 2473490, upload-time = "2025-06-28T04:21:00.614Z" }, + { url = "https://files.pythonhosted.org/packages/4d/c7/375718ae091c8f5776828ce97bdcb014ca26244296f8b7f70af1a803ed2f/grpcio_tools-1.71.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c50250c7248055040f89eb29ecad39d3a260a4b6d3696af1575945f7a8d5dcdc", size = 2850333, upload-time = "2025-06-28T04:21:01.95Z" }, + { url = "https://files.pythonhosted.org/packages/19/37/efc69345bd92a73b2bc80f4f9e53d42dfdc234b2491ae58c87da20ca0ea5/grpcio_tools-1.71.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6ab1ad955e69027ef12ace4d700c5fc36341bdc2f420e87881e9d6d02af3d7b8", size = 3300748, upload-time = "2025-06-28T04:21:03.451Z" }, + { url = "https://files.pythonhosted.org/packages/d2/1f/15f787eb25ae42086f55ed3e4260e85f385921c788debf0f7583b34446e3/grpcio_tools-1.71.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dd75dde575781262b6b96cc6d0b2ac6002b2f50882bf5e06713f1bf364ee6e09", size = 2913178, upload-time = "2025-06-28T04:21:04.879Z" }, + { url = "https://files.pythonhosted.org/packages/12/aa/69cb3a9dff7d143a05e4021c3c9b5cde07aacb8eb1c892b7c5b9fb4973e3/grpcio_tools-1.71.2-cp311-cp311-win32.whl", hash = "sha256:9a3cb244d2bfe0d187f858c5408d17cb0e76ca60ec9a274c8fd94cc81457c7fc", size = 946256, upload-time = "2025-06-28T04:21:06.518Z" }, + { url = "https://files.pythonhosted.org/packages/1e/df/fb951c5c87eadb507a832243942e56e67d50d7667b0e5324616ffd51b845/grpcio_tools-1.71.2-cp311-cp311-win_amd64.whl", hash = "sha256:00eb909997fd359a39b789342b476cbe291f4dd9c01ae9887a474f35972a257e", size = 1117661, upload-time = "2025-06-28T04:21:08.18Z" }, + { url = "https://files.pythonhosted.org/packages/9c/d3/3ed30a9c5b2424627b4b8411e2cd6a1a3f997d3812dbc6a8630a78bcfe26/grpcio_tools-1.71.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:bfc0b5d289e383bc7d317f0e64c9dfb59dc4bef078ecd23afa1a816358fb1473", size = 2385479, upload-time = "2025-06-28T04:21:10.413Z" }, + { url = "https://files.pythonhosted.org/packages/54/61/e0b7295456c7e21ef777eae60403c06835160c8d0e1e58ebfc7d024c51d3/grpcio_tools-1.71.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:b4669827716355fa913b1376b1b985855d5cfdb63443f8d18faf210180199006", size = 5431521, upload-time = "2025-06-28T04:21:12.261Z" }, + { url = "https://files.pythonhosted.org/packages/75/d7/7bcad6bcc5f5b7fab53e6bce5db87041f38ef3e740b1ec2d8c49534fa286/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d4071f9b44564e3f75cdf0f05b10b3e8c7ea0ca5220acbf4dc50b148552eef2f", size = 2350289, upload-time = "2025-06-28T04:21:13.625Z" }, + { url = "https://files.pythonhosted.org/packages/b2/8a/e4c1c4cb8c9ff7f50b7b2bba94abe8d1e98ea05f52a5db476e7f1c1a3c70/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a28eda8137d587eb30081384c256f5e5de7feda34776f89848b846da64e4be35", size = 2743321, upload-time = "2025-06-28T04:21:15.007Z" }, + { url = "https://files.pythonhosted.org/packages/fd/aa/95bc77fda5c2d56fb4a318c1b22bdba8914d5d84602525c99047114de531/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b19c083198f5eb15cc69c0a2f2c415540cbc636bfe76cea268e5894f34023b40", size = 2474005, upload-time = "2025-06-28T04:21:16.443Z" }, + { url = "https://files.pythonhosted.org/packages/c9/ff/ca11f930fe1daa799ee0ce1ac9630d58a3a3deed3dd2f465edb9a32f299d/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:784c284acda0d925052be19053d35afbf78300f4d025836d424cf632404f676a", size = 2851559, upload-time = "2025-06-28T04:21:18.139Z" }, + { url = "https://files.pythonhosted.org/packages/64/10/c6fc97914c7e19c9bb061722e55052fa3f575165da9f6510e2038d6e8643/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:381e684d29a5d052194e095546eef067201f5af30fd99b07b5d94766f44bf1ae", size = 3300622, upload-time = "2025-06-28T04:21:20.291Z" }, + { url = "https://files.pythonhosted.org/packages/e5/d6/965f36cfc367c276799b730d5dd1311b90a54a33726e561393b808339b04/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3e4b4801fabd0427fc61d50d09588a01b1cfab0ec5e8a5f5d515fbdd0891fd11", size = 2913863, upload-time = "2025-06-28T04:21:22.196Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f0/c05d5c3d0c1d79ac87df964e9d36f1e3a77b60d948af65bec35d3e5c75a3/grpcio_tools-1.71.2-cp312-cp312-win32.whl", hash = "sha256:84ad86332c44572305138eafa4cc30040c9a5e81826993eae8227863b700b490", size = 945744, upload-time = "2025-06-28T04:21:23.463Z" }, + { url = "https://files.pythonhosted.org/packages/e2/e9/c84c1078f0b7af7d8a40f5214a9bdd8d2a567ad6c09975e6e2613a08d29d/grpcio_tools-1.71.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e1108d37eecc73b1c4a27350a6ed921b5dda25091700c1da17cfe30761cd462", size = 1117695, upload-time = "2025-06-28T04:21:25.22Z" }, ] [[package]] name = "gunicorn" -version = "23.0.0" +version = "25.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/34/72/9614c465dc206155d93eff0ca20d42e1e35afc533971379482de953521a4/gunicorn-23.0.0.tar.gz", hash = "sha256:f014447a0101dc57e294f6c18ca6b40227a4c90e9bdb586042628030cba004ec", size = 375031, upload-time = "2024-08-10T20:25:27.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/13/ef67f59f6a7896fdc2c1d62b5665c5219d6b0a9a1784938eb9a28e55e128/gunicorn-25.1.0.tar.gz", hash = "sha256:1426611d959fa77e7de89f8c0f32eed6aa03ee735f98c01efba3e281b1c47616", size = 594377, upload-time = "2026-02-13T11:09:58.989Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/7d/6dac2a6e1eba33ee43f318edbed4ff29151a49b5d37f080aad1e6469bca4/gunicorn-23.0.0-py3-none-any.whl", hash = "sha256:ec400d38950de4dfd418cff8328b2c8faed0edb0d517d3394e457c317908ca4d", size = 85029, upload-time = "2024-08-10T20:25:24.996Z" }, + { url = "https://files.pythonhosted.org/packages/da/73/4ad5b1f6a2e21cf1e85afdaad2b7b1a933985e2f5d679147a1953aaa192c/gunicorn-25.1.0-py3-none-any.whl", hash = "sha256:d0b1236ccf27f72cfe14bce7caadf467186f19e865094ca84221424e839b8b8b", size = 197067, upload-time = "2026-02-13T11:09:57.146Z" }, ] [[package]] @@ -2925,51 +3025,66 @@ wheels = [ [[package]] name = "hf-xet" -version = "1.2.0" +version = "1.4.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5e/6e/0f11bacf08a67f7fb5ee09740f2ca54163863b07b70d579356e9222ce5d8/hf_xet-1.2.0.tar.gz", hash = "sha256:a8c27070ca547293b6890c4bf389f713f80e8c478631432962bb7f4bc0bd7d7f", size = 506020, upload-time = "2025-10-24T19:04:32.129Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/08/23c84a26716382c89151b5b447b4beb19e3345f3a93d3b73009a71a57ad3/hf_xet-1.4.2.tar.gz", hash = "sha256:b7457b6b482d9e0743bd116363239b1fa904a5e65deede350fbc0c4ea67c71ea", size = 672357, upload-time = "2026-03-13T06:58:51.077Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/2d/22338486473df5923a9ab7107d375dbef9173c338ebef5098ef593d2b560/hf_xet-1.2.0-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:46740d4ac024a7ca9b22bebf77460ff43332868b661186a8e46c227fdae01848", size = 2866099, upload-time = "2025-10-24T19:04:15.366Z" }, - { url = "https://files.pythonhosted.org/packages/7f/8c/c5becfa53234299bc2210ba314eaaae36c2875e0045809b82e40a9544f0c/hf_xet-1.2.0-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:27df617a076420d8845bea087f59303da8be17ed7ec0cd7ee3b9b9f579dff0e4", size = 2722178, upload-time = "2025-10-24T19:04:13.695Z" }, - { url = "https://files.pythonhosted.org/packages/9a/92/cf3ab0b652b082e66876d08da57fcc6fa2f0e6c70dfbbafbd470bb73eb47/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3651fd5bfe0281951b988c0facbe726aa5e347b103a675f49a3fa8144c7968fd", size = 3320214, upload-time = "2025-10-24T19:04:03.596Z" }, - { url = "https://files.pythonhosted.org/packages/46/92/3f7ec4a1b6a65bf45b059b6d4a5d38988f63e193056de2f420137e3c3244/hf_xet-1.2.0-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d06fa97c8562fb3ee7a378dd9b51e343bc5bc8190254202c9771029152f5e08c", size = 3229054, upload-time = "2025-10-24T19:04:01.949Z" }, - { url = "https://files.pythonhosted.org/packages/0b/dd/7ac658d54b9fb7999a0ccb07ad863b413cbaf5cf172f48ebcd9497ec7263/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4c1428c9ae73ec0939410ec73023c4f842927f39db09b063b9482dac5a3bb737", size = 3413812, upload-time = "2025-10-24T19:04:24.585Z" }, - { url = "https://files.pythonhosted.org/packages/92/68/89ac4e5b12a9ff6286a12174c8538a5930e2ed662091dd2572bbe0a18c8a/hf_xet-1.2.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a55558084c16b09b5ed32ab9ed38421e2d87cf3f1f89815764d1177081b99865", size = 3508920, upload-time = "2025-10-24T19:04:26.927Z" }, - { url = "https://files.pythonhosted.org/packages/cb/44/870d44b30e1dcfb6a65932e3e1506c103a8a5aea9103c337e7a53180322c/hf_xet-1.2.0-cp37-abi3-win_amd64.whl", hash = "sha256:e6584a52253f72c9f52f9e549d5895ca7a471608495c4ecaa6cc73dba2b24d69", size = 2905735, upload-time = "2025-10-24T19:04:35.928Z" }, + { url = "https://files.pythonhosted.org/packages/b4/86/b40b83a2ff03ef05c4478d2672b1fc2b9683ff870e2b25f4f3af240f2e7b/hf_xet-1.4.2-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:71f02d6e4cdd07f344f6844845d78518cc7186bd2bc52d37c3b73dc26a3b0bc5", size = 3800339, upload-time = "2026-03-13T06:58:36.245Z" }, + { url = "https://files.pythonhosted.org/packages/64/2e/af4475c32b4378b0e92a587adb1aa3ec53e3450fd3e5fe0372a874531c00/hf_xet-1.4.2-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:e9b38d876e94d4bdcf650778d6ebbaa791dd28de08db9736c43faff06ede1b5a", size = 3559664, upload-time = "2026-03-13T06:58:34.787Z" }, + { url = "https://files.pythonhosted.org/packages/3c/4c/781267da3188db679e601de18112021a5cb16506fe86b246e22c5401a9c4/hf_xet-1.4.2-cp37-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:77e8c180b7ef12d8a96739a4e1e558847002afe9ea63b6f6358b2271a8bdda1c", size = 4217422, upload-time = "2026-03-13T06:58:27.472Z" }, + { url = "https://files.pythonhosted.org/packages/68/47/d6cf4a39ecf6c7705f887a46f6ef5c8455b44ad9eb0d391aa7e8a2ff7fea/hf_xet-1.4.2-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:c3b3c6a882016b94b6c210957502ff7877802d0dbda8ad142c8595db8b944271", size = 3992847, upload-time = "2026-03-13T06:58:25.989Z" }, + { url = "https://files.pythonhosted.org/packages/2d/ef/e80815061abff54697239803948abc665c6b1d237102c174f4f7a9a5ffc5/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9d9a634cc929cfbaf2e1a50c0e532ae8c78fa98618426769480c58501e8c8ac2", size = 4193843, upload-time = "2026-03-13T06:58:44.59Z" }, + { url = "https://files.pythonhosted.org/packages/54/75/07f6aa680575d9646c4167db6407c41340cbe2357f5654c4e72a1b01ca14/hf_xet-1.4.2-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6b0932eb8b10317ea78b7da6bab172b17be03bbcd7809383d8d5abd6a2233e04", size = 4432751, upload-time = "2026-03-13T06:58:46.533Z" }, + { url = "https://files.pythonhosted.org/packages/cd/71/193eabd7e7d4b903c4aa983a215509c6114915a5a237525ec562baddb868/hf_xet-1.4.2-cp37-abi3-win_amd64.whl", hash = "sha256:ad185719fb2e8ac26f88c8100562dbf9dbdcc3d9d2add00faa94b5f106aea53f", size = 3671149, upload-time = "2026-03-13T06:58:57.07Z" }, + { url = "https://files.pythonhosted.org/packages/b4/7e/ccf239da366b37ba7f0b36095450efae4a64980bdc7ec2f51354205fdf39/hf_xet-1.4.2-cp37-abi3-win_arm64.whl", hash = "sha256:32c012286b581f783653e718c1862aea5b9eb140631685bb0c5e7012c8719a87", size = 3533426, upload-time = "2026-03-13T06:58:55.46Z" }, ] [[package]] name = "hiredis" -version = "3.3.0" +version = "3.3.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/65/82/d2817ce0653628e0a0cb128533f6af0dd6318a49f3f3a6a7bd1f2f2154af/hiredis-3.3.0.tar.gz", hash = "sha256:105596aad9249634361815c574351f1bd50455dc23b537c2940066c4a9dea685", size = 89048, upload-time = "2025-10-14T16:33:34.263Z" } +sdist = { url = "https://files.pythonhosted.org/packages/97/d6/9bef6dc3052c168c93fbf7e6c0f2b12c45f0f741a2d30fd919096774343a/hiredis-3.3.1.tar.gz", hash = "sha256:da6f0302360e99d32bc2869772692797ebadd536e1b826d0103c72ba49d38698", size = 89101, upload-time = "2026-03-16T15:21:08.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/0c/be3b1093f93a7c823ca16fbfbb83d3a1de671bbd2add8da1fe2bcfccb2b8/hiredis-3.3.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:63ee6c1ae6a2462a2439eb93c38ab0315cd5f4b6d769c6a34903058ba538b5d6", size = 81813, upload-time = "2025-10-14T16:32:00.576Z" }, - { url = "https://files.pythonhosted.org/packages/95/2b/ed722d392ac59a7eee548d752506ef32c06ffdd0bce9cf91125a74b8edf9/hiredis-3.3.0-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:31eda3526e2065268a8f97fbe3d0e9a64ad26f1d89309e953c80885c511ea2ae", size = 46049, upload-time = "2025-10-14T16:32:01.319Z" }, - { url = "https://files.pythonhosted.org/packages/e5/61/8ace8027d5b3f6b28e1dc55f4a504be038ba8aa8bf71882b703e8f874c91/hiredis-3.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a26bae1b61b7bcafe3d0d0c7d012fb66ab3c95f2121dbea336df67e344e39089", size = 41814, upload-time = "2025-10-14T16:32:02.076Z" }, - { url = "https://files.pythonhosted.org/packages/23/0e/380ade1ffb21034976663a5128f0383533f35caccdba13ff0537dd5ace79/hiredis-3.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b9546079f7fd5c50fbff9c791710049b32eebe7f9b94debec1e8b9f4c048cba2", size = 167572, upload-time = "2025-10-14T16:32:03.125Z" }, - { url = "https://files.pythonhosted.org/packages/ca/60/b4a8d2177575b896730f73e6890644591aa56790a75c2b6d6f2302a1dae6/hiredis-3.3.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ae327fc13b1157b694d53f92d50920c0051e30b0c245f980a7036e299d039ab4", size = 179373, upload-time = "2025-10-14T16:32:04.04Z" }, - { url = "https://files.pythonhosted.org/packages/31/53/a473a18d27cfe8afda7772ff9adfba1718fd31d5e9c224589dc17774fa0b/hiredis-3.3.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4016e50a8be5740a59c5af5252e5ad16c395021a999ad24c6604f0d9faf4d346", size = 177504, upload-time = "2025-10-14T16:32:04.934Z" }, - { url = "https://files.pythonhosted.org/packages/7e/0f/f6ee4c26b149063dbf5b1b6894b4a7a1f00a50e3d0cfd30a22d4c3479db3/hiredis-3.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c17b473f273465a3d2168a57a5b43846165105ac217d5652a005e14068589ddc", size = 169449, upload-time = "2025-10-14T16:32:05.808Z" }, - { url = "https://files.pythonhosted.org/packages/64/38/e3e113172289e1261ccd43e387a577dd268b0b9270721b5678735803416c/hiredis-3.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9ecd9b09b11bd0b8af87d29c3f5da628d2bdc2a6c23d2dd264d2da082bd4bf32", size = 164010, upload-time = "2025-10-14T16:32:06.695Z" }, - { url = "https://files.pythonhosted.org/packages/8d/9a/ccf4999365691ea73d0dd2ee95ee6ef23ebc9a835a7417f81765bc49eade/hiredis-3.3.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:00fb04eac208cd575d14f246e74a468561081ce235937ab17d77cde73aefc66c", size = 174623, upload-time = "2025-10-14T16:32:07.627Z" }, - { url = "https://files.pythonhosted.org/packages/ed/c7/ee55fa2ade078b7c4f17e8ddc9bc28881d0b71b794ebf9db4cfe4c8f0623/hiredis-3.3.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:60814a7d0b718adf3bfe2c32c6878b0e00d6ae290ad8e47f60d7bba3941234a6", size = 167650, upload-time = "2025-10-14T16:32:08.615Z" }, - { url = "https://files.pythonhosted.org/packages/bf/06/f6cd90275dcb0ba03f69767805151eb60b602bc25830648bd607660e1f97/hiredis-3.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:fcbd1a15e935aa323b5b2534b38419511b7909b4b8ee548e42b59090a1b37bb1", size = 165452, upload-time = "2025-10-14T16:32:09.561Z" }, - { url = "https://files.pythonhosted.org/packages/c3/10/895177164a6c4409a07717b5ae058d84a908e1ab629f0401110b02aaadda/hiredis-3.3.0-cp311-cp311-win32.whl", hash = "sha256:73679607c5a19f4bcfc9cf6eb54480bcd26617b68708ac8b1079da9721be5449", size = 20394, upload-time = "2025-10-14T16:32:10.469Z" }, - { url = "https://files.pythonhosted.org/packages/3c/c7/1e8416ae4d4134cb62092c61cabd76b3d720507ee08edd19836cdeea4c7a/hiredis-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:30a4df3d48f32538de50648d44146231dde5ad7f84f8f08818820f426840ae97", size = 22336, upload-time = "2025-10-14T16:32:11.221Z" }, - { url = "https://files.pythonhosted.org/packages/48/1c/ed28ae5d704f5c7e85b946fa327f30d269e6272c847fef7e91ba5fc86193/hiredis-3.3.0-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:5b8e1d6a2277ec5b82af5dce11534d3ed5dffeb131fd9b210bc1940643b39b5f", size = 82026, upload-time = "2025-10-14T16:32:12.004Z" }, - { url = "https://files.pythonhosted.org/packages/f4/9b/79f30c5c40e248291023b7412bfdef4ad9a8a92d9e9285d65d600817dac7/hiredis-3.3.0-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:c4981de4d335f996822419e8a8b3b87367fcef67dc5fb74d3bff4df9f6f17783", size = 46217, upload-time = "2025-10-14T16:32:13.133Z" }, - { url = "https://files.pythonhosted.org/packages/e7/c3/02b9ed430ad9087aadd8afcdf616717452d16271b701fa47edfe257b681e/hiredis-3.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1706480a683e328ae9ba5d704629dee2298e75016aa0207e7067b9c40cecc271", size = 41858, upload-time = "2025-10-14T16:32:13.98Z" }, - { url = "https://files.pythonhosted.org/packages/f1/98/b2a42878b82130a535c7aa20bc937ba2d07d72e9af3ad1ad93e837c419b5/hiredis-3.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a95cef9989736ac313639f8f545b76b60b797e44e65834aabbb54e4fad8d6c8", size = 170195, upload-time = "2025-10-14T16:32:14.728Z" }, - { url = "https://files.pythonhosted.org/packages/66/1d/9dcde7a75115d3601b016113d9b90300726fa8e48aacdd11bf01a453c145/hiredis-3.3.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ca2802934557ccc28a954414c245ba7ad904718e9712cb67c05152cf6b9dd0a3", size = 181808, upload-time = "2025-10-14T16:32:15.622Z" }, - { url = "https://files.pythonhosted.org/packages/56/a1/60f6bda9b20b4e73c85f7f5f046bc2c154a5194fc94eb6861e1fd97ced52/hiredis-3.3.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:fe730716775f61e76d75810a38ee4c349d3af3896450f1525f5a4034cf8f2ed7", size = 180578, upload-time = "2025-10-14T16:32:16.514Z" }, - { url = "https://files.pythonhosted.org/packages/d9/01/859d21de65085f323a701824e23ea3330a0ac05f8e184544d7aa5c26128d/hiredis-3.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:749faa69b1ce1f741f5eaf743435ac261a9262e2d2d66089192477e7708a9abc", size = 172508, upload-time = "2025-10-14T16:32:17.411Z" }, - { url = "https://files.pythonhosted.org/packages/99/a8/28fd526e554c80853d0fbf57ef2a3235f00e4ed34ce0e622e05d27d0f788/hiredis-3.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:95c9427f2ac3f1dd016a3da4e1161fa9d82f221346c8f3fdd6f3f77d4e28946c", size = 166341, upload-time = "2025-10-14T16:32:18.561Z" }, - { url = "https://files.pythonhosted.org/packages/f2/91/ded746b7d2914f557fbbf77be55e90d21f34ba758ae10db6591927c642c8/hiredis-3.3.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c863ee44fe7bff25e41f3a5105c936a63938b76299b802d758f40994ab340071", size = 176765, upload-time = "2025-10-14T16:32:19.491Z" }, - { url = "https://files.pythonhosted.org/packages/d6/4c/04aa46ff386532cb5f08ee495c2bf07303e93c0acf2fa13850e031347372/hiredis-3.3.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:2213c7eb8ad5267434891f3241c7776e3bafd92b5933fc57d53d4456247dc542", size = 170312, upload-time = "2025-10-14T16:32:20.404Z" }, - { url = "https://files.pythonhosted.org/packages/90/6e/67f9d481c63f542a9cf4c9f0ea4e5717db0312fb6f37fb1f78f3a66de93c/hiredis-3.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a172bae3e2837d74530cd60b06b141005075db1b814d966755977c69bd882ce8", size = 167965, upload-time = "2025-10-14T16:32:21.259Z" }, - { url = "https://files.pythonhosted.org/packages/7a/df/dde65144d59c3c0d85e43255798f1fa0c48d413e668cfd92b3d9f87924ef/hiredis-3.3.0-cp312-cp312-win32.whl", hash = "sha256:cb91363b9fd6d41c80df9795e12fffbaf5c399819e6ae8120f414dedce6de068", size = 20533, upload-time = "2025-10-14T16:32:22.192Z" }, - { url = "https://files.pythonhosted.org/packages/f5/a9/55a4ac9c16fdf32e92e9e22c49f61affe5135e177ca19b014484e28950f7/hiredis-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:04ec150e95eea3de9ff8bac754978aa17b8bf30a86d4ab2689862020945396b0", size = 22379, upload-time = "2025-10-14T16:32:22.916Z" }, + { url = "https://files.pythonhosted.org/packages/8b/71/b8e7da87ff0a270e086670da46732ff8e0af2fb4042afe1486846cf44ea7/hiredis-3.3.1-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:26f899cde0279e4b7d370716ff80320601c2bd93cdf3e774a42bdd44f65b41f8", size = 81823, upload-time = "2026-03-16T15:19:20.139Z" }, + { url = "https://files.pythonhosted.org/packages/50/e0/8bdafc6251aada93c670eb1893335bb248e10faa784f54de6b9384c7d2cb/hiredis-3.3.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:a2f049c3f3c83e886cd1f53958e2a1ebb369be626bef9e50d8b24d79864f1df6", size = 46043, upload-time = "2026-03-16T15:19:21.292Z" }, + { url = "https://files.pythonhosted.org/packages/1b/e8/48e5eee6dffb2d5659f437231341bfbf00c53d9fdb5d069ea629f1b2fa96/hiredis-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5f316cf2d0558f5027aab19dde7d7e4901c26c21fa95367bc37784e8f547bbf2", size = 41813, upload-time = "2026-03-16T15:19:22.404Z" }, + { url = "https://files.pythonhosted.org/packages/d8/e0/8dcd593db6d0e91cd797fafc565995cd28bd9d7ae85807c820b5e245ab82/hiredis-3.3.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:03baa381964b8df356d19ec4e3a6ae656044249a87b0def257fe1e08dbaf6094", size = 167570, upload-time = "2026-03-16T15:19:23.328Z" }, + { url = "https://files.pythonhosted.org/packages/76/e5/e2d75ecc15db51117ebd260fab4059b8a4cbbf74eb89c407c6d437bd6413/hiredis-3.3.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:304481241e081bc26f0778b2c2b99f9c43917e4e724a016dcc9439b7ab12c726", size = 179373, upload-time = "2026-03-16T15:19:24.739Z" }, + { url = "https://files.pythonhosted.org/packages/86/9a/4fafde37b86f70125bcd01e8af5e9f448fc99f4116db6d0e9ad214fc688a/hiredis-3.3.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:8597c35c9e82f65fd5897c4a2188c65d7daf10607b102960137b23d261cd957b", size = 177501, upload-time = "2026-03-16T15:19:25.982Z" }, + { url = "https://files.pythonhosted.org/packages/a3/73/413a17d6926c015683a608c148862f1dc7e8ad6f5c205b626607be9b9ddf/hiredis-3.3.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ad940dc2db545dc978cb41cb9a683e2ff328f3ef581230b9ca40ff6c3d01d542", size = 169446, upload-time = "2026-03-16T15:19:27.35Z" }, + { url = "https://files.pythonhosted.org/packages/e3/39/aa8e41d5f728dfa99f2236c1176a6b348c7577fd68ca9960d20d251d3b29/hiredis-3.3.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:156be6a0c736ee145cfe0fb155d0e96cec8d4872cf8b4f76ad6a2ee6ab391d0a", size = 164009, upload-time = "2026-03-16T15:19:28.426Z" }, + { url = "https://files.pythonhosted.org/packages/c4/37/85a609a2cf2b6354749bcef8f488c3298976358601cb4906bbaf2eb53944/hiredis-3.3.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:583de2f16528e66081cbdfe510d8488c2de73039dc00aada7d22bd49d73a4a94", size = 174623, upload-time = "2026-03-16T15:19:29.466Z" }, + { url = "https://files.pythonhosted.org/packages/a3/5e/75e0a76e4c9021f9914cfa1de8d98cff4acd0a0eb3344d31f43c02ec9375/hiredis-3.3.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:c24c1460486b6b36083252c2db21a814becf8495ccd0e76b7286623e37239b63", size = 167649, upload-time = "2026-03-16T15:19:30.438Z" }, + { url = "https://files.pythonhosted.org/packages/f7/08/1212138ee61e9b72d3f561da60cf6dc15031c10117735938ac258613803f/hiredis-3.3.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a58a58cef0d911b1717154179a9ff47852249c536ea5966bde4370b6b20638ff", size = 165451, upload-time = "2026-03-16T15:19:31.404Z" }, + { url = "https://files.pythonhosted.org/packages/46/36/cd776ef13b44afbb86c3d63c1a76b09d54cb1b545cce9e26fcd439d69606/hiredis-3.3.1-cp311-cp311-win32.whl", hash = "sha256:e0db44cf81e4d7b94f3776b9f89111f74ed6bbdbfd42a22bc4a5ce0644d3e060", size = 20399, upload-time = "2026-03-16T15:19:32.44Z" }, + { url = "https://files.pythonhosted.org/packages/df/0e/5b2a73bea6d18e7ebda7ed73520854cdc176ba70a945bd541bdeeb3f8caa/hiredis-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:1f7bceb03a1b934872ffe3942eaeed7c7e09096e67b53f095b81f39c7a819113", size = 22336, upload-time = "2026-03-16T15:19:33.238Z" }, + { url = "https://files.pythonhosted.org/packages/b3/1d/1a7d925d886211948ab9cca44221b1d9dd4d3481d015511e98794e37d369/hiredis-3.3.1-cp312-cp312-macosx_10_15_universal2.whl", hash = "sha256:60543f3b068b16a86e99ed96b7fdae71cdc1d8abdfe9b3f82032a555e52ece7e", size = 82023, upload-time = "2026-03-16T15:19:34.157Z" }, + { url = "https://files.pythonhosted.org/packages/13/2f/a6017fe1db47cd63a4aefc0dd21dd4dcb0c4e857bfbcfaa27329745f24a3/hiredis-3.3.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:2611bfaaadc5e8d43fb7967f9bbf1110c8beaa83aee2f2d812c76f11cfb56c6a", size = 46215, upload-time = "2026-03-16T15:19:35.068Z" }, + { url = "https://files.pythonhosted.org/packages/77/4b/35a71d088c6934e162aa81c7e289fa3110a3aca84ab695d88dbd488c74a2/hiredis-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8e3754ce60e1b11b0afad9a053481ff184d2ee24bea47099107156d1b84a84aa", size = 41861, upload-time = "2026-03-16T15:19:36.32Z" }, + { url = "https://files.pythonhosted.org/packages/1f/54/904bc723a95926977764fefd6f0d46067579bac38fffc32b806f3f2c05c0/hiredis-3.3.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e89dabf436ee79b358fd970dcbed6333a36d91db73f27069ca24a02fb138a404", size = 170196, upload-time = "2026-03-16T15:19:37.274Z" }, + { url = "https://files.pythonhosted.org/packages/1d/01/4e840cd4cb53c28578234708b08fb9ec9e41c2880acc0e269a7264e1b3af/hiredis-3.3.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4f7e242eab698ad0be5a4b2ec616fa856569c57455cc67c625fd567726290e5f", size = 181808, upload-time = "2026-03-16T15:19:38.637Z" }, + { url = "https://files.pythonhosted.org/packages/87/0d/fc845f06f8203ab76c401d4d2b97f9fb768e644b053a40f441f7dcc71f2d/hiredis-3.3.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:53148a4e21057541b6d8e493b2ea1b500037ddf34433c391970036f3cbce00e3", size = 180577, upload-time = "2026-03-16T15:19:39.749Z" }, + { url = "https://files.pythonhosted.org/packages/52/3a/859afe2620666bf6d58eb977870c47d98af4999d473b50528b323918f3f7/hiredis-3.3.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c25132902d3eff38781e0d54f27a0942ec849e3c07dbdce83c4d92b7e43c8dce", size = 172507, upload-time = "2026-03-16T15:19:40.87Z" }, + { url = "https://files.pythonhosted.org/packages/60/a8/004349708ad8bf0d188d46049f846d3fe2d4a7a8d0d5a6a8ba024017d8b3/hiredis-3.3.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3fb6573efa15a29c12c0c0f7170b14e7c1347fe4bb39b6a15b779f46015cc929", size = 166339, upload-time = "2026-03-16T15:19:41.912Z" }, + { url = "https://files.pythonhosted.org/packages/c3/fb/bfc6df29381830c99bfd9e97ed3b6d75d9303866a28c23d51ab8c50f63e3/hiredis-3.3.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:487658e1db83c1ee9fbbac6a43039ea76957767a5987ffb16b590613f9e68297", size = 176766, upload-time = "2026-03-16T15:19:42.981Z" }, + { url = "https://files.pythonhosted.org/packages/53/e7/f54aaad4559a413ec8b1043a89567a5a1f898426e4091b9af5e0f2120371/hiredis-3.3.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a1d190790ee39b8b7adeeb10fc4090dc4859eb4e75ed27bd8108710eef18f358", size = 170313, upload-time = "2026-03-16T15:19:44.082Z" }, + { url = "https://files.pythonhosted.org/packages/60/51/b80394db4c74d4cba342fa4208f690a2739c16f1125c2a62ba1701b8e2b7/hiredis-3.3.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a42c7becd4c9ec4ab5769c754eb61112777bdc6e1c1525e2077389e193b5f5aa", size = 167964, upload-time = "2026-03-16T15:19:45.237Z" }, + { url = "https://files.pythonhosted.org/packages/47/ef/5e438d1e058be57cdc1bafc1b1ec8ab43cc890c61447e88f8b878a0e32c3/hiredis-3.3.1-cp312-cp312-win32.whl", hash = "sha256:17ec8b524055a88b80d76c177dbbbe475a25c17c5bf4b67bdbdbd0629bcae838", size = 20532, upload-time = "2026-03-16T15:19:46.233Z" }, + { url = "https://files.pythonhosted.org/packages/e9/c6/39994b9c5646e7bf7d5e92170c07fd5f224ae9f34d95ff202f31845eb94b/hiredis-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:0fac4af8515e6cca74fc701169ae4dc9a71a90e9319c9d21006ec9454b43aa2f", size = 22381, upload-time = "2026-03-16T15:19:47.082Z" }, +] + +[[package]] +name = "holo-search-sdk" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "psycopg", extra = ["binary"] }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/b8/70a4999dabbba15e98d201a7399aab76ab96931ad1a27392ba5252cc9165/holo_search_sdk-0.4.1.tar.gz", hash = "sha256:9aea98b6078b9202abb568ed69d798d5e0505d2b4cc3a136a6aa84402bcd2133", size = 56701, upload-time = "2026-01-28T01:44:57.645Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8a/30/3059a979272f90a96f31b167443cc27675e8cc8f970a3ac0cb80bf803c70/holo_search_sdk-0.4.1-py3-none-any.whl", hash = "sha256:ef1059895ea936ff6a087f68dac92bd1ae0320e51ec5b1d4e7bed7a5dd6beb45", size = 32647, upload-time = "2026-01-28T01:44:56.098Z" }, ] [[package]] @@ -3009,14 +3124,14 @@ wheels = [ [[package]] name = "httplib2" -version = "0.31.0" +version = "0.31.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pyparsing" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/52/77/6653db69c1f7ecfe5e3f9726fdadc981794656fcd7d98c4209fecfea9993/httplib2-0.31.0.tar.gz", hash = "sha256:ac7ab497c50975147d4f7b1ade44becc7df2f8954d42b38b3d69c515f531135c", size = 250759, upload-time = "2025-09-11T12:16:03.403Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c1/1f/e86365613582c027dda5ddb64e1010e57a3d53e99ab8a72093fa13d565ec/httplib2-0.31.2.tar.gz", hash = "sha256:385e0869d7397484f4eab426197a4c020b606edd43372492337c0b4010ae5d24", size = 250800, upload-time = "2026-01-23T11:04:44.165Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8c/a2/0d269db0f6163be503775dc8b6a6fa15820cc9fdc866f6ba608d86b721f2/httplib2-0.31.0-py3-none-any.whl", hash = "sha256:b9cd78abea9b4e43a7714c6e0f8b6b8561a6fc1e95d5dbd367f5bf0ef35f5d24", size = 91148, upload-time = "2025-09-11T12:16:01.803Z" }, + { url = "https://files.pythonhosted.org/packages/2f/90/fd509079dfcab01102c0fdd87f3a9506894bc70afcf9e9785ef6b2b3aff6/httplib2-0.31.2-py3-none-any.whl", hash = "sha256:dbf0c2fa3862acf3c55c078ea9c0bc4481d7dc5117cae71be9514912cf9f8349", size = 91099, upload-time = "2026-01-23T11:04:42.78Z" }, ] [[package]] @@ -3043,18 +3158,17 @@ wheels = [ [[package]] name = "httpx" -version = "0.27.2" +version = "0.28.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "certifi" }, { name = "httpcore" }, { name = "idna" }, - { name = "sniffio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/78/82/08f8c936781f67d9e6b9eeb8a0c8b4e406136ea4c3d1f89a5db71d42e0e6/httpx-0.27.2.tar.gz", hash = "sha256:f7c2be1d2f3c3c3160d441802406b206c2b76f5947b11115e6df10c6c65e66c2", size = 144189, upload-time = "2024-08-27T12:54:01.334Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/df/48c586a5fe32a0f01324ee087459e112ebb7224f646c0b5023f5e79e9956/httpx-0.28.1.tar.gz", hash = "sha256:75e98c5f16b0f35b567856f597f06ff2270a374470a5c2392242528e3e3e42fc", size = 141406, upload-time = "2024-12-06T15:37:23.222Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/95/9377bcb415797e44274b51d46e3249eba641711cf3348050f76ee7b15ffc/httpx-0.27.2-py3-none-any.whl", hash = "sha256:7bb2708e112d8fdd7829cd4243970f0c223274051cb35ee80c03301ee29a3df0", size = 76395, upload-time = "2024-08-27T12:53:59.653Z" }, + { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" }, ] [package.optional-dependencies] @@ -3076,33 +3190,22 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.36.0" +version = "1.7.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, - { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, + { name = "httpx" }, { name = "packaging" }, { name = "pyyaml" }, - { name = "requests" }, { name = "tqdm" }, + { name = "typer" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/98/63/4910c5fa9128fdadf6a9c5ac138e8b1b6cee4ca44bf7915bbfbce4e355ee/huggingface_hub-0.36.0.tar.gz", hash = "sha256:47b3f0e2539c39bf5cde015d63b72ec49baff67b6931c3d97f3f84532e2b8d25", size = 463358, upload-time = "2025-10-23T12:12:01.413Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/15/eafc1c57bf0f8afffb243dcd4c0cceb785e956acc17bba4d9bf2ae21fc9c/huggingface_hub-1.7.2.tar.gz", hash = "sha256:7f7e294e9bbb822e025bdb2ada025fa4344d978175a7f78e824d86e35f7ab43b", size = 724684, upload-time = "2026-03-20T10:36:08.767Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cb/bd/1a875e0d592d447cbc02805fd3fe0f497714d6a2583f59d14fa9ebad96eb/huggingface_hub-0.36.0-py3-none-any.whl", hash = "sha256:7bcc9ad17d5b3f07b57c78e79d527102d08313caa278a641993acddcb894548d", size = 566094, upload-time = "2025-10-23T12:11:59.557Z" }, -] - -[[package]] -name = "humanfriendly" -version = "10.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyreadline3", marker = "sys_platform == 'win32'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/cc/3f/2c29224acb2e2df4d2046e4c73ee2662023c58ff5b113c4c1adac0886c43/humanfriendly-10.0.tar.gz", hash = "sha256:6b0b831ce8f15f7300721aa49829fc4e83921a9a301cc7f606be6686a2288ddc", size = 360702, upload-time = "2021-09-17T21:40:43.31Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/0f/310fb31e39e2d734ccaa2c0fb981ee41f7bd5056ce9bc29b2248bd569169/humanfriendly-10.0-py2.py3-none-any.whl", hash = "sha256:1697e1a8a8f550fd43c2865cd84542fc175a61dcb779b6fee18cf6b6ccba1477", size = 86794, upload-time = "2021-09-17T21:40:39.897Z" }, + { url = "https://files.pythonhosted.org/packages/08/de/3ad061a05f74728927ded48c90b73521b9a9328c85d841bdefb30e01fb85/huggingface_hub-1.7.2-py3-none-any.whl", hash = "sha256:288f33a0a17b2a73a1359e2a5fd28d1becb2c121748c6173ab8643fb342c850e", size = 618036, upload-time = "2026-03-20T10:36:06.824Z" }, ] [[package]] @@ -3116,23 +3219,14 @@ wheels = [ [[package]] name = "hypothesis" -version = "6.148.2" +version = "6.151.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4a/99/a3c6eb3fdd6bfa01433d674b0f12cd9102aa99630689427422d920aea9c6/hypothesis-6.148.2.tar.gz", hash = "sha256:07e65d34d687ddff3e92a3ac6b43966c193356896813aec79f0a611c5018f4b1", size = 469984, upload-time = "2025-11-18T20:21:17.047Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/e1/ef365ff480903b929d28e057f57b76cae51a30375943e33374ec9a165d9c/hypothesis-6.151.9.tar.gz", hash = "sha256:2f284428dda6c3c48c580de0e18470ff9c7f5ef628a647ee8002f38c3f9097ca", size = 463534, upload-time = "2026-02-16T22:59:23.09Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/d2/c2673aca0127e204965e0e9b3b7a0e91e9b12993859ac8758abd22669b89/hypothesis-6.148.2-py3-none-any.whl", hash = "sha256:bf8ddc829009da73b321994b902b1964bcc3e5c3f0ed9a1c1e6a1631ab97c5fa", size = 536986, upload-time = "2025-11-18T20:21:15.212Z" }, -] - -[[package]] -name = "identify" -version = "2.6.17" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/57/84/376a3b96e5a8d33a7aa2c5b3b31a4b3c364117184bf0b17418055f6ace66/identify-2.6.17.tar.gz", hash = "sha256:f816b0b596b204c9fdf076ded172322f2723cf958d02f9c3587504834c8ff04d", size = 99579, upload-time = "2026-03-01T20:04:12.702Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/40/66/71c1227dff78aaeb942fed29dd5651f2aec166cc7c9aeea3e8b26a539b7d/identify-2.6.17-py2.py3-none-any.whl", hash = "sha256:be5f8412d5ed4b20f2bd41a65f920990bdccaa6a4a18a08f1eefdcd0bdd885f0", size = 99382, upload-time = "2026-03-01T20:04:11.439Z" }, + { url = "https://files.pythonhosted.org/packages/c4/f7/5cc291d701094754a1d327b44d80a44971e13962881d9a400235726171da/hypothesis-6.151.9-py3-none-any.whl", hash = "sha256:7b7220585c67759b1b1ef839b1e6e9e3d82ed468cfc1ece43c67184848d7edd9", size = 529307, upload-time = "2026-02-16T22:59:20.443Z" }, ] [[package]] @@ -3146,28 +3240,29 @@ wheels = [ [[package]] name = "import-linter" -version = "2.6" +version = "2.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "grimp" }, + { name = "rich" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1d/20/37f3661ccbdba41072a74cb7a57a932b6884ab6c489318903d2d870c6c07/import_linter-2.6.tar.gz", hash = "sha256:60429a450eb6ebeed536f6d2b83428b026c5747ca69d029812e2f1360b136f85", size = 161294, upload-time = "2025-11-10T09:59:20.977Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ba/66/55b697a17bb15c6cb88d97d73716813f5427281527b90f02cc0a600abc6e/import_linter-2.11.tar.gz", hash = "sha256:5abc3394797a54f9bae315e7242dc98715ba485f840ac38c6d3192c370d0085e", size = 1153682, upload-time = "2026-03-06T12:11:38.198Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/df/02389e13d340229baa687bd0b9be4878e13668ce0beadbe531fb2b597386/import_linter-2.6-py3-none-any.whl", hash = "sha256:4e835141294b803325a619b8c789398320b81f0bde7771e0dd36f34524e51b1e", size = 46488, upload-time = "2025-11-10T09:59:19.611Z" }, + { url = "https://files.pythonhosted.org/packages/e9/aa/2ed2c89543632ded7196e0d93dcc6c7fe87769e88391a648c4a298ea864a/import_linter-2.11-py3-none-any.whl", hash = "sha256:3dc54cae933bae3430358c30989762b721c77aa99d424f56a08265be0eeaa465", size = 637315, upload-time = "2026-03-06T12:11:36.599Z" }, ] [[package]] name = "importlib-metadata" -version = "8.4.0" +version = "8.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "zipp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/bd/fa8ce65b0a7d4b6d143ec23b0f5fd3f7ab80121078c465bc02baeaab22dc/importlib_metadata-8.4.0.tar.gz", hash = "sha256:9a547d3bc3608b025f93d403fdd1aae741c24fbb8314df4b155675742ce303c5", size = 54320, upload-time = "2024-08-20T17:11:42.348Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/12/33e59336dca5be0c398a7482335911a33aa0e20776128f038019f1a95f1b/importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7", size = 55304, upload-time = "2024-09-11T14:56:08.937Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/14/362d31bf1076b21e1bcdcb0dc61944822ff263937b804a79231df2774d28/importlib_metadata-8.4.0-py3-none-any.whl", hash = "sha256:66f342cc6ac9818fc6ff340576acd24d65ba0b3efabb2b4ac08b598965a4a2f1", size = 26269, upload-time = "2024-08-20T17:11:41.102Z" }, + { url = "https://files.pythonhosted.org/packages/a0/d9/a1e041c5e7caa9a05c925f4bdbdfb7f006d1f74996af53467bc394c97be7/importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b", size = 26514, upload-time = "2024-09-11T14:56:07.019Z" }, ] [[package]] @@ -3188,9 +3283,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/cb/b1/3846dd7f199d53cb17f49cba7e651e9ce294d8497c8c150530ed11865bb8/iniconfig-2.3.0-py3-none-any.whl", hash = "sha256:f631c04d2c48c52b84d0d0549c99ff3859c98df65b3101406327ecc7d53fbf12", size = 7484, upload-time = "2025-10-18T21:55:41.639Z" }, ] +[[package]] +name = "installer" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/05/18/ceeb4e3ab3aa54495775775b38ae42b10a92f42ce42dfa44da684289b8c8/installer-0.7.0.tar.gz", hash = "sha256:a26d3e3116289bb08216e0d0f7d925fcef0b0194eedfa0c944bcaaa106c4b631", size = 474349, upload-time = "2023-03-17T20:39:38.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/ca/1172b6638d52f2d6caa2dd262ec4c811ba59eee96d54a7701930726bce18/installer-0.7.0-py3-none-any.whl", hash = "sha256:05d1933f0a5ba7d8d6296bb6d5018e7c94fa473ceb10cf198a92ccea19c27b53", size = 453838, upload-time = "2023-03-17T20:39:36.219Z" }, +] + [[package]] name = "instructor" -version = "1.12.0" +version = "1.14.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3199,7 +3303,6 @@ dependencies = [ { name = "jinja2" }, { name = "jiter" }, { name = "openai" }, - { name = "pre-commit" }, { name = "pydantic" }, { name = "pydantic-core" }, { name = "requests" }, @@ -3207,9 +3310,9 @@ dependencies = [ { name = "tenacity" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f8/4d/cc37bc2bb0fcd9584f4935ecb5f4b23d33c63ddeea20d899d4d99f72a69a/instructor-1.12.0.tar.gz", hash = "sha256:f0e4dd7f275120f49200df0204af6a2d4e3e2f1f698b6b8c0f776e3a8c977e54", size = 69892486, upload-time = "2025-10-27T18:47:55.191Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/ef/986d059424db204ed57b29d8c07fda35de2a2c72dee8ea7994bc90a6f767/instructor-1.14.5.tar.gz", hash = "sha256:fcb6432867f2fe4a5986e8bf389dcc64ed2ad4039a12a2dff85464e51c2f171a", size = 69950754, upload-time = "2026-01-29T14:18:50.454Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/8a/af9e30cd9ec64ab595a39996fe761cf2c7ce47475a9607559e3ddf25104a/instructor-1.12.0-py3-none-any.whl", hash = "sha256:88c2161c5ac7ccb60f9b9fc3e93e6a5750a0a28f2927d835b7d198018c3165d9", size = 157906, upload-time = "2025-10-27T18:47:52.007Z" }, + { url = "https://files.pythonhosted.org/packages/45/04/e442e1356c97b03a6d30d2b462f7c0bdfbf207e75f6833815fd1225a75b4/instructor-1.14.5-py3-none-any.whl", hash = "sha256:2a5a31222b008c0989be1cc001e33a237f49506e80ac5833f6d36d7690bae7b1", size = 177445, upload-time = "2026-01-29T14:18:53.641Z" }, ] [[package]] @@ -3226,12 +3329,15 @@ wheels = [ [[package]] name = "intervaltree" -version = "3.1.0" +version = "3.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/50/fb/396d568039d21344639db96d940d40eb62befe704ef849b27949ded5c3bb/intervaltree-3.1.0.tar.gz", hash = "sha256:902b1b88936918f9b2a19e0e5eb7ccb430ae45cde4f39ea4b36932920d33952d", size = 32861, upload-time = "2020-08-03T08:01:11.392Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/c3/b2afa612aa0373f3e6bb190e6de35f293b307d1537f109e3e25dbfcdf212/intervaltree-3.2.1.tar.gz", hash = "sha256:f3f7e8baeb7dd75b9f7a6d33cf3ec10025984a8e66e3016d537e52130c73cfe2", size = 1231531, upload-time = "2025-12-24T04:25:06.773Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/7f/8a80a1c7c2ed05822b5a2b312d2995f30c533641f8198366ba2e26a7bb03/intervaltree-3.2.1-py2.py3-none-any.whl", hash = "sha256:a8a8381bbd35d48ceebee932c77ffc988492d22fb1d27d0ba1d74a7694eb8f0b", size = 25929, upload-time = "2025-12-24T04:25:05.298Z" }, +] [[package]] name = "isodate" @@ -3251,51 +3357,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/04/96/92447566d16df59b2a776c0fb82dbc4d9e07cd95062562af01e408583fc4/itsdangerous-2.2.0-py3-none-any.whl", hash = "sha256:c6242fc49e35958c8b15141343aa660db5fc54d4f13a1db01a3f5891b98700ef", size = 16234, upload-time = "2024-04-16T21:28:14.499Z" }, ] -[[package]] -name = "jaraco-classes" -version = "3.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "more-itertools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/06/c0/ed4a27bc5571b99e3cff68f8a9fa5b56ff7df1c2251cc715a652ddd26402/jaraco.classes-3.4.0.tar.gz", hash = "sha256:47a024b51d0239c0dd8c8540c6c7f484be3b8fcf0b2d85c13825780d3b3f3acd", size = 11780, upload-time = "2024-03-31T07:27:36.643Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/66/b15ce62552d84bbfcec9a4873ab79d993a1dd4edb922cbfccae192bd5b5f/jaraco.classes-3.4.0-py3-none-any.whl", hash = "sha256:f662826b6bed8cace05e7ff873ce0f9283b5c924470fe664fff1c2f00f581790", size = 6777, upload-time = "2024-03-31T07:27:34.792Z" }, -] - -[[package]] -name = "jaraco-context" -version = "6.1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "backports-tarfile", marker = "python_full_version < '3.12'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/27/7b/c3081ff1af947915503121c649f26a778e1a2101fd525f74aef997d75b7e/jaraco_context-6.1.1.tar.gz", hash = "sha256:bc046b2dc94f1e5532bd02402684414575cc11f565d929b6563125deb0a6e581", size = 15832, upload-time = "2026-03-07T15:46:04.63Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/f4/49/c152890d49102b280ecf86ba5f80a8c111c3a155dafa3bd24aeb64fde9e1/jaraco_context-6.1.1-py3-none-any.whl", hash = "sha256:0df6a0287258f3e364072c3e40d5411b20cafa30cb28c4839d24319cecf9f808", size = 7005, upload-time = "2026-03-07T15:46:03.515Z" }, -] - -[[package]] -name = "jaraco-functools" -version = "4.4.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "more-itertools" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/0f/27/056e0638a86749374d6f57d0b0db39f29509cce9313cf91bdc0ac4d91084/jaraco_functools-4.4.0.tar.gz", hash = "sha256:da21933b0417b89515562656547a77b4931f98176eb173644c0d35032a33d6bb", size = 19943, upload-time = "2025-12-21T09:29:43.6Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/c4/813bb09f0985cb21e959f21f2464169eca882656849adf727ac7bb7e1767/jaraco_functools-4.4.0-py3-none-any.whl", hash = "sha256:9eec1e36f45c818d9bf307c8948eb03b2b56cd44087b3cdc989abca1f20b9176", size = 10481, upload-time = "2025-12-21T09:29:42.27Z" }, -] - -[[package]] -name = "jeepney" -version = "0.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/7b/6f/357efd7602486741aa73ffc0617fb310a29b588ed0fd69c2399acbb85b0c/jeepney-0.9.0.tar.gz", hash = "sha256:cf0e9e845622b81e4a28df94c40345400256ec608d0e55bb8a3feaa9163f5732", size = 106758, upload-time = "2025-02-27T18:51:01.684Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/a3/e137168c9c44d18eff0376253da9f1e9234d0239e0ee230d2fee6cea8e55/jeepney-0.9.0-py3-none-any.whl", hash = "sha256:97e5714520c16fc0a45695e5365a2e11b81ea79bba796e26f9f1d178cb182683", size = 49010, upload-time = "2025-02-27T18:51:00.104Z" }, -] - [[package]] name = "jieba" version = "0.42.1" @@ -3316,34 +3377,44 @@ wheels = [ [[package]] name = "jiter" -version = "0.10.0" +version = "0.11.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ee/9d/ae7ddb4b8ab3fb1b51faf4deb36cb48a4fbbd7cb36bad6a5fca4741306f7/jiter-0.10.0.tar.gz", hash = "sha256:07a7142c38aacc85194391108dc91b5b57093c978a9932bd86a36862759d9500", size = 162759, upload-time = "2025-05-18T19:04:59.73Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/68/0357982493a7b20925aece061f7fb7a2678e3b232f8d73a6edb7e5304443/jiter-0.11.1.tar.gz", hash = "sha256:849dcfc76481c0ea0099391235b7ca97d7279e0fa4c86005457ac7c88e8b76dc", size = 168385, upload-time = "2025-10-17T11:31:15.186Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1b/dd/6cefc6bd68b1c3c979cecfa7029ab582b57690a31cd2f346c4d0ce7951b6/jiter-0.10.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:3bebe0c558e19902c96e99217e0b8e8b17d570906e72ed8a87170bc290b1e978", size = 317473, upload-time = "2025-05-18T19:03:25.942Z" }, - { url = "https://files.pythonhosted.org/packages/be/cf/fc33f5159ce132be1d8dd57251a1ec7a631c7df4bd11e1cd198308c6ae32/jiter-0.10.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:558cc7e44fd8e507a236bee6a02fa17199ba752874400a0ca6cd6e2196cdb7dc", size = 321971, upload-time = "2025-05-18T19:03:27.255Z" }, - { url = "https://files.pythonhosted.org/packages/68/a4/da3f150cf1d51f6c472616fb7650429c7ce053e0c962b41b68557fdf6379/jiter-0.10.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d613e4b379a07d7c8453c5712ce7014e86c6ac93d990a0b8e7377e18505e98d", size = 345574, upload-time = "2025-05-18T19:03:28.63Z" }, - { url = "https://files.pythonhosted.org/packages/84/34/6e8d412e60ff06b186040e77da5f83bc158e9735759fcae65b37d681f28b/jiter-0.10.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f62cf8ba0618eda841b9bf61797f21c5ebd15a7a1e19daab76e4e4b498d515b2", size = 371028, upload-time = "2025-05-18T19:03:30.292Z" }, - { url = "https://files.pythonhosted.org/packages/fb/d9/9ee86173aae4576c35a2f50ae930d2ccb4c4c236f6cb9353267aa1d626b7/jiter-0.10.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:919d139cdfa8ae8945112398511cb7fca58a77382617d279556b344867a37e61", size = 491083, upload-time = "2025-05-18T19:03:31.654Z" }, - { url = "https://files.pythonhosted.org/packages/d9/2c/f955de55e74771493ac9e188b0f731524c6a995dffdcb8c255b89c6fb74b/jiter-0.10.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13ddbc6ae311175a3b03bd8994881bc4635c923754932918e18da841632349db", size = 388821, upload-time = "2025-05-18T19:03:33.184Z" }, - { url = "https://files.pythonhosted.org/packages/81/5a/0e73541b6edd3f4aada586c24e50626c7815c561a7ba337d6a7eb0a915b4/jiter-0.10.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4c440ea003ad10927a30521a9062ce10b5479592e8a70da27f21eeb457b4a9c5", size = 352174, upload-time = "2025-05-18T19:03:34.965Z" }, - { url = "https://files.pythonhosted.org/packages/1c/c0/61eeec33b8c75b31cae42be14d44f9e6fe3ac15a4e58010256ac3abf3638/jiter-0.10.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dc347c87944983481e138dea467c0551080c86b9d21de6ea9306efb12ca8f606", size = 391869, upload-time = "2025-05-18T19:03:36.436Z" }, - { url = "https://files.pythonhosted.org/packages/41/22/5beb5ee4ad4ef7d86f5ea5b4509f680a20706c4a7659e74344777efb7739/jiter-0.10.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:13252b58c1f4d8c5b63ab103c03d909e8e1e7842d302473f482915d95fefd605", size = 523741, upload-time = "2025-05-18T19:03:38.168Z" }, - { url = "https://files.pythonhosted.org/packages/ea/10/768e8818538e5817c637b0df52e54366ec4cebc3346108a4457ea7a98f32/jiter-0.10.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:7d1bbf3c465de4a24ab12fb7766a0003f6f9bce48b8b6a886158c4d569452dc5", size = 514527, upload-time = "2025-05-18T19:03:39.577Z" }, - { url = "https://files.pythonhosted.org/packages/73/6d/29b7c2dc76ce93cbedabfd842fc9096d01a0550c52692dfc33d3cc889815/jiter-0.10.0-cp311-cp311-win32.whl", hash = "sha256:db16e4848b7e826edca4ccdd5b145939758dadf0dc06e7007ad0e9cfb5928ae7", size = 210765, upload-time = "2025-05-18T19:03:41.271Z" }, - { url = "https://files.pythonhosted.org/packages/c2/c9/d394706deb4c660137caf13e33d05a031d734eb99c051142e039d8ceb794/jiter-0.10.0-cp311-cp311-win_amd64.whl", hash = "sha256:9c9c1d5f10e18909e993f9641f12fe1c77b3e9b533ee94ffa970acc14ded3812", size = 209234, upload-time = "2025-05-18T19:03:42.918Z" }, - { url = "https://files.pythonhosted.org/packages/6d/b5/348b3313c58f5fbfb2194eb4d07e46a35748ba6e5b3b3046143f3040bafa/jiter-0.10.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1e274728e4a5345a6dde2d343c8da018b9d4bd4350f5a472fa91f66fda44911b", size = 312262, upload-time = "2025-05-18T19:03:44.637Z" }, - { url = "https://files.pythonhosted.org/packages/9c/4a/6a2397096162b21645162825f058d1709a02965606e537e3304b02742e9b/jiter-0.10.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7202ae396446c988cb2a5feb33a543ab2165b786ac97f53b59aafb803fef0744", size = 320124, upload-time = "2025-05-18T19:03:46.341Z" }, - { url = "https://files.pythonhosted.org/packages/2a/85/1ce02cade7516b726dd88f59a4ee46914bf79d1676d1228ef2002ed2f1c9/jiter-0.10.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23ba7722d6748b6920ed02a8f1726fb4b33e0fd2f3f621816a8b486c66410ab2", size = 345330, upload-time = "2025-05-18T19:03:47.596Z" }, - { url = "https://files.pythonhosted.org/packages/75/d0/bb6b4f209a77190ce10ea8d7e50bf3725fc16d3372d0a9f11985a2b23eff/jiter-0.10.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:371eab43c0a288537d30e1f0b193bc4eca90439fc08a022dd83e5e07500ed026", size = 369670, upload-time = "2025-05-18T19:03:49.334Z" }, - { url = "https://files.pythonhosted.org/packages/a0/f5/a61787da9b8847a601e6827fbc42ecb12be2c925ced3252c8ffcb56afcaf/jiter-0.10.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c675736059020365cebc845a820214765162728b51ab1e03a1b7b3abb70f74c", size = 489057, upload-time = "2025-05-18T19:03:50.66Z" }, - { url = "https://files.pythonhosted.org/packages/12/e4/6f906272810a7b21406c760a53aadbe52e99ee070fc5c0cb191e316de30b/jiter-0.10.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0c5867d40ab716e4684858e4887489685968a47e3ba222e44cde6e4a2154f959", size = 389372, upload-time = "2025-05-18T19:03:51.98Z" }, - { url = "https://files.pythonhosted.org/packages/e2/ba/77013b0b8ba904bf3762f11e0129b8928bff7f978a81838dfcc958ad5728/jiter-0.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:395bb9a26111b60141757d874d27fdea01b17e8fac958b91c20128ba8f4acc8a", size = 352038, upload-time = "2025-05-18T19:03:53.703Z" }, - { url = "https://files.pythonhosted.org/packages/67/27/c62568e3ccb03368dbcc44a1ef3a423cb86778a4389e995125d3d1aaa0a4/jiter-0.10.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6842184aed5cdb07e0c7e20e5bdcfafe33515ee1741a6835353bb45fe5d1bd95", size = 391538, upload-time = "2025-05-18T19:03:55.046Z" }, - { url = "https://files.pythonhosted.org/packages/c0/72/0d6b7e31fc17a8fdce76164884edef0698ba556b8eb0af9546ae1a06b91d/jiter-0.10.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:62755d1bcea9876770d4df713d82606c8c1a3dca88ff39046b85a048566d56ea", size = 523557, upload-time = "2025-05-18T19:03:56.386Z" }, - { url = "https://files.pythonhosted.org/packages/2f/09/bc1661fbbcbeb6244bd2904ff3a06f340aa77a2b94e5a7373fd165960ea3/jiter-0.10.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:533efbce2cacec78d5ba73a41756beff8431dfa1694b6346ce7af3a12c42202b", size = 514202, upload-time = "2025-05-18T19:03:57.675Z" }, - { url = "https://files.pythonhosted.org/packages/1b/84/5a5d5400e9d4d54b8004c9673bbe4403928a00d28529ff35b19e9d176b19/jiter-0.10.0-cp312-cp312-win32.whl", hash = "sha256:8be921f0cadd245e981b964dfbcd6fd4bc4e254cdc069490416dd7a2632ecc01", size = 211781, upload-time = "2025-05-18T19:03:59.025Z" }, - { url = "https://files.pythonhosted.org/packages/9b/52/7ec47455e26f2d6e5f2ea4951a0652c06e5b995c291f723973ae9e724a65/jiter-0.10.0-cp312-cp312-win_amd64.whl", hash = "sha256:a7c7d785ae9dda68c2678532a5a1581347e9c15362ae9f6e68f3fdbfb64f2e49", size = 206176, upload-time = "2025-05-18T19:04:00.305Z" }, + { url = "https://files.pythonhosted.org/packages/8b/34/c9e6cfe876f9a24f43ed53fe29f052ce02bd8d5f5a387dbf46ad3764bef0/jiter-0.11.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:9b0088ff3c374ce8ce0168523ec8e97122ebb788f950cf7bb8e39c7dc6a876a2", size = 310160, upload-time = "2025-10-17T11:28:59.174Z" }, + { url = "https://files.pythonhosted.org/packages/bc/9f/b06ec8181d7165858faf2ac5287c54fe52b2287760b7fe1ba9c06890255f/jiter-0.11.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:74433962dd3c3090655e02e461267095d6c84f0741c7827de11022ef8d7ff661", size = 316573, upload-time = "2025-10-17T11:29:00.905Z" }, + { url = "https://files.pythonhosted.org/packages/66/49/3179d93090f2ed0c6b091a9c210f266d2d020d82c96f753260af536371d0/jiter-0.11.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6d98030e345e6546df2cc2c08309c502466c66c4747b043f1a0d415fada862b8", size = 348998, upload-time = "2025-10-17T11:29:02.321Z" }, + { url = "https://files.pythonhosted.org/packages/ae/9d/63db2c8eabda7a9cad65a2e808ca34aaa8689d98d498f5a2357d7a2e2cec/jiter-0.11.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d6db0b2e788db46bec2cf729a88b6dd36959af2abd9fa2312dfba5acdd96dcb", size = 363413, upload-time = "2025-10-17T11:29:03.787Z" }, + { url = "https://files.pythonhosted.org/packages/25/ff/3e6b3170c5053053c7baddb8d44e2bf11ff44cd71024a280a8438ae6ba32/jiter-0.11.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55678fbbda261eafe7289165dd2ddd0e922df5f9a1ae46d7c79a5a15242bd7d1", size = 487144, upload-time = "2025-10-17T11:29:05.37Z" }, + { url = "https://files.pythonhosted.org/packages/b0/50/b63fcadf699893269b997f4c2e88400bc68f085c6db698c6e5e69d63b2c1/jiter-0.11.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6a6b74fae8e40497653b52ce6ca0f1b13457af769af6fb9c1113efc8b5b4d9be", size = 376215, upload-time = "2025-10-17T11:29:07.123Z" }, + { url = "https://files.pythonhosted.org/packages/39/8c/57a8a89401134167e87e73471b9cca321cf651c1fd78c45f3a0f16932213/jiter-0.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a55a453f8b035eb4f7852a79a065d616b7971a17f5e37a9296b4b38d3b619e4", size = 359163, upload-time = "2025-10-17T11:29:09.047Z" }, + { url = "https://files.pythonhosted.org/packages/4b/96/30b0cdbffbb6f753e25339d3dbbe26890c9ef119928314578201c758aace/jiter-0.11.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2638148099022e6bdb3f42904289cd2e403609356fb06eb36ddec2d50958bc29", size = 385344, upload-time = "2025-10-17T11:29:10.69Z" }, + { url = "https://files.pythonhosted.org/packages/c6/d5/31dae27c1cc9410ad52bb514f11bfa4f286f7d6ef9d287b98b8831e156ec/jiter-0.11.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:252490567a5d990986f83b95a5f1ca1bf205ebd27b3e9e93bb7c2592380e29b9", size = 517972, upload-time = "2025-10-17T11:29:12.174Z" }, + { url = "https://files.pythonhosted.org/packages/61/1e/5905a7a3aceab80de13ab226fd690471a5e1ee7e554dc1015e55f1a6b896/jiter-0.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d431d52b0ca2436eea6195f0f48528202100c7deda354cb7aac0a302167594d5", size = 508408, upload-time = "2025-10-17T11:29:13.597Z" }, + { url = "https://files.pythonhosted.org/packages/91/12/1c49b97aa49077e136e8591cef7162f0d3e2860ae457a2d35868fd1521ef/jiter-0.11.1-cp311-cp311-win32.whl", hash = "sha256:db6f41e40f8bae20c86cb574b48c4fd9f28ee1c71cb044e9ec12e78ab757ba3a", size = 203937, upload-time = "2025-10-17T11:29:14.894Z" }, + { url = "https://files.pythonhosted.org/packages/6d/9d/2255f7c17134ee9892c7e013c32d5bcf4bce64eb115402c9fe5e727a67eb/jiter-0.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:0cc407b8e6cdff01b06bb80f61225c8b090c3df108ebade5e0c3c10993735b19", size = 207589, upload-time = "2025-10-17T11:29:16.166Z" }, + { url = "https://files.pythonhosted.org/packages/3c/28/6307fc8f95afef84cae6caf5429fee58ef16a582c2ff4db317ceb3e352fa/jiter-0.11.1-cp311-cp311-win_arm64.whl", hash = "sha256:fe04ea475392a91896d1936367854d346724a1045a247e5d1c196410473b8869", size = 188391, upload-time = "2025-10-17T11:29:17.488Z" }, + { url = "https://files.pythonhosted.org/packages/15/8b/318e8af2c904a9d29af91f78c1e18f0592e189bbdb8a462902d31fe20682/jiter-0.11.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:c92148eec91052538ce6823dfca9525f5cfc8b622d7f07e9891a280f61b8c96c", size = 305655, upload-time = "2025-10-17T11:29:18.859Z" }, + { url = "https://files.pythonhosted.org/packages/f7/29/6c7de6b5d6e511d9e736312c0c9bfcee8f9b6bef68182a08b1d78767e627/jiter-0.11.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ecd4da91b5415f183a6be8f7158d127bdd9e6a3174138293c0d48d6ea2f2009d", size = 315645, upload-time = "2025-10-17T11:29:20.889Z" }, + { url = "https://files.pythonhosted.org/packages/ac/5f/ef9e5675511ee0eb7f98dd8c90509e1f7743dbb7c350071acae87b0145f3/jiter-0.11.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7e3ac25c00b9275684d47aa42febaa90a9958e19fd1726c4ecf755fbe5e553b", size = 348003, upload-time = "2025-10-17T11:29:22.712Z" }, + { url = "https://files.pythonhosted.org/packages/56/1b/abe8c4021010b0a320d3c62682769b700fb66f92c6db02d1a1381b3db025/jiter-0.11.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:57d7305c0a841858f866cd459cd9303f73883fb5e097257f3d4a3920722c69d4", size = 365122, upload-time = "2025-10-17T11:29:24.408Z" }, + { url = "https://files.pythonhosted.org/packages/2a/2d/4a18013939a4f24432f805fbd5a19893e64650b933edb057cd405275a538/jiter-0.11.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e86fa10e117dce22c547f31dd6d2a9a222707d54853d8de4e9a2279d2c97f239", size = 488360, upload-time = "2025-10-17T11:29:25.724Z" }, + { url = "https://files.pythonhosted.org/packages/f0/77/38124f5d02ac4131f0dfbcfd1a19a0fac305fa2c005bc4f9f0736914a1a4/jiter-0.11.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ae5ef1d48aec7e01ee8420155d901bb1d192998fa811a65ebb82c043ee186711", size = 376884, upload-time = "2025-10-17T11:29:27.056Z" }, + { url = "https://files.pythonhosted.org/packages/7b/43/59fdc2f6267959b71dd23ce0bd8d4aeaf55566aa435a5d00f53d53c7eb24/jiter-0.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eb68e7bf65c990531ad8715e57d50195daf7c8e6f1509e617b4e692af1108939", size = 358827, upload-time = "2025-10-17T11:29:28.698Z" }, + { url = "https://files.pythonhosted.org/packages/7d/d0/b3cc20ff5340775ea3bbaa0d665518eddecd4266ba7244c9cb480c0c82ec/jiter-0.11.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:43b30c8154ded5845fa454ef954ee67bfccce629b2dea7d01f795b42bc2bda54", size = 385171, upload-time = "2025-10-17T11:29:30.078Z" }, + { url = "https://files.pythonhosted.org/packages/d2/bc/94dd1f3a61f4dc236f787a097360ec061ceeebebf4ea120b924d91391b10/jiter-0.11.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:586cafbd9dd1f3ce6a22b4a085eaa6be578e47ba9b18e198d4333e598a91db2d", size = 518359, upload-time = "2025-10-17T11:29:31.464Z" }, + { url = "https://files.pythonhosted.org/packages/7e/8c/12ee132bd67e25c75f542c227f5762491b9a316b0dad8e929c95076f773c/jiter-0.11.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:677cc2517d437a83bb30019fd4cf7cad74b465914c56ecac3440d597ac135250", size = 509205, upload-time = "2025-10-17T11:29:32.895Z" }, + { url = "https://files.pythonhosted.org/packages/39/d5/9de848928ce341d463c7e7273fce90ea6d0ea4343cd761f451860fa16b59/jiter-0.11.1-cp312-cp312-win32.whl", hash = "sha256:fa992af648fcee2b850a3286a35f62bbbaeddbb6dbda19a00d8fbc846a947b6e", size = 205448, upload-time = "2025-10-17T11:29:34.217Z" }, + { url = "https://files.pythonhosted.org/packages/ee/b0/8002d78637e05009f5e3fb5288f9d57d65715c33b5d6aa20fd57670feef5/jiter-0.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:88b5cae9fa51efeb3d4bd4e52bfd4c85ccc9cac44282e2a9640893a042ba4d87", size = 204285, upload-time = "2025-10-17T11:29:35.446Z" }, + { url = "https://files.pythonhosted.org/packages/9f/a2/bb24d5587e4dff17ff796716542f663deee337358006a80c8af43ddc11e5/jiter-0.11.1-cp312-cp312-win_arm64.whl", hash = "sha256:9a6cae1ab335551917f882f2c3c1efe7617b71b4c02381e4382a8fc80a02588c", size = 188712, upload-time = "2025-10-17T11:29:37.027Z" }, + { url = "https://files.pythonhosted.org/packages/9d/51/bd41562dd284e2a18b6dc0a99d195fd4a3560d52ab192c42e56fe0316643/jiter-0.11.1-graalpy311-graalpy242_311_native-macosx_10_12_x86_64.whl", hash = "sha256:e642b5270e61dd02265866398707f90e365b5db2eb65a4f30c789d826682e1f6", size = 306871, upload-time = "2025-10-17T11:31:03.616Z" }, + { url = "https://files.pythonhosted.org/packages/ba/cb/64e7f21dd357e8cd6b3c919c26fac7fc198385bbd1d85bb3b5355600d787/jiter-0.11.1-graalpy311-graalpy242_311_native-macosx_11_0_arm64.whl", hash = "sha256:464ba6d000585e4e2fd1e891f31f1231f497273414f5019e27c00a4b8f7a24ad", size = 301454, upload-time = "2025-10-17T11:31:05.338Z" }, + { url = "https://files.pythonhosted.org/packages/55/b0/54bdc00da4ef39801b1419a01035bd8857983de984fd3776b0be6b94add7/jiter-0.11.1-graalpy311-graalpy242_311_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:055568693ab35e0bf3a171b03bb40b2dcb10352359e0ab9b5ed0da2bf1eb6f6f", size = 336801, upload-time = "2025-10-17T11:31:06.893Z" }, + { url = "https://files.pythonhosted.org/packages/de/8f/87176ed071d42e9db415ed8be787ef4ef31a4fa27f52e6a4fbf34387bd28/jiter-0.11.1-graalpy311-graalpy242_311_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e0c69ea798d08a915ba4478113efa9e694971e410056392f4526d796f136d3fa", size = 343452, upload-time = "2025-10-17T11:31:08.259Z" }, + { url = "https://files.pythonhosted.org/packages/a6/bc/950dd7f170c6394b6fdd73f989d9e729bd98907bcc4430ef080a72d06b77/jiter-0.11.1-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:0d4d6993edc83cf75e8c6828a8d6ce40a09ee87e38c7bfba6924f39e1337e21d", size = 302626, upload-time = "2025-10-17T11:31:09.645Z" }, + { url = "https://files.pythonhosted.org/packages/3a/65/43d7971ca82ee100b7b9b520573eeef7eabc0a45d490168ebb9a9b5bb8b2/jiter-0.11.1-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:f78d151c83a87a6cf5461d5ee55bc730dd9ae227377ac6f115b922989b95f838", size = 297034, upload-time = "2025-10-17T11:31:10.975Z" }, + { url = "https://files.pythonhosted.org/packages/19/4c/000e1e0c0c67e96557a279f8969487ea2732d6c7311698819f977abae837/jiter-0.11.1-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c9022974781155cd5521d5cb10997a03ee5e31e8454c9d999dcdccd253f2353f", size = 337328, upload-time = "2025-10-17T11:31:12.399Z" }, + { url = "https://files.pythonhosted.org/packages/d9/71/71408b02c6133153336d29fa3ba53000f1e1a3f78bb2fc2d1a1865d2e743/jiter-0.11.1-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18c77aaa9117510d5bdc6a946baf21b1f0cfa58ef04d31c8d016f206f2118960", size = 343697, upload-time = "2025-10-17T11:31:13.773Z" }, ] [[package]] @@ -3357,20 +3428,20 @@ wheels = [ [[package]] name = "joblib" -version = "1.5.2" +version = "1.5.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/5d/447af5ea094b9e4c4054f82e223ada074c552335b9b4b2d14bd9b35a67c4/joblib-1.5.2.tar.gz", hash = "sha256:3faa5c39054b2f03ca547da9b2f52fde67c06240c31853f306aea97f13647b55", size = 331077, upload-time = "2025-08-27T12:15:46.575Z" } +sdist = { url = "https://files.pythonhosted.org/packages/41/f2/d34e8b3a08a9cc79a50b2208a93dce981fe615b64d5a4d4abee421d898df/joblib-1.5.3.tar.gz", hash = "sha256:8561a3269e6801106863fd0d6d84bb737be9e7631e33aaed3fb9ce5953688da3", size = 331603, upload-time = "2025-12-15T08:41:46.427Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/e8/685f47e0d754320684db4425a0967f7d3fa70126bffd76110b7009a0090f/joblib-1.5.2-py3-none-any.whl", hash = "sha256:4e1f0bdbb987e6d843c70cf43714cb276623def372df3c22fe5266b2670bc241", size = 308396, upload-time = "2025-08-27T12:15:45.188Z" }, + { url = "https://files.pythonhosted.org/packages/7b/91/984aca2ec129e2757d1e4e3c81c3fcda9d0f85b74670a094cc443d9ee949/joblib-1.5.3-py3-none-any.whl", hash = "sha256:5fc3c5039fc5ca8c0276333a188bbd59d6b7ab37fe6632daa76bc7f9ec18e713", size = 309071, upload-time = "2025-12-15T08:41:44.973Z" }, ] [[package]] name = "json-repair" -version = "0.55.1" +version = "0.58.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c0/de/71d6bb078d167c0d0959776cee6b6bb8d2ad843f512a5222d7151dde4955/json_repair-0.55.1.tar.gz", hash = "sha256:b27aa0f6bf2e5bf58554037468690446ef26f32ca79c8753282adb3df25fb888", size = 39231, upload-time = "2026-01-23T09:37:20.93Z" } +sdist = { url = "https://files.pythonhosted.org/packages/76/70/484f97a744d2614218a2b162004accda3f3c4ccc8c5d688712624567ebec/json_repair-0.58.6.tar.gz", hash = "sha256:aa740113a1c9dede4ba84c29aa8f81493253aede6f0e4edde9a560ec4b1d7762", size = 44804, upload-time = "2026-03-16T13:43:34.722Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/da/289ba9eb550ae420cfc457926f6c49b87cacf8083ee9927e96921888a665/json_repair-0.55.1-py3-none-any.whl", hash = "sha256:a1bcc151982a12bc3ef9e9528198229587b1074999cfe08921ab6333b0c8e206", size = 29743, upload-time = "2026-01-23T09:37:19.404Z" }, + { url = "https://files.pythonhosted.org/packages/da/bb/c019ac05a6923c5776fa134c65e5b19d216ef17227618d93b1f608bc2806/json_repair-0.58.6-py3-none-any.whl", hash = "sha256:e438a1e4ea03179dfe9a05dfd738e678e888f1ea5b4a40398f8f220925df1c5c", size = 43482, upload-time = "2026-03-16T13:43:33.569Z" }, ] [[package]] @@ -3387,16 +3458,16 @@ wheels = [ [[package]] name = "jsonpointer" -version = "3.0.0" +version = "3.1.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114, upload-time = "2024-06-10T19:24:42.462Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/bf/9ecc036fbc15cf4153ea6ed4dbeed31ef043f762cccc9d44a534be8319b0/jsonpointer-3.1.0.tar.gz", hash = "sha256:f9b39abd59ba8c1de8a4ff16141605d2a8dacc4dd6cf399672cf237bfe47c211", size = 9000, upload-time = "2026-03-20T21:47:09.982Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595, upload-time = "2024-06-10T19:24:40.698Z" }, + { url = "https://files.pythonhosted.org/packages/e4/25/cebb241a435cbf4626b5ea096d8385c04416d7ca3082a15299b746e248fa/jsonpointer-3.1.0-py3-none-any.whl", hash = "sha256:f82aa0f745001f169b96473348370b43c3f581446889c41c807bab1db11c8e7b", size = 7651, upload-time = "2026-03-20T21:47:08.792Z" }, ] [[package]] name = "jsonschema" -version = "4.25.1" +version = "4.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, @@ -3404,9 +3475,9 @@ dependencies = [ { name = "referencing" }, { name = "rpds-py" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/74/69/f7185de793a29082a9f3c7728268ffb31cb5095131a9c139a74078e27336/jsonschema-4.25.1.tar.gz", hash = "sha256:e4a9655ce0da0c0b67a085847e00a3a51449e1157f4f75e9fb5aa545e122eb85", size = 357342, upload-time = "2025-08-18T17:03:50.038Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/fc/e067678238fa451312d4c62bf6e6cf5ec56375422aee02f9cb5f909b3047/jsonschema-4.26.0.tar.gz", hash = "sha256:0c26707e2efad8aa1bfc5b7ce170f3fccc2e4918ff85989ba9ffa9facb2be326", size = 366583, upload-time = "2026-01-07T13:41:07.246Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/9c/8c95d856233c1f82500c2450b8c68576b4cf1c871db3afac5c34ff84e6fd/jsonschema-4.25.1-py3-none-any.whl", hash = "sha256:3fba0169e345c7175110351d456342c364814cfcf3b964ba4587f22915230a63", size = 90040, upload-time = "2025-08-18T17:03:48.373Z" }, + { url = "https://files.pythonhosted.org/packages/69/90/f63fb5873511e014207a475e2bb4e8b2e570d655b00ac19a9a0ca0a385ee/jsonschema-4.26.0-py3-none-any.whl", hash = "sha256:d489f15263b8d200f8387e64b4c3a75f06629559fb73deb8fdfb525f2dab50ce", size = 90630, upload-time = "2026-01-07T13:41:05.306Z" }, ] [[package]] @@ -3430,27 +3501,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4a/4a/cf14bf3b1f5ffb13c69cf5f0ea78031247790558ee88984a8bdd22fae60d/kaitaistruct-0.11-py2.py3-none-any.whl", hash = "sha256:5c6ce79177b4e193a577ecd359e26516d1d6d000a0bffd6e1010f2a46a62a561", size = 11372, upload-time = "2025-09-08T15:46:23.635Z" }, ] -[[package]] -name = "keyring" -version = "25.7.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "importlib-metadata", marker = "python_full_version < '3.12'" }, - { name = "jaraco-classes" }, - { name = "jaraco-context" }, - { name = "jaraco-functools" }, - { name = "jeepney", marker = "sys_platform == 'linux'" }, - { name = "pywin32-ctypes", marker = "sys_platform == 'win32'" }, - { name = "secretstorage", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/43/4b/674af6ef2f97d56f0ab5153bf0bfa28ccb6c3ed4d1babf4305449668807b/keyring-25.7.0.tar.gz", hash = "sha256:fe01bd85eb3f8fb3dd0405defdeac9a5b4f6f0439edbb3149577f244a2e8245b", size = 63516, upload-time = "2025-11-16T16:26:09.482Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/db/e655086b7f3a705df045bf0933bdd9c2f79bb3c97bfef1384598bb79a217/keyring-25.7.0-py3-none-any.whl", hash = "sha256:be4a0b195f149690c166e850609a477c532ddbfbaed96a404d4e43f8d5e2689f", size = 39160, upload-time = "2025-11-16T16:26:08.402Z" }, -] - [[package]] name = "kombu" -version = "5.5.4" +version = "5.6.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "amqp" }, @@ -3458,20 +3511,18 @@ dependencies = [ { name = "tzdata" }, { name = "vine" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/d3/5ff936d8319ac86b9c409f1501b07c426e6ad41966fedace9ef1b966e23f/kombu-5.5.4.tar.gz", hash = "sha256:886600168275ebeada93b888e831352fe578168342f0d1d5833d88ba0d847363", size = 461992, upload-time = "2025-06-01T10:19:22.281Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b6/a5/607e533ed6c83ae1a696969b8e1c137dfebd5759a2e9682e26ff1b97740b/kombu-5.6.2.tar.gz", hash = "sha256:8060497058066c6f5aed7c26d7cd0d3b574990b09de842a8c5aaed0b92cc5a55", size = 472594, upload-time = "2025-12-29T20:30:07.779Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/70/a07dcf4f62598c8ad579df241af55ced65bed76e42e45d3c368a6d82dbc1/kombu-5.5.4-py3-none-any.whl", hash = "sha256:a12ed0557c238897d8e518f1d1fdf84bd1516c5e305af2dacd85c2015115feb8", size = 210034, upload-time = "2025-06-01T10:19:20.436Z" }, + { url = "https://files.pythonhosted.org/packages/fb/0f/834427d8c03ff1d7e867d3db3d176470c64871753252b21b4f4897d1fa45/kombu-5.6.2-py3-none-any.whl", hash = "sha256:efcfc559da324d41d61ca311b0c64965ea35b4c55cc04ee36e55386145dace93", size = 214219, upload-time = "2025-12-29T20:30:05.74Z" }, ] [[package]] name = "kubernetes" -version = "33.1.0" +version = "35.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "durationpy" }, - { name = "google-auth" }, - { name = "oauthlib" }, { name = "python-dateutil" }, { name = "pyyaml" }, { name = "requests" }, @@ -3480,14 +3531,28 @@ dependencies = [ { name = "urllib3" }, { name = "websocket-client" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/52/19ebe8004c243fdfa78268a96727c71e08f00ff6fe69a301d0b7fcbce3c2/kubernetes-33.1.0.tar.gz", hash = "sha256:f64d829843a54c251061a8e7a14523b521f2dc5c896cf6d65ccf348648a88993", size = 1036779, upload-time = "2025-06-09T21:57:58.521Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/8f/85bf51ad4150f64e8c665daf0d9dfe9787ae92005efb9a4d1cba592bd79d/kubernetes-35.0.0.tar.gz", hash = "sha256:3d00d344944239821458b9efd484d6df9f011da367ecb155dadf9513f05f09ee", size = 1094642, upload-time = "2026-01-16T01:05:27.76Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/89/43/d9bebfc3db7dea6ec80df5cb2aad8d274dd18ec2edd6c4f21f32c237cbbb/kubernetes-33.1.0-py2.py3-none-any.whl", hash = "sha256:544de42b24b64287f7e0aa9513c93cb503f7f40eea39b20f66810011a86eabc5", size = 1941335, upload-time = "2025-06-09T21:57:56.327Z" }, + { url = "https://files.pythonhosted.org/packages/0c/70/05b685ea2dffcb2adbf3cdcea5d8865b7bc66f67249084cf845012a0ff13/kubernetes-35.0.0-py2.py3-none-any.whl", hash = "sha256:39e2b33b46e5834ef6c3985ebfe2047ab39135d41de51ce7641a7ca5b372a13d", size = 2017602, upload-time = "2026-01-16T01:05:25.991Z" }, ] [[package]] name = "langchain" -version = "0.3.25" +version = "1.2.13" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/81/e5/56fdeedaa0ef1be3c53721d382d9e21c63930179567361610ea6102c04ea/langchain-1.2.13.tar.gz", hash = "sha256:d566ef67c8287e7f2e2df3c99bf3953a6beefd2a75a97fe56ecce905e21f3ef4", size = 573819, upload-time = "2026-03-19T17:16:07.641Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/1d/a509af07535d8f4621d77f3ba5ec846ee6d52c59d2239e1385ec3b29bf92/langchain-1.2.13-py3-none-any.whl", hash = "sha256:37d4526ac4b0cdd3d7706a6366124c30dc0771bf5340865b37cdc99d5e5ad9b1", size = 112488, upload-time = "2026-03-19T17:16:06.134Z" }, +] + +[[package]] +name = "langchain-classic" +version = "1.0.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, @@ -3498,20 +3563,20 @@ dependencies = [ { name = "requests" }, { name = "sqlalchemy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/f9/a256609096a9fc7a1b3a6300a97000091efabdf21555a97988f93d4d9258/langchain-0.3.25.tar.gz", hash = "sha256:a1d72aa39546a23db08492d7228464af35c9ee83379945535ceef877340d2a3a", size = 10225045, upload-time = "2025-05-02T18:39:04.353Z" } +sdist = { url = "https://files.pythonhosted.org/packages/32/04/b01c09e37414bab9f209efa311502841a3c0de5bc6c35e729c8d8a9893c9/langchain_classic-1.0.3.tar.gz", hash = "sha256:168ef1dfbfb18cae5a9ff0accecc9413a5b5aa3464b53fa841561a3384b6324a", size = 10534933, upload-time = "2026-03-13T13:56:11.96Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/5c/5c0be747261e1f8129b875fa3bfea736bc5fe17652f9d5e15ca118571b6f/langchain-0.3.25-py3-none-any.whl", hash = "sha256:931f7d2d1eaf182f9f41c5e3272859cfe7f94fc1f7cef6b3e5a46024b4884c21", size = 1011008, upload-time = "2025-05-02T18:39:02.21Z" }, + { url = "https://files.pythonhosted.org/packages/ab/e6/cfdeedec0537ffbf5041773590d25beb7f2aa467cc6630e788c9c7c72c3e/langchain_classic-1.0.3-py3-none-any.whl", hash = "sha256:26df1ec9806b1fbff19d9085a747ea7d8d82d7e3fb1d25132859979de627ef79", size = 1041335, upload-time = "2026-03-13T13:56:09.677Z" }, ] [[package]] name = "langchain-community" -version = "0.3.24" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, { name = "dataclasses-json" }, { name = "httpx-sse" }, - { name = "langchain" }, + { name = "langchain-classic" }, { name = "langchain-core" }, { name = "langsmith" }, { name = "numpy" }, @@ -3521,14 +3586,14 @@ dependencies = [ { name = "sqlalchemy" }, { name = "tenacity" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/f6/4892d1f1cf6d3e89da6ee6cfb0eb82b908c706c58bde7df28367ee76a93f/langchain_community-0.3.24.tar.gz", hash = "sha256:62d9e8cf9aadf35182ec3925f9ec1c8e5e84fb4f199f67a01aee496d289dc264", size = 33233643, upload-time = "2025-05-12T13:26:39.222Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/97/a03585d42b9bdb6fbd935282d6e3348b10322a24e6ce12d0c99eb461d9af/langchain_community-0.4.1.tar.gz", hash = "sha256:f3b211832728ee89f169ddce8579b80a085222ddb4f4ed445a46e977d17b1e85", size = 33241144, upload-time = "2025-10-27T15:20:32.504Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/cb/582f22d74d69f4dbd41e98d361ee36922b79a245a9411383327bd4b63747/langchain_community-0.3.24-py3-none-any.whl", hash = "sha256:b6cdb376bf1c2f4d2503aca20f8f35f2d5b3d879c52848277f20ce1950e7afaf", size = 2528335, upload-time = "2025-05-12T13:26:37.375Z" }, + { url = "https://files.pythonhosted.org/packages/f0/a4/c4fde67f193401512337456cabc2148f2c43316e445f5decd9f8806e2992/langchain_community-0.4.1-py3-none-any.whl", hash = "sha256:2135abb2c7748a35c84613108f7ebf30f8505b18c3c18305ffaecfc7651f6c6a", size = 2533285, upload-time = "2025-10-27T15:20:30.767Z" }, ] [[package]] name = "langchain-core" -version = "0.3.63" +version = "1.2.20" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jsonpatch" }, @@ -3538,36 +3603,37 @@ dependencies = [ { name = "pyyaml" }, { name = "tenacity" }, { name = "typing-extensions" }, + { name = "uuid-utils" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b9/0a/b71a9a5d42e743d6876cce23d803e284b191ed4d6544e2f7fe1b37f7854c/langchain_core-0.3.63.tar.gz", hash = "sha256:e2e30cfbb7684a5a0319f6cbf065fc3c438bfd1060302f085a122527890fb01e", size = 558302, upload-time = "2025-05-29T18:57:19.933Z" } +sdist = { url = "https://files.pythonhosted.org/packages/db/41/6552a419fe549a79601e5a698d1d5ee2ca7fe93bb87fd624a16a8c1bdee3/langchain_core-1.2.20.tar.gz", hash = "sha256:c7ac8b976039b5832abb989fef058b88c270594ba331efc79e835df046e7dc44", size = 838330, upload-time = "2026-03-18T17:34:45.522Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/71/a748861e6a69ab6ef50ab8e65120422a1f36245c71a0dd0f02de49c208e1/langchain_core-0.3.63-py3-none-any.whl", hash = "sha256:f91db8221b1bc6808f70b2e72fded1a94d50ee3f1dff1636fb5a5a514c64b7f5", size = 438468, upload-time = "2025-05-29T18:57:17.424Z" }, + { url = "https://files.pythonhosted.org/packages/d9/06/08c88ddd4d6766de4e6c43111ae8f3025df383d2a4379cb938fc571b49d4/langchain_core-1.2.20-py3-none-any.whl", hash = "sha256:b65ff678f3c3dc1f1b4d03a3af5ee3b8d51f9be5181d74eb53c6c11cd9dd5e68", size = 504215, upload-time = "2026-03-18T17:34:44.087Z" }, ] [[package]] name = "langchain-openai" -version = "0.3.19" +version = "1.1.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, { name = "openai" }, { name = "tiktoken" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5c/aa/4622a8c722f7bbd5261c8492d805165d845bc3212eca16b156fa48ce5626/langchain_openai-0.3.19.tar.gz", hash = "sha256:2ff103f272d01694aef650dfe0dc64525481b89f7f9e61f5e3ef8eb21da9f5fe", size = 544254, upload-time = "2025-06-02T17:57:29.34Z" } +sdist = { url = "https://files.pythonhosted.org/packages/40/cd/439be2b8deb8bd0d4c470c7c7f66698a84d823e583c3d36a322483cb7cab/langchain_openai-1.1.11.tar.gz", hash = "sha256:44b003a2960d1f6699f23721196b3b97d0c420d2e04444950869213214b7a06a", size = 1088560, upload-time = "2026-03-09T23:02:36.894Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/30/1dd3bebaccdce52afc0139f4733b31d01986b42d53e4ac1a919eca9531d8/langchain_openai-0.3.19-py3-none-any.whl", hash = "sha256:0ffb9eb86e1d25909a8e406e7e1fead3bd2e8d74a7e6daa74bacf2c6971e8b99", size = 64463, upload-time = "2025-06-02T17:57:27.938Z" }, + { url = "https://files.pythonhosted.org/packages/f0/0f/e4cb42848c25f65969adfb500a06dea1a541831604250fd0d8aa6e54fef5/langchain_openai-1.1.11-py3-none-any.whl", hash = "sha256:a03596221405d38d6852fb865467cb0d9ff9e79f335905eb6a576e8c4874ac71", size = 87694, upload-time = "2026-03-09T23:02:35.651Z" }, ] [[package]] name = "langchain-text-splitters" -version = "0.3.8" +version = "1.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "langchain-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e7/ac/b4a25c5716bb0103b1515f1f52cc69ffb1035a5a225ee5afe3aed28bf57b/langchain_text_splitters-0.3.8.tar.gz", hash = "sha256:116d4b9f2a22dda357d0b79e30acf005c5518177971c66a9f1ab0edfdb0f912e", size = 42128, upload-time = "2025-04-04T14:03:51.521Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/38/14121ead61e0e75f79c3a35e5148ac7c2fe754a55f76eab3eed573269524/langchain_text_splitters-1.1.1.tar.gz", hash = "sha256:34861abe7c07d9e49d4dc852d0129e26b32738b60a74486853ec9b6d6a8e01d2", size = 279352, upload-time = "2026-02-18T23:02:42.798Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/a3/3696ff2444658053c01b6b7443e761f28bb71217d82bb89137a978c5f66f/langchain_text_splitters-0.3.8-py3-none-any.whl", hash = "sha256:e75cc0f4ae58dcf07d9f18776400cf8ade27fadd4ff6d264df6278bb302f6f02", size = 32440, upload-time = "2025-04-04T14:03:50.6Z" }, + { url = "https://files.pythonhosted.org/packages/84/66/d9e0c3b83b0ad75ee746c51ba347cacecb8d656b96e1d513f3e334d1ccab/langchain_text_splitters-1.1.1-py3-none-any.whl", hash = "sha256:5ed0d7bf314ba925041e7d7d17cd8b10f688300d5415fb26c29442f061e329dc", size = 35734, upload-time = "2026-02-18T23:02:41.913Z" }, ] [[package]] @@ -3597,25 +3663,128 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/f7/242a13ca094c78464b7d4df77dfe7d4c44ed77b15fed3d2e3486afa5d2e1/langfuse-2.51.5-py3-none-any.whl", hash = "sha256:b95401ca710ef94b521afa6541933b6f93d7cfd4a97523c8fc75bca4d6d219fb", size = 214281, upload-time = "2024-10-09T00:59:12.596Z" }, ] +[[package]] +name = "langgraph" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, + { name = "langgraph-prebuilt" }, + { name = "langgraph-sdk" }, + { name = "pydantic" }, + { name = "xxhash" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d2/b2/e7db624e8b0ee063ecfbf7acc09467c0836a05914a78e819dfb3744a0fac/langgraph-1.1.3.tar.gz", hash = "sha256:ee496c297a9c93b38d8560be15cbb918110f49077d83abd14976cb13ac3b3370", size = 545120, upload-time = "2026-03-18T23:42:58.24Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fb/f7/221cc479e95e03e260496616e5ce6fb50c1ea01472e3a5bc481a9b8a2f83/langgraph-1.1.3-py3-none-any.whl", hash = "sha256:57cd6964ebab41cbd211f222293a2352404e55f8b2312cecde05e8753739b546", size = 168149, upload-time = "2026-03-18T23:42:56.967Z" }, +] + +[[package]] +name = "langgraph-checkpoint" +version = "4.0.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "ormsgpack" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/44/a8df45d1e8b4637e29789fa8bae1db022c953cc7ac80093cfc52e923547e/langgraph_checkpoint-4.0.1.tar.gz", hash = "sha256:b433123735df11ade28829e40ce25b9be614930cd50245ff2af60629234befd9", size = 158135, upload-time = "2026-02-27T21:06:16.092Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/65/4c/09a4a0c42f5d2fc38d6c4d67884788eff7fd2cfdf367fdf7033de908b4c0/langgraph_checkpoint-4.0.1-py3-none-any.whl", hash = "sha256:e3adcd7a0e0166f3b48b8cf508ce0ea366e7420b5a73aa81289888727769b034", size = 50453, upload-time = "2026-02-27T21:06:14.293Z" }, +] + +[[package]] +name = "langgraph-prebuilt" +version = "1.0.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langchain-core" }, + { name = "langgraph-checkpoint" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0d/06/dd61a5c2dce009d1b03b1d56f2a85b3127659fdddf5b3be5d8f1d60820fb/langgraph_prebuilt-1.0.8.tar.gz", hash = "sha256:0cd3cf5473ced8a6cd687cc5294e08d3de57529d8dd14fdc6ae4899549efcf69", size = 164442, upload-time = "2026-02-19T18:14:39.083Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/41/ec966424ad3f2ed3996d24079d3342c8cd6c0bd0653c12b2a917a685ec6c/langgraph_prebuilt-1.0.8-py3-none-any.whl", hash = "sha256:d16a731e591ba4470f3e313a319c7eee7dbc40895bcf15c821f985a3522a7ce0", size = 35648, upload-time = "2026-02-19T18:14:37.611Z" }, +] + +[[package]] +name = "langgraph-sdk" +version = "0.3.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, + { name = "orjson" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fd/a1/012f0e0f5c9fd26f92bdc9d244756ad673c428230156ef668e6ec7c18cee/langgraph_sdk-0.3.12.tar.gz", hash = "sha256:c9c9ec22b3c0fcd352e2b8f32a815164f69446b8648ca22606329f4ff4c59a71", size = 194932, upload-time = "2026-03-18T22:15:54.592Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/17/4d/4f796e86b03878ab20d9b30aaed1ad459eda71a5c5b67f7cfe712f3548f2/langgraph_sdk-0.3.12-py3-none-any.whl", hash = "sha256:44323804965d6ec2a07127b3cf08a0428ea6deaeb172c2d478d5cd25540e3327", size = 95834, upload-time = "2026-03-18T22:15:53.545Z" }, +] + [[package]] name = "langsmith" -version = "0.1.147" +version = "0.7.22" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, { name = "orjson", marker = "platform_python_implementation != 'PyPy'" }, + { name = "packaging" }, { name = "pydantic" }, { name = "requests" }, { name = "requests-toolbelt" }, + { name = "uuid-utils" }, + { name = "xxhash" }, + { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/6c/56/201dd94d492ae47c1bf9b50cacc1985113dc2288d8f15857e1f4a6818376/langsmith-0.1.147.tar.gz", hash = "sha256:2e933220318a4e73034657103b3b1a3a6109cc5db3566a7e8e03be8d6d7def7a", size = 300453, upload-time = "2024-11-27T17:32:41.297Z" } +sdist = { url = "https://files.pythonhosted.org/packages/be/2a/2d5e6c67396fd228670af278c4da7bd6db2b8d11deaf6f108490b6d3f561/langsmith-0.7.22.tar.gz", hash = "sha256:35bfe795d648b069958280760564632fd28ebc9921c04f3e209c0db6a6c7dc04", size = 1134923, upload-time = "2026-03-19T22:45:23.492Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/f0/63b06b99b730b9954f8709f6f7d9b8d076fa0a973e472efe278089bde42b/langsmith-0.1.147-py3-none-any.whl", hash = "sha256:7166fc23b965ccf839d64945a78e9f1157757add228b086141eb03a60d699a15", size = 311812, upload-time = "2024-11-27T17:32:39.569Z" }, + { url = "https://files.pythonhosted.org/packages/1a/94/1f5d72655ab6534129540843776c40eff757387b88e798d8b3bf7e313fd4/langsmith-0.7.22-py3-none-any.whl", hash = "sha256:6e9d5148314d74e86748cb9d3898632cad0320c9323d95f70f969e5bc078eee4", size = 359927, upload-time = "2026-03-19T22:45:21.603Z" }, +] + +[[package]] +name = "lark" +version = "1.3.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/da/34/28fff3ab31ccff1fd4f6c7c7b0ceb2b6968d8ea4950663eadcb5720591a0/lark-1.3.1.tar.gz", hash = "sha256:b426a7a6d6d53189d318f2b6236ab5d6429eaf09259f1ca33eb716eed10d2905", size = 382732, upload-time = "2025-10-27T18:25:56.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12", size = 113151, upload-time = "2025-10-27T18:25:54.882Z" }, +] + +[[package]] +name = "librt" +version = "0.8.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/56/9c/b4b0c54d84da4a94b37bd44151e46d5e583c9534c7e02250b961b1b6d8a8/librt-0.8.1.tar.gz", hash = "sha256:be46a14693955b3bd96014ccbdb8339ee8c9346fbe11c1b78901b55125f14c73", size = 177471, upload-time = "2026-02-17T16:13:06.101Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1d/01/0e748af5e4fee180cf7cd12bd12b0513ad23b045dccb2a83191bde82d168/librt-0.8.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:681dc2451d6d846794a828c16c22dc452d924e9f700a485b7ecb887a30aad1fd", size = 65315, upload-time = "2026-02-17T16:11:25.152Z" }, + { url = "https://files.pythonhosted.org/packages/9d/4d/7184806efda571887c798d573ca4134c80ac8642dcdd32f12c31b939c595/librt-0.8.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a3b4350b13cc0e6f5bec8fa7caf29a8fb8cdc051a3bae45cfbfd7ce64f009965", size = 68021, upload-time = "2026-02-17T16:11:26.129Z" }, + { url = "https://files.pythonhosted.org/packages/ae/88/c3c52d2a5d5101f28d3dc89298444626e7874aa904eed498464c2af17627/librt-0.8.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ac1e7817fd0ed3d14fd7c5df91daed84c48e4c2a11ee99c0547f9f62fdae13da", size = 194500, upload-time = "2026-02-17T16:11:27.177Z" }, + { url = "https://files.pythonhosted.org/packages/d6/5d/6fb0a25b6a8906e85b2c3b87bee1d6ed31510be7605b06772f9374ca5cb3/librt-0.8.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:747328be0c5b7075cde86a0e09d7a9196029800ba75a1689332348e998fb85c0", size = 205622, upload-time = "2026-02-17T16:11:28.242Z" }, + { url = "https://files.pythonhosted.org/packages/b2/a6/8006ae81227105476a45691f5831499e4d936b1c049b0c1feb17c11b02d1/librt-0.8.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f0af2bd2bc204fa27f3d6711d0f360e6b8c684a035206257a81673ab924aa11e", size = 218304, upload-time = "2026-02-17T16:11:29.344Z" }, + { url = "https://files.pythonhosted.org/packages/ee/19/60e07886ad16670aae57ef44dada41912c90906a6fe9f2b9abac21374748/librt-0.8.1-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:d480de377f5b687b6b1bc0c0407426da556e2a757633cc7e4d2e1a057aa688f3", size = 211493, upload-time = "2026-02-17T16:11:30.445Z" }, + { url = "https://files.pythonhosted.org/packages/9c/cf/f666c89d0e861d05600438213feeb818c7514d3315bae3648b1fc145d2b6/librt-0.8.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d0ee06b5b5291f609ddb37b9750985b27bc567791bc87c76a569b3feed8481ac", size = 219129, upload-time = "2026-02-17T16:11:32.021Z" }, + { url = "https://files.pythonhosted.org/packages/8f/ef/f1bea01e40b4a879364c031476c82a0dc69ce068daad67ab96302fed2d45/librt-0.8.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:9e2c6f77b9ad48ce5603b83b7da9ee3e36b3ab425353f695cba13200c5d96596", size = 213113, upload-time = "2026-02-17T16:11:33.192Z" }, + { url = "https://files.pythonhosted.org/packages/9b/80/cdab544370cc6bc1b72ea369525f547a59e6938ef6863a11ab3cd24759af/librt-0.8.1-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:439352ba9373f11cb8e1933da194dcc6206daf779ff8df0ed69c5e39113e6a99", size = 212269, upload-time = "2026-02-17T16:11:34.373Z" }, + { url = "https://files.pythonhosted.org/packages/9d/9c/48d6ed8dac595654f15eceab2035131c136d1ae9a1e3548e777bb6dbb95d/librt-0.8.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:82210adabbc331dbb65d7868b105185464ef13f56f7f76688565ad79f648b0fe", size = 234673, upload-time = "2026-02-17T16:11:36.063Z" }, + { url = "https://files.pythonhosted.org/packages/16/01/35b68b1db517f27a01be4467593292eb5315def8900afad29fabf56304ba/librt-0.8.1-cp311-cp311-win32.whl", hash = "sha256:52c224e14614b750c0a6d97368e16804a98c684657c7518752c356834fff83bb", size = 54597, upload-time = "2026-02-17T16:11:37.544Z" }, + { url = "https://files.pythonhosted.org/packages/71/02/796fe8f02822235966693f257bf2c79f40e11337337a657a8cfebba5febc/librt-0.8.1-cp311-cp311-win_amd64.whl", hash = "sha256:c00e5c884f528c9932d278d5c9cbbea38a6b81eb62c02e06ae53751a83a4d52b", size = 61733, upload-time = "2026-02-17T16:11:38.691Z" }, + { url = "https://files.pythonhosted.org/packages/28/ad/232e13d61f879a42a4e7117d65e4984bb28371a34bb6fb9ca54ec2c8f54e/librt-0.8.1-cp311-cp311-win_arm64.whl", hash = "sha256:f7cdf7f26c2286ffb02e46d7bac56c94655540b26347673bea15fa52a6af17e9", size = 52273, upload-time = "2026-02-17T16:11:40.308Z" }, + { url = "https://files.pythonhosted.org/packages/95/21/d39b0a87ac52fc98f621fb6f8060efb017a767ebbbac2f99fbcbc9ddc0d7/librt-0.8.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a28f2612ab566b17f3698b0da021ff9960610301607c9a5e8eaca62f5e1c350a", size = 66516, upload-time = "2026-02-17T16:11:41.604Z" }, + { url = "https://files.pythonhosted.org/packages/69/f1/46375e71441c43e8ae335905e069f1c54febee63a146278bcee8782c84fd/librt-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:60a78b694c9aee2a0f1aaeaa7d101cf713e92e8423a941d2897f4fa37908dab9", size = 68634, upload-time = "2026-02-17T16:11:43.268Z" }, + { url = "https://files.pythonhosted.org/packages/0a/33/c510de7f93bf1fa19e13423a606d8189a02624a800710f6e6a0a0f0784b3/librt-0.8.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:758509ea3f1eba2a57558e7e98f4659d0ea7670bff49673b0dde18a3c7e6c0eb", size = 198941, upload-time = "2026-02-17T16:11:44.28Z" }, + { url = "https://files.pythonhosted.org/packages/dd/36/e725903416409a533d92398e88ce665476f275081d0d7d42f9c4951999e5/librt-0.8.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:039b9f2c506bd0ab0f8725aa5ba339c6f0cd19d3b514b50d134789809c24285d", size = 209991, upload-time = "2026-02-17T16:11:45.462Z" }, + { url = "https://files.pythonhosted.org/packages/30/7a/8d908a152e1875c9f8eac96c97a480df425e657cdb47854b9efaa4998889/librt-0.8.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5bb54f1205a3a6ab41a6fd71dfcdcbd278670d3a90ca502a30d9da583105b6f7", size = 224476, upload-time = "2026-02-17T16:11:46.542Z" }, + { url = "https://files.pythonhosted.org/packages/a8/b8/a22c34f2c485b8903a06f3fe3315341fe6876ef3599792344669db98fcff/librt-0.8.1-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:05bd41cdee35b0c59c259f870f6da532a2c5ca57db95b5f23689fcb5c9e42440", size = 217518, upload-time = "2026-02-17T16:11:47.746Z" }, + { url = "https://files.pythonhosted.org/packages/79/6f/5c6fea00357e4f82ba44f81dbfb027921f1ab10e320d4a64e1c408d035d9/librt-0.8.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adfab487facf03f0d0857b8710cf82d0704a309d8ffc33b03d9302b4c64e91a9", size = 225116, upload-time = "2026-02-17T16:11:49.298Z" }, + { url = "https://files.pythonhosted.org/packages/f2/a0/95ced4e7b1267fe1e2720a111685bcddf0e781f7e9e0ce59d751c44dcfe5/librt-0.8.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:153188fe98a72f206042be10a2c6026139852805215ed9539186312d50a8e972", size = 217751, upload-time = "2026-02-17T16:11:50.49Z" }, + { url = "https://files.pythonhosted.org/packages/93/c2/0517281cb4d4101c27ab59472924e67f55e375bc46bedae94ac6dc6e1902/librt-0.8.1-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:dd3c41254ee98604b08bd5b3af5bf0a89740d4ee0711de95b65166bf44091921", size = 218378, upload-time = "2026-02-17T16:11:51.783Z" }, + { url = "https://files.pythonhosted.org/packages/43/e8/37b3ac108e8976888e559a7b227d0ceac03c384cfd3e7a1c2ee248dbae79/librt-0.8.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e0d138c7ae532908cbb342162b2611dbd4d90c941cd25ab82084aaf71d2c0bd0", size = 241199, upload-time = "2026-02-17T16:11:53.561Z" }, + { url = "https://files.pythonhosted.org/packages/4b/5b/35812d041c53967fedf551a39399271bbe4257e681236a2cf1a69c8e7fa1/librt-0.8.1-cp312-cp312-win32.whl", hash = "sha256:43353b943613c5d9c49a25aaffdba46f888ec354e71e3529a00cca3f04d66a7a", size = 54917, upload-time = "2026-02-17T16:11:54.758Z" }, + { url = "https://files.pythonhosted.org/packages/de/d1/fa5d5331b862b9775aaf2a100f5ef86854e5d4407f71bddf102f4421e034/librt-0.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:ff8baf1f8d3f4b6b7257fcb75a501f2a5499d0dda57645baa09d4d0d34b19444", size = 62017, upload-time = "2026-02-17T16:11:55.748Z" }, + { url = "https://files.pythonhosted.org/packages/c7/7c/c614252f9acda59b01a66e2ddfd243ed1c7e1deab0293332dfbccf862808/librt-0.8.1-cp312-cp312-win_arm64.whl", hash = "sha256:0f2ae3725904f7377e11cc37722d5d401e8b3d5851fb9273d7f4fe04f6b3d37d", size = 52441, upload-time = "2026-02-17T16:11:56.801Z" }, ] [[package]] name = "litellm" -version = "1.77.1" +version = "1.82.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -3631,56 +3800,25 @@ dependencies = [ { name = "tiktoken" }, { name = "tokenizers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8c/65/71fe4851709fa4a612e41b80001a9ad803fea979d21b90970093fd65eded/litellm-1.77.1.tar.gz", hash = "sha256:76bab5203115efb9588244e5bafbfc07a800a239be75d8dc6b1b9d17394c6418", size = 10275745, upload-time = "2025-09-13T21:05:21.377Z" } +sdist = { url = "https://files.pythonhosted.org/packages/29/75/1c537aa458426a9127a92bc2273787b2f987f4e5044e21f01f2eed5244fd/litellm-1.82.6.tar.gz", hash = "sha256:2aa1c2da21fe940c33613aa447119674a3ad4d2ad5eb064e4d5ce5ee42420136", size = 17414147, upload-time = "2026-03-22T06:36:00.452Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/dc/ff4f119cd4d783742c9648a03e0ba5c2b52fc385b2ae9f0d32acf3a78241/litellm-1.77.1-py3-none-any.whl", hash = "sha256:407761dc3c35fbcd41462d3fe65dd3ed70aac705f37cde318006c18940f695a0", size = 9067070, upload-time = "2025-09-13T21:05:18.078Z" }, -] - -[[package]] -name = "llama-index" -version = "0.9.48" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "aiohttp" }, - { name = "dataclasses-json" }, - { name = "deprecated" }, - { name = "dirtyjson" }, - { name = "fsspec" }, - { name = "httpx" }, - { name = "nest-asyncio" }, - { name = "networkx" }, - { name = "nltk" }, - { name = "numpy" }, - { name = "openai" }, - { name = "pandas" }, - { name = "requests" }, - { name = "sqlalchemy", extra = ["asyncio"] }, - { name = "tenacity" }, - { name = "tiktoken" }, - { name = "typing-extensions" }, - { name = "typing-inspect" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f3/83/58657a17a6138d031370fd0ce58b8bb64686b75180422c9a234dee6ce96d/llama_index-0.9.48.tar.gz", hash = "sha256:c50d02ac8c7e4ff9fb41f0860391fe0020ad8a3d7c30048db52d17d8be654bf3", size = 16246528, upload-time = "2024-02-12T15:41:04.037Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/9a/d5/02b3c3572e62f51e005a025923636c59258817d02a7c61f60b484c480f34/llama_index-0.9.48-py3-none-any.whl", hash = "sha256:56aa406d39e7ca53a5d990b55d69901fbb9eddc9af6a40950367dc5d734f6283", size = 15888075, upload-time = "2024-02-12T15:40:57.855Z" }, + { url = "https://files.pythonhosted.org/packages/02/6c/5327667e6dbe9e98cbfbd4261c8e91386a52e38f41419575854248bbab6a/litellm-1.82.6-py3-none-any.whl", hash = "sha256:164a3ef3e19f309e3cabc199bef3d2045212712fefdfa25fc7f75884a5b5b205", size = 15591595, upload-time = "2026-03-22T06:35:56.795Z" }, ] [[package]] name = "llvmlite" -version = "0.45.1" +version = "0.46.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/99/8d/5baf1cef7f9c084fb35a8afbde88074f0d6a727bc63ef764fe0e7543ba40/llvmlite-0.45.1.tar.gz", hash = "sha256:09430bb9d0bb58fc45a45a57c7eae912850bedc095cd0810a57de109c69e1c32", size = 185600, upload-time = "2025-10-01T17:59:52.046Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb", size = 193456, upload-time = "2025-12-08T18:15:36.295Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/ad/9bdc87b2eb34642c1cfe6bcb4f5db64c21f91f26b010f263e7467e7536a3/llvmlite-0.45.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:60f92868d5d3af30b4239b50e1717cb4e4e54f6ac1c361a27903b318d0f07f42", size = 43043526, upload-time = "2025-10-01T18:03:15.051Z" }, - { url = "https://files.pythonhosted.org/packages/a5/ea/c25c6382f452a943b4082da5e8c1665ce29a62884e2ec80608533e8e82d5/llvmlite-0.45.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98baab513e19beb210f1ef39066288784839a44cd504e24fff5d17f1b3cf0860", size = 37253118, upload-time = "2025-10-01T18:04:06.783Z" }, - { url = "https://files.pythonhosted.org/packages/fe/af/85fc237de98b181dbbe8647324331238d6c52a3554327ccdc83ced28efba/llvmlite-0.45.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3adc2355694d6a6fbcc024d59bb756677e7de506037c878022d7b877e7613a36", size = 56288209, upload-time = "2025-10-01T18:01:00.168Z" }, - { url = "https://files.pythonhosted.org/packages/0a/df/3daf95302ff49beff4230065e3178cd40e71294968e8d55baf4a9e560814/llvmlite-0.45.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2f3377a6db40f563058c9515dedcc8a3e562d8693a106a28f2ddccf2c8fcf6ca", size = 55140958, upload-time = "2025-10-01T18:02:11.199Z" }, - { url = "https://files.pythonhosted.org/packages/a4/56/4c0d503fe03bac820ecdeb14590cf9a248e120f483bcd5c009f2534f23f0/llvmlite-0.45.1-cp311-cp311-win_amd64.whl", hash = "sha256:f9c272682d91e0d57f2a76c6d9ebdfccc603a01828cdbe3d15273bdca0c3363a", size = 38132232, upload-time = "2025-10-01T18:04:52.181Z" }, - { url = "https://files.pythonhosted.org/packages/e2/7c/82cbd5c656e8991bcc110c69d05913be2229302a92acb96109e166ae31fb/llvmlite-0.45.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:28e763aba92fe9c72296911e040231d486447c01d4f90027c8e893d89d49b20e", size = 43043524, upload-time = "2025-10-01T18:03:30.666Z" }, - { url = "https://files.pythonhosted.org/packages/9d/bc/5314005bb2c7ee9f33102c6456c18cc81745d7055155d1218f1624463774/llvmlite-0.45.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1a53f4b74ee9fd30cb3d27d904dadece67a7575198bd80e687ee76474620735f", size = 37253123, upload-time = "2025-10-01T18:04:18.177Z" }, - { url = "https://files.pythonhosted.org/packages/96/76/0f7154952f037cb320b83e1c952ec4a19d5d689cf7d27cb8a26887d7bbc1/llvmlite-0.45.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5b3796b1b1e1c14dcae34285d2f4ea488402fbd2c400ccf7137603ca3800864f", size = 56288211, upload-time = "2025-10-01T18:01:24.079Z" }, - { url = "https://files.pythonhosted.org/packages/00/b1/0b581942be2683ceb6862d558979e87387e14ad65a1e4db0e7dd671fa315/llvmlite-0.45.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:779e2f2ceefef0f4368548685f0b4adde34e5f4b457e90391f570a10b348d433", size = 55140958, upload-time = "2025-10-01T18:02:30.482Z" }, - { url = "https://files.pythonhosted.org/packages/33/94/9ba4ebcf4d541a325fd8098ddc073b663af75cc8b065b6059848f7d4dce7/llvmlite-0.45.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e6c9949baf25d9aa9cd7cf0f6d011b9ca660dd17f5ba2b23bdbdb77cc86b116", size = 38132231, upload-time = "2025-10-01T18:05:03.664Z" }, + { url = "https://files.pythonhosted.org/packages/7a/a1/2ad4b2367915faeebe8447f0a057861f646dbf5fbbb3561db42c65659cf3/llvmlite-0.46.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:82f3d39b16f19aa1a56d5fe625883a6ab600d5cc9ea8906cca70ce94cabba067", size = 37232766, upload-time = "2025-12-08T18:14:48.836Z" }, + { url = "https://files.pythonhosted.org/packages/12/b5/99cf8772fdd846c07da4fd70f07812a3c8fd17ea2409522c946bb0f2b277/llvmlite-0.46.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a3df43900119803bbc52720e758c76f316a9a0f34612a886862dfe0a5591a17e", size = 56275175, upload-time = "2025-12-08T18:14:51.604Z" }, + { url = "https://files.pythonhosted.org/packages/38/f2/ed806f9c003563732da156139c45d970ee435bd0bfa5ed8de87ba972b452/llvmlite-0.46.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:de183fefc8022d21b0aa37fc3e90410bc3524aed8617f0ff76732fc6c3af5361", size = 55128630, upload-time = "2025-12-08T18:14:55.107Z" }, + { url = "https://files.pythonhosted.org/packages/19/0c/8f5a37a65fc9b7b17408508145edd5f86263ad69c19d3574e818f533a0eb/llvmlite-0.46.0-cp311-cp311-win_amd64.whl", hash = "sha256:e8b10bc585c58bdffec9e0c309bb7d51be1f2f15e169a4b4d42f2389e431eb93", size = 38138652, upload-time = "2025-12-08T18:14:58.171Z" }, + { url = "https://files.pythonhosted.org/packages/2b/f8/4db016a5e547d4e054ff2f3b99203d63a497465f81ab78ec8eb2ff7b2304/llvmlite-0.46.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b9588ad4c63b4f0175a3984b85494f0c927c6b001e3a246a3a7fb3920d9a137", size = 37232767, upload-time = "2025-12-08T18:15:00.737Z" }, + { url = "https://files.pythonhosted.org/packages/aa/85/4890a7c14b4fa54400945cb52ac3cd88545bbdb973c440f98ca41591cdc5/llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3535bd2bb6a2d7ae4012681ac228e5132cdb75fefb1bcb24e33f2f3e0c865ed4", size = 56275176, upload-time = "2025-12-08T18:15:03.936Z" }, + { url = "https://files.pythonhosted.org/packages/6a/07/3d31d39c1a1a08cd5337e78299fca77e6aebc07c059fbd0033e3edfab45c/llvmlite-0.46.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cbfd366e60ff87ea6cc62f50bc4cd800ebb13ed4c149466f50cf2163a473d1e", size = 55128630, upload-time = "2025-12-08T18:15:07.196Z" }, + { url = "https://files.pythonhosted.org/packages/2a/6b/d139535d7590a1bba1ceb68751bef22fadaa5b815bbdf0e858e3875726b2/llvmlite-0.46.0-cp312-cp312-win_amd64.whl", hash = "sha256:398b39db462c39563a97b912d4f2866cd37cba60537975a09679b28fbbc0fb38", size = 38138940, upload-time = "2025-12-08T18:15:10.162Z" }, ] [[package]] @@ -3778,11 +3916,11 @@ wheels = [ [[package]] name = "markdown" -version = "3.5.2" +version = "3.10.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/11/28/c5441a6642681d92de56063fa7984df56f783d3f1eba518dc3e7a253b606/Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8", size = 349398, upload-time = "2024-01-10T15:19:38.261Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/f4/69fa6ed85ae003c2378ffa8f6d2e3234662abd02c10d216c0ba96081a238/markdown-3.10.2.tar.gz", hash = "sha256:994d51325d25ad8aa7ce4ebaec003febcce822c3f8c911e3b17c52f7f589f950", size = 368805, upload-time = "2026-02-09T14:57:26.942Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/f4/f0031854de10a0bc7821ef9fca0b92ca0d7aa6fbfbf504c5473ba825e49c/Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd", size = 103870, upload-time = "2024-01-10T15:19:36.071Z" }, + { url = "https://files.pythonhosted.org/packages/de/1f/77fa3081e4f66ca3576c896ae5d31c3002ac6607f9747d2e3aa49227e464/markdown-3.10.2-py3-none-any.whl", hash = "sha256:e91464b71ae3ee7afd3017d9f358ef0baf158fd9a298db92f1d4761133824c36", size = 108180, upload-time = "2026-02-09T14:57:25.787Z" }, ] [[package]] @@ -3848,23 +3986,9 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/38/89ba8ad64ae25be8de66a6d463314cf1eb366222074cfda9ee839c56a4b4/mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8", size = 9979, upload-time = "2022-08-14T12:40:09.779Z" }, ] -[[package]] -name = "milvus-lite" -version = "2.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "tqdm" }, -] -wheels = [ - { url = "https://files.pythonhosted.org/packages/a9/b2/acc5024c8e8b6a0b034670b8e8af306ebd633ede777dcbf557eac4785937/milvus_lite-2.5.1-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6b014453200ba977be37ba660cb2d021030375fa6a35bc53c2e1d92980a0c512", size = 27934713, upload-time = "2025-06-30T04:23:37.028Z" }, - { url = "https://files.pythonhosted.org/packages/9b/2e/746f5bb1d6facd1e73eb4af6dd5efda11125b0f29d7908a097485ca6cad9/milvus_lite-2.5.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a2e031088bf308afe5f8567850412d618cfb05a65238ed1a6117f60decccc95a", size = 24421451, upload-time = "2025-06-30T04:23:51.747Z" }, - { url = "https://files.pythonhosted.org/packages/2e/cf/3d1fee5c16c7661cf53977067a34820f7269ed8ba99fe9cf35efc1700866/milvus_lite-2.5.1-py3-none-manylinux2014_aarch64.whl", hash = "sha256:a13277e9bacc6933dea172e42231f7e6135bd3bdb073dd2688ee180418abd8d9", size = 45337093, upload-time = "2025-06-30T04:24:06.706Z" }, - { url = "https://files.pythonhosted.org/packages/d3/82/41d9b80f09b82e066894d9b508af07b7b0fa325ce0322980674de49106a0/milvus_lite-2.5.1-py3-none-manylinux2014_x86_64.whl", hash = "sha256:25ce13f4b8d46876dd2b7ac8563d7d8306da7ff3999bb0d14b116b30f71d706c", size = 55263911, upload-time = "2025-06-30T04:24:19.434Z" }, -] - [[package]] name = "mlflow-skinny" -version = "3.6.0" +version = "3.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -3887,49 +4011,49 @@ dependencies = [ { name = "typing-extensions" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/8e/2a2d0cd5b1b985c5278202805f48aae6f2adc3ddc0fce3385ec50e07e258/mlflow_skinny-3.6.0.tar.gz", hash = "sha256:cc04706b5b6faace9faf95302a6e04119485e1bfe98ddc9b85b81984e80944b6", size = 1963286, upload-time = "2025-11-07T18:33:52.596Z" } +sdist = { url = "https://files.pythonhosted.org/packages/71/65/5b2c28e74c167ba8a5afe59399ef44291a0f140487f534db1900f09f59f6/mlflow_skinny-3.10.1.tar.gz", hash = "sha256:3d1c5c30245b6e7065b492b09dd47be7528e0a14c4266b782fe58f9bcd1e0be0", size = 2478631, upload-time = "2026-03-05T10:49:01.47Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/78/e8fdc3e1708bdfd1eba64f41ce96b461cae1b505aa08b69352ac99b4caa4/mlflow_skinny-3.6.0-py3-none-any.whl", hash = "sha256:c83b34fce592acb2cc6bddcb507587a6d9ef3f590d9e7a8658c85e0980596d78", size = 2364629, upload-time = "2025-11-07T18:33:50.744Z" }, + { url = "https://files.pythonhosted.org/packages/4b/52/17460157271e70b0d8444d27f8ad730ef7d95fb82fac59dc19f11519b921/mlflow_skinny-3.10.1-py3-none-any.whl", hash = "sha256:df1dd507d8ddadf53bfab2423c76cdcafc235cd1a46921a06d1a6b4dd04b023c", size = 2987098, upload-time = "2026-03-05T10:48:59.566Z" }, ] [[package]] name = "mmh3" -version = "5.2.0" +version = "5.2.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a7/af/f28c2c2f51f31abb4725f9a64bc7863d5f491f6539bd26aee2a1d21a649e/mmh3-5.2.0.tar.gz", hash = "sha256:1efc8fec8478e9243a78bb993422cf79f8ff85cb4cf6b79647480a31e0d950a8", size = 33582, upload-time = "2025-07-29T07:43:48.49Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/1a/edb23803a168f070ded7a3014c6d706f63b90c84ccc024f89d794a3b7a6d/mmh3-5.2.1.tar.gz", hash = "sha256:bbea5b775f0ac84945191fb83f845a6fd9a21a03ea7f2e187defac7e401616ad", size = 33775, upload-time = "2026-03-05T15:55:57.716Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/87/399567b3796e134352e11a8b973cd470c06b2ecfad5468fe580833be442b/mmh3-5.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7901c893e704ee3c65f92d39b951f8f34ccf8e8566768c58103fb10e55afb8c1", size = 56107, upload-time = "2025-07-29T07:41:57.07Z" }, - { url = "https://files.pythonhosted.org/packages/c3/09/830af30adf8678955b247d97d3d9543dd2fd95684f3cd41c0cd9d291da9f/mmh3-5.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4a5f5536b1cbfa72318ab3bfc8a8188b949260baed186b75f0abc75b95d8c051", size = 40635, upload-time = "2025-07-29T07:41:57.903Z" }, - { url = "https://files.pythonhosted.org/packages/07/14/eaba79eef55b40d653321765ac5e8f6c9ac38780b8a7c2a2f8df8ee0fb72/mmh3-5.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:cedac4f4054b8f7859e5aed41aaa31ad03fce6851901a7fdc2af0275ac533c10", size = 40078, upload-time = "2025-07-29T07:41:58.772Z" }, - { url = "https://files.pythonhosted.org/packages/bb/26/83a0f852e763f81b2265d446b13ed6d49ee49e1fc0c47b9655977e6f3d81/mmh3-5.2.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:eb756caf8975882630ce4e9fbbeb9d3401242a72528230422c9ab3a0d278e60c", size = 97262, upload-time = "2025-07-29T07:41:59.678Z" }, - { url = "https://files.pythonhosted.org/packages/00/7d/b7133b10d12239aeaebf6878d7eaf0bf7d3738c44b4aba3c564588f6d802/mmh3-5.2.0-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:097e13c8b8a66c5753c6968b7640faefe85d8e38992703c1f666eda6ef4c3762", size = 103118, upload-time = "2025-07-29T07:42:01.197Z" }, - { url = "https://files.pythonhosted.org/packages/7b/3e/62f0b5dce2e22fd5b7d092aba285abd7959ea2b17148641e029f2eab1ffa/mmh3-5.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a7c0c7845566b9686480e6a7e9044db4afb60038d5fabd19227443f0104eeee4", size = 106072, upload-time = "2025-07-29T07:42:02.601Z" }, - { url = "https://files.pythonhosted.org/packages/66/84/ea88bb816edfe65052c757a1c3408d65c4201ddbd769d4a287b0f1a628b2/mmh3-5.2.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:61ac226af521a572700f863d6ecddc6ece97220ce7174e311948ff8c8919a363", size = 112925, upload-time = "2025-07-29T07:42:03.632Z" }, - { url = "https://files.pythonhosted.org/packages/2e/13/c9b1c022807db575fe4db806f442d5b5784547e2e82cff36133e58ea31c7/mmh3-5.2.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:582f9dbeefe15c32a5fa528b79b088b599a1dfe290a4436351c6090f90ddebb8", size = 120583, upload-time = "2025-07-29T07:42:04.991Z" }, - { url = "https://files.pythonhosted.org/packages/8a/5f/0e2dfe1a38f6a78788b7eb2b23432cee24623aeabbc907fed07fc17d6935/mmh3-5.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2ebfc46b39168ab1cd44670a32ea5489bcbc74a25795c61b6d888c5c2cf654ed", size = 99127, upload-time = "2025-07-29T07:42:05.929Z" }, - { url = "https://files.pythonhosted.org/packages/77/27/aefb7d663b67e6a0c4d61a513c83e39ba2237e8e4557fa7122a742a23de5/mmh3-5.2.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1556e31e4bd0ac0c17eaf220be17a09c171d7396919c3794274cb3415a9d3646", size = 98544, upload-time = "2025-07-29T07:42:06.87Z" }, - { url = "https://files.pythonhosted.org/packages/ab/97/a21cc9b1a7c6e92205a1b5fa030cdf62277d177570c06a239eca7bd6dd32/mmh3-5.2.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:81df0dae22cd0da87f1c978602750f33d17fb3d21fb0f326c89dc89834fea79b", size = 106262, upload-time = "2025-07-29T07:42:07.804Z" }, - { url = "https://files.pythonhosted.org/packages/43/18/db19ae82ea63c8922a880e1498a75342311f8aa0c581c4dd07711473b5f7/mmh3-5.2.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:eba01ec3bd4a49b9ac5ca2bc6a73ff5f3af53374b8556fcc2966dd2af9eb7779", size = 109824, upload-time = "2025-07-29T07:42:08.735Z" }, - { url = "https://files.pythonhosted.org/packages/9f/f5/41dcf0d1969125fc6f61d8618b107c79130b5af50b18a4651210ea52ab40/mmh3-5.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e9a011469b47b752e7d20de296bb34591cdfcbe76c99c2e863ceaa2aa61113d2", size = 97255, upload-time = "2025-07-29T07:42:09.706Z" }, - { url = "https://files.pythonhosted.org/packages/32/b3/cce9eaa0efac1f0e735bb178ef9d1d2887b4927fe0ec16609d5acd492dda/mmh3-5.2.0-cp311-cp311-win32.whl", hash = "sha256:bc44fc2b886243d7c0d8daeb37864e16f232e5b56aaec27cc781d848264cfd28", size = 40779, upload-time = "2025-07-29T07:42:10.546Z" }, - { url = "https://files.pythonhosted.org/packages/7c/e9/3fa0290122e6d5a7041b50ae500b8a9f4932478a51e48f209a3879fe0b9b/mmh3-5.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:8ebf241072cf2777a492d0e09252f8cc2b3edd07dfdb9404b9757bffeb4f2cee", size = 41549, upload-time = "2025-07-29T07:42:11.399Z" }, - { url = "https://files.pythonhosted.org/packages/3a/54/c277475b4102588e6f06b2e9095ee758dfe31a149312cdbf62d39a9f5c30/mmh3-5.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:b5f317a727bba0e633a12e71228bc6a4acb4f471a98b1c003163b917311ea9a9", size = 39336, upload-time = "2025-07-29T07:42:12.209Z" }, - { url = "https://files.pythonhosted.org/packages/bf/6a/d5aa7edb5c08e0bd24286c7d08341a0446f9a2fbbb97d96a8a6dd81935ee/mmh3-5.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:384eda9361a7bf83a85e09447e1feafe081034af9dd428893701b959230d84be", size = 56141, upload-time = "2025-07-29T07:42:13.456Z" }, - { url = "https://files.pythonhosted.org/packages/08/49/131d0fae6447bc4a7299ebdb1a6fb9d08c9f8dcf97d75ea93e8152ddf7ab/mmh3-5.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2c9da0d568569cc87315cb063486d761e38458b8ad513fedd3dc9263e1b81bcd", size = 40681, upload-time = "2025-07-29T07:42:14.306Z" }, - { url = "https://files.pythonhosted.org/packages/8f/6f/9221445a6bcc962b7f5ff3ba18ad55bba624bacdc7aa3fc0a518db7da8ec/mmh3-5.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86d1be5d63232e6eb93c50881aea55ff06eb86d8e08f9b5417c8c9b10db9db96", size = 40062, upload-time = "2025-07-29T07:42:15.08Z" }, - { url = "https://files.pythonhosted.org/packages/1e/d4/6bb2d0fef81401e0bb4c297d1eb568b767de4ce6fc00890bc14d7b51ecc4/mmh3-5.2.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:bf7bee43e17e81671c447e9c83499f53d99bf440bc6d9dc26a841e21acfbe094", size = 97333, upload-time = "2025-07-29T07:42:16.436Z" }, - { url = "https://files.pythonhosted.org/packages/44/e0/ccf0daff8134efbb4fbc10a945ab53302e358c4b016ada9bf97a6bdd50c1/mmh3-5.2.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:7aa18cdb58983ee660c9c400b46272e14fa253c675ed963d3812487f8ca42037", size = 103310, upload-time = "2025-07-29T07:42:17.796Z" }, - { url = "https://files.pythonhosted.org/packages/02/63/1965cb08a46533faca0e420e06aff8bbaf9690a6f0ac6ae6e5b2e4544687/mmh3-5.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae9d032488fcec32d22be6542d1a836f00247f40f320844dbb361393b5b22773", size = 106178, upload-time = "2025-07-29T07:42:19.281Z" }, - { url = "https://files.pythonhosted.org/packages/c2/41/c883ad8e2c234013f27f92061200afc11554ea55edd1bcf5e1accd803a85/mmh3-5.2.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:e1861fb6b1d0453ed7293200139c0a9011eeb1376632e048e3766945b13313c5", size = 113035, upload-time = "2025-07-29T07:42:20.356Z" }, - { url = "https://files.pythonhosted.org/packages/df/b5/1ccade8b1fa625d634a18bab7bf08a87457e09d5ec8cf83ca07cbea9d400/mmh3-5.2.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:99bb6a4d809aa4e528ddfe2c85dd5239b78b9dd14be62cca0329db78505e7b50", size = 120784, upload-time = "2025-07-29T07:42:21.377Z" }, - { url = "https://files.pythonhosted.org/packages/77/1c/919d9171fcbdcdab242e06394464ccf546f7d0f3b31e0d1e3a630398782e/mmh3-5.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1f8d8b627799f4e2fcc7c034fed8f5f24dc7724ff52f69838a3d6d15f1ad4765", size = 99137, upload-time = "2025-07-29T07:42:22.344Z" }, - { url = "https://files.pythonhosted.org/packages/66/8a/1eebef5bd6633d36281d9fc83cf2e9ba1ba0e1a77dff92aacab83001cee4/mmh3-5.2.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b5995088dd7023d2d9f310a0c67de5a2b2e06a570ecfd00f9ff4ab94a67cde43", size = 98664, upload-time = "2025-07-29T07:42:23.269Z" }, - { url = "https://files.pythonhosted.org/packages/13/41/a5d981563e2ee682b21fb65e29cc0f517a6734a02b581359edd67f9d0360/mmh3-5.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1a5f4d2e59d6bba8ef01b013c472741835ad961e7c28f50c82b27c57748744a4", size = 106459, upload-time = "2025-07-29T07:42:24.238Z" }, - { url = "https://files.pythonhosted.org/packages/24/31/342494cd6ab792d81e083680875a2c50fa0c5df475ebf0b67784f13e4647/mmh3-5.2.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:fd6e6c3d90660d085f7e73710eab6f5545d4854b81b0135a3526e797009dbda3", size = 110038, upload-time = "2025-07-29T07:42:25.629Z" }, - { url = "https://files.pythonhosted.org/packages/28/44/efda282170a46bb4f19c3e2b90536513b1d821c414c28469a227ca5a1789/mmh3-5.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c4a2f3d83879e3de2eb8cbf562e71563a8ed15ee9b9c2e77ca5d9f73072ac15c", size = 97545, upload-time = "2025-07-29T07:42:27.04Z" }, - { url = "https://files.pythonhosted.org/packages/68/8f/534ae319c6e05d714f437e7206f78c17e66daca88164dff70286b0e8ea0c/mmh3-5.2.0-cp312-cp312-win32.whl", hash = "sha256:2421b9d665a0b1ad724ec7332fb5a98d075f50bc51a6ff854f3a1882bd650d49", size = 40805, upload-time = "2025-07-29T07:42:28.032Z" }, - { url = "https://files.pythonhosted.org/packages/b8/f6/f6abdcfefcedab3c964868048cfe472764ed358c2bf6819a70dd4ed4ed3a/mmh3-5.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:72d80005b7634a3a2220f81fbeb94775ebd12794623bb2e1451701ea732b4aa3", size = 41597, upload-time = "2025-07-29T07:42:28.894Z" }, - { url = "https://files.pythonhosted.org/packages/15/fd/f7420e8cbce45c259c770cac5718badf907b302d3a99ec587ba5ce030237/mmh3-5.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:3d6bfd9662a20c054bc216f861fa330c2dac7c81e7fb8307b5e32ab5b9b4d2e0", size = 39350, upload-time = "2025-07-29T07:42:29.794Z" }, + { url = "https://files.pythonhosted.org/packages/65/d7/3312a59df3c1cdd783f4cf0c4ee8e9decff9c5466937182e4cc7dbbfe6c5/mmh3-5.2.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:dae0f0bd7d30c0ad61b9a504e8e272cb8391eed3f1587edf933f4f6b33437450", size = 56082, upload-time = "2026-03-05T15:53:59.702Z" }, + { url = "https://files.pythonhosted.org/packages/61/96/6f617baa098ca0d2989bfec6d28b5719532cd8d8848782662f5b755f657f/mmh3-5.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9aeaf53eaa075dd63e81512522fd180097312fb2c9f476333309184285c49ce0", size = 40458, upload-time = "2026-03-05T15:54:01.548Z" }, + { url = "https://files.pythonhosted.org/packages/c1/b4/9cd284bd6062d711e13d26c04d4778ab3f690c1c38a4563e3c767ec8802e/mmh3-5.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:0634581290e6714c068f4aa24020acf7880927d1f0084fa753d9799ae9610082", size = 40079, upload-time = "2026-03-05T15:54:02.743Z" }, + { url = "https://files.pythonhosted.org/packages/f6/09/a806334ce1d3d50bf782b95fcee8b3648e1e170327d4bb7b4bad2ad7d956/mmh3-5.2.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:e080c0637aea036f35507e803a4778f119a9b436617694ae1c5c366805f1e997", size = 97242, upload-time = "2026-03-05T15:54:04.536Z" }, + { url = "https://files.pythonhosted.org/packages/ee/93/723e317dd9e041c4dc4566a2eb53b01ad94de31750e0b834f1643905e97c/mmh3-5.2.1-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:db0562c5f71d18596dcd45e854cf2eeba27d7543e1a3acdafb7eef728f7fe85d", size = 103082, upload-time = "2026-03-05T15:54:06.387Z" }, + { url = "https://files.pythonhosted.org/packages/61/b5/f96121e69cc48696075071531cf574f112e1ffd08059f4bffb41210e6fc5/mmh3-5.2.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1d9f9a3ce559a5267014b04b82956993270f63ec91765e13e9fd73daf2d2738e", size = 106054, upload-time = "2026-03-05T15:54:07.506Z" }, + { url = "https://files.pythonhosted.org/packages/82/49/192b987ec48d0b2aecf8ac285a9b11fbc00030f6b9c694664ae923458dde/mmh3-5.2.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:960b1b3efa39872ac8b6cc3a556edd6fb90ed74f08c9c45e028f1005b26aa55d", size = 112910, upload-time = "2026-03-05T15:54:09.403Z" }, + { url = "https://files.pythonhosted.org/packages/cf/a1/03e91fd334ed0144b83343a76eb11f17434cd08f746401488cfeafb2d241/mmh3-5.2.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d30b650595fdbe32366b94cb14f30bb2b625e512bd4e1df00611f99dc5c27fd4", size = 120551, upload-time = "2026-03-05T15:54:10.587Z" }, + { url = "https://files.pythonhosted.org/packages/93/b9/b89a71d2ff35c3a764d1c066c7313fc62c7cc48fa48a4b3b0304a4a0146f/mmh3-5.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:82f3802bfc4751f420d591c5c864de538b71cea117fce67e4595c2afede08a15", size = 99096, upload-time = "2026-03-05T15:54:11.76Z" }, + { url = "https://files.pythonhosted.org/packages/36/b5/613772c1c6ed5f7b63df55eb131e887cc43720fec392777b95a79d34e640/mmh3-5.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:915e7a2418f10bd1151b1953df06d896db9783c9cfdb9a8ee1f9b3a4331ab503", size = 98524, upload-time = "2026-03-05T15:54:13.122Z" }, + { url = "https://files.pythonhosted.org/packages/5e/0e/1524566fe8eaf871e4f7bc44095929fcd2620488f402822d848df19d679c/mmh3-5.2.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:fc78739b5ec6e4fb02301984a3d442a91406e7700efbe305071e7fd1c78278f2", size = 106239, upload-time = "2026-03-05T15:54:14.601Z" }, + { url = "https://files.pythonhosted.org/packages/04/94/21adfa7d90a7a697137ad6de33eeff6445420ca55e433a5d4919c79bc3b5/mmh3-5.2.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:41aac7002a749f08727cb91babff1daf8deac317c0b1f317adc69be0e6c375d1", size = 109797, upload-time = "2026-03-05T15:54:15.819Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e6/1aacc3a219e1aa62fa65669995d4a3562b35be5200ec03680c7e4bec9676/mmh3-5.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9d8089d853c7963a8ce87fff93e2a67075c0bc08684a08ea6ad13577c38ffc38", size = 97228, upload-time = "2026-03-05T15:54:16.992Z" }, + { url = "https://files.pythonhosted.org/packages/f1/b9/5e4cca8dcccf298add0a27f3c357bc8cf8baf821d35cdc6165e4bd5a48b0/mmh3-5.2.1-cp311-cp311-win32.whl", hash = "sha256:baeb47635cb33375dee4924cd93d7f5dcaa786c740b08423b0209b824a1ee728", size = 40751, upload-time = "2026-03-05T15:54:18.714Z" }, + { url = "https://files.pythonhosted.org/packages/72/fc/5b11d49247f499bcda591171e9cf3b6ee422b19e70aa2cef2e0ae65ca3b9/mmh3-5.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:1e4ecee40ba19e6975e1120829796770325841c2f153c0e9aecca927194c6a2a", size = 41517, upload-time = "2026-03-05T15:54:19.764Z" }, + { url = "https://files.pythonhosted.org/packages/8a/5f/2a511ee8a1c2a527c77726d5231685b72312c5a1a1b7639ad66a9652aa84/mmh3-5.2.1-cp311-cp311-win_arm64.whl", hash = "sha256:c302245fd6c33d96bd169c7ccf2513c20f4c1e417c07ce9dce107c8bc3f8411f", size = 39287, upload-time = "2026-03-05T15:54:20.904Z" }, + { url = "https://files.pythonhosted.org/packages/92/94/bc5c3b573b40a328c4d141c20e399039ada95e5e2a661df3425c5165fd84/mmh3-5.2.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:0cc21533878e5586b80d74c281d7f8da7932bc8ace50b8d5f6dbf7e3935f63f1", size = 56087, upload-time = "2026-03-05T15:54:21.92Z" }, + { url = "https://files.pythonhosted.org/packages/f6/80/64a02cc3e95c3af0aaa2590849d9ed24a9f14bb93537addde688e039b7c3/mmh3-5.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4eda76074cfca2787c8cf1bec603eaebdddd8b061ad5502f85cddae998d54f00", size = 40500, upload-time = "2026-03-05T15:54:22.953Z" }, + { url = "https://files.pythonhosted.org/packages/8b/72/e6d6602ce18adf4ddcd0e48f2e13590cc92a536199e52109f46f259d3c46/mmh3-5.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:eee884572b06bbe8a2b54f424dbd996139442cf83c76478e1ec162512e0dd2c7", size = 40034, upload-time = "2026-03-05T15:54:23.943Z" }, + { url = "https://files.pythonhosted.org/packages/59/c2/bf4537a8e58e21886ef16477041238cab5095c836496e19fafc34b7445d2/mmh3-5.2.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0d0b7e803191db5f714d264044e06189c8ccd3219e936cc184f07106bd17fd7b", size = 97292, upload-time = "2026-03-05T15:54:25.335Z" }, + { url = "https://files.pythonhosted.org/packages/e5/e2/51ed62063b44d10b06d975ac87af287729eeb5e3ed9772f7584a17983e90/mmh3-5.2.1-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8e6c219e375f6341d0959af814296372d265a8ca1af63825f65e2e87c618f006", size = 103274, upload-time = "2026-03-05T15:54:26.44Z" }, + { url = "https://files.pythonhosted.org/packages/75/ce/12a7524dca59eec92e5b31fdb13ede1e98eda277cf2b786cf73bfbc24e81/mmh3-5.2.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:26fb5b9c3946bf7f1daed7b37e0c03898a6f062149127570f8ede346390a0825", size = 106158, upload-time = "2026-03-05T15:54:28.578Z" }, + { url = "https://files.pythonhosted.org/packages/86/1f/d3ba6dd322d01ab5d44c46c8f0c38ab6bbbf9b5e20e666dfc05bf4a23604/mmh3-5.2.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3c38d142c706201db5b2345166eeef1e7740e3e2422b470b8ba5c8727a9b4c7a", size = 113005, upload-time = "2026-03-05T15:54:29.767Z" }, + { url = "https://files.pythonhosted.org/packages/b6/a9/15d6b6f913294ea41b44d901741298e3718e1cb89ee626b3694625826a43/mmh3-5.2.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:50885073e2909251d4718634a191c49ae5f527e5e1736d738e365c3e8be8f22b", size = 120744, upload-time = "2026-03-05T15:54:30.931Z" }, + { url = "https://files.pythonhosted.org/packages/76/b3/70b73923fd0284c439860ff5c871b20210dfdbe9a6b9dd0ee6496d77f174/mmh3-5.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b3f99e1756fc48ad507b95e5d86f2fb21b3d495012ff13e6592ebac14033f166", size = 99111, upload-time = "2026-03-05T15:54:32.353Z" }, + { url = "https://files.pythonhosted.org/packages/dd/38/99f7f75cd27d10d8b899a1caafb9d531f3903e4d54d572220e3d8ac35e89/mmh3-5.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:62815d2c67f2dd1be76a253d88af4e1da19aeaa1820146dec52cf8bee2958b16", size = 98623, upload-time = "2026-03-05T15:54:33.801Z" }, + { url = "https://files.pythonhosted.org/packages/fd/68/6e292c0853e204c44d2f03ea5f090be3317a0e2d9417ecb62c9eb27687df/mmh3-5.2.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8f767ba0911602ddef289404e33835a61168314ebd3c729833db2ed685824211", size = 106437, upload-time = "2026-03-05T15:54:35.177Z" }, + { url = "https://files.pythonhosted.org/packages/dd/c6/fedd7284c459cfb58721d461fcf5607a4c1f5d9ab195d113d51d10164d16/mmh3-5.2.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:67e41a497bac88cc1de96eeba56eeb933c39d54bc227352f8455aa87c4ca4000", size = 110002, upload-time = "2026-03-05T15:54:36.673Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ac/ca8e0c19a34f5b71390171d2ff0b9f7f187550d66801a731bb68925126a4/mmh3-5.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3d74a03fb57757ece25aa4b3c1c60157a1cece37a020542785f942e2f827eed5", size = 97507, upload-time = "2026-03-05T15:54:37.804Z" }, + { url = "https://files.pythonhosted.org/packages/df/94/6ebb9094cfc7ac5e7950776b9d13a66bb4a34f83814f32ba2abc9494fc68/mmh3-5.2.1-cp312-cp312-win32.whl", hash = "sha256:7374d6e3ef72afe49697ecd683f3da12f4fc06af2d75433d0580c6746d2fa025", size = 40773, upload-time = "2026-03-05T15:54:40.077Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3c/cd3527198cf159495966551c84a5f36805a10ac17b294f41f67b83f6a4d6/mmh3-5.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:3a9fed49c6ce4ed7e73f13182760c65c816da006debe67f37635580dfb0fae00", size = 41560, upload-time = "2026-03-05T15:54:41.148Z" }, + { url = "https://files.pythonhosted.org/packages/15/96/6fe5ebd0f970a076e3ed5512871ce7569447b962e96c125528a2f9724470/mmh3-5.2.1-cp312-cp312-win_arm64.whl", hash = "sha256:bbfcb95d9a744e6e2827dfc66ad10e1020e0cac255eb7f85652832d5a264c2fc", size = 39313, upload-time = "2026-03-05T15:54:42.171Z" }, ] [[package]] @@ -3946,15 +4070,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0d/e7/514f5cf5909f96adf09b78146a9e5c92f82abcc212bc3f88456bf2640c23/mo_vector-0.1.13-py3-none-any.whl", hash = "sha256:f7d619acc3e92ed59631e6b3a12508240e22cf428c87daf022c0d87fbd5da459", size = 20091, upload-time = "2025-06-18T09:27:26.899Z" }, ] -[[package]] -name = "more-itertools" -version = "10.8.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ea/5d/38b681d3fce7a266dd9ab73c66959406d565b3e85f21d5e66e1181d93721/more_itertools-10.8.0.tar.gz", hash = "sha256:f638ddf8a1a0d134181275fb5d58b086ead7c6a72429ad725c67503f13ba30bd", size = 137431, upload-time = "2025-09-02T15:23:11.018Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a4/8e/469e5a4a2f5855992e425f3cb33804cc07bf18d48f2db061aec61ce50270/more_itertools-10.8.0-py3-none-any.whl", hash = "sha256:52d4362373dcf7c52546bc4af9a86ee7c4579df9a8dc268be0a2f949d376cc9b", size = 69667, upload-time = "2025-09-02T15:23:09.635Z" }, -] - [[package]] name = "mpmath" version = "1.3.0" @@ -3966,16 +4081,16 @@ wheels = [ [[package]] name = "msal" -version = "1.34.0" +version = "1.35.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, { name = "pyjwt", extra = ["crypto"] }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cf/0e/c857c46d653e104019a84f22d4494f2119b4fe9f896c92b4b864b3b045cc/msal-1.34.0.tar.gz", hash = "sha256:76ba83b716ea5a6d75b0279c0ac353a0e05b820ca1f6682c0eb7f45190c43c2f", size = 153961, upload-time = "2025-09-22T23:05:48.989Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3c/aa/5a646093ac218e4a329391d5a31e5092a89db7d2ef1637a90b82cd0b6f94/msal-1.35.1.tar.gz", hash = "sha256:70cac18ab80a053bff86219ba64cfe3da1f307c74b009e2da57ef040eb1b5656", size = 165658, upload-time = "2026-03-04T23:38:51.812Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/dc/18d48843499e278538890dc709e9ee3dea8375f8be8e82682851df1b48b5/msal-1.34.0-py3-none-any.whl", hash = "sha256:f669b1644e4950115da7a176441b0e13ec2975c29528d8b9e81316023676d6e1", size = 116987, upload-time = "2025-09-22T23:05:47.294Z" }, + { url = "https://files.pythonhosted.org/packages/96/86/16815fddf056ca998853c6dc525397edf0b43559bb4073a80d2bc7fe8009/msal-1.35.1-py3-none-any.whl", hash = "sha256:8f4e82f34b10c19e326ec69f44dc6b30171f2f7098f3720ea8a9f0c11832caa3", size = 119909, upload-time = "2026-03-04T23:38:50.452Z" }, ] [[package]] @@ -3992,104 +4107,126 @@ wheels = [ [[package]] name = "multidict" -version = "6.7.0" +version = "6.7.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/80/1e/5492c365f222f907de1039b91f922b93fa4f764c713ee858d235495d8f50/multidict-6.7.0.tar.gz", hash = "sha256:c6e99d9a65ca282e578dfea819cfa9c0a62b2499d8677392e09feaf305e9e6f5", size = 101834, upload-time = "2025-10-06T14:52:30.657Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1a/c2/c2d94cbe6ac1753f3fc980da97b3d930efe1da3af3c9f5125354436c073d/multidict-6.7.1.tar.gz", hash = "sha256:ec6652a1bee61c53a3e5776b6049172c53b6aaba34f18c9ad04f82712bac623d", size = 102010, upload-time = "2026-01-26T02:46:45.979Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/34/9e/5c727587644d67b2ed479041e4b1c58e30afc011e3d45d25bbe35781217c/multidict-6.7.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4d409aa42a94c0b3fa617708ef5276dfe81012ba6753a0370fcc9d0195d0a1fc", size = 76604, upload-time = "2025-10-06T14:48:54.277Z" }, - { url = "https://files.pythonhosted.org/packages/17/e4/67b5c27bd17c085a5ea8f1ec05b8a3e5cba0ca734bfcad5560fb129e70ca/multidict-6.7.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14c9e076eede3b54c636f8ce1c9c252b5f057c62131211f0ceeec273810c9721", size = 44715, upload-time = "2025-10-06T14:48:55.445Z" }, - { url = "https://files.pythonhosted.org/packages/4d/e1/866a5d77be6ea435711bef2a4291eed11032679b6b28b56b4776ab06ba3e/multidict-6.7.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4c09703000a9d0fa3c3404b27041e574cc7f4df4c6563873246d0e11812a94b6", size = 44332, upload-time = "2025-10-06T14:48:56.706Z" }, - { url = "https://files.pythonhosted.org/packages/31/61/0c2d50241ada71ff61a79518db85ada85fdabfcf395d5968dae1cbda04e5/multidict-6.7.0-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:a265acbb7bb33a3a2d626afbe756371dce0279e7b17f4f4eda406459c2b5ff1c", size = 245212, upload-time = "2025-10-06T14:48:58.042Z" }, - { url = "https://files.pythonhosted.org/packages/ac/e0/919666a4e4b57fff1b57f279be1c9316e6cdc5de8a8b525d76f6598fefc7/multidict-6.7.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:51cb455de290ae462593e5b1cb1118c5c22ea7f0d3620d9940bf695cea5a4bd7", size = 246671, upload-time = "2025-10-06T14:49:00.004Z" }, - { url = "https://files.pythonhosted.org/packages/a1/cc/d027d9c5a520f3321b65adea289b965e7bcbd2c34402663f482648c716ce/multidict-6.7.0-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:db99677b4457c7a5c5a949353e125ba72d62b35f74e26da141530fbb012218a7", size = 225491, upload-time = "2025-10-06T14:49:01.393Z" }, - { url = "https://files.pythonhosted.org/packages/75/c4/bbd633980ce6155a28ff04e6a6492dd3335858394d7bb752d8b108708558/multidict-6.7.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:f470f68adc395e0183b92a2f4689264d1ea4b40504a24d9882c27375e6662bb9", size = 257322, upload-time = "2025-10-06T14:49:02.745Z" }, - { url = "https://files.pythonhosted.org/packages/4c/6d/d622322d344f1f053eae47e033b0b3f965af01212de21b10bcf91be991fb/multidict-6.7.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0db4956f82723cc1c270de9c6e799b4c341d327762ec78ef82bb962f79cc07d8", size = 254694, upload-time = "2025-10-06T14:49:04.15Z" }, - { url = "https://files.pythonhosted.org/packages/a8/9f/78f8761c2705d4c6d7516faed63c0ebdac569f6db1bef95e0d5218fdc146/multidict-6.7.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3e56d780c238f9e1ae66a22d2adf8d16f485381878250db8d496623cd38b22bd", size = 246715, upload-time = "2025-10-06T14:49:05.967Z" }, - { url = "https://files.pythonhosted.org/packages/78/59/950818e04f91b9c2b95aab3d923d9eabd01689d0dcd889563988e9ea0fd8/multidict-6.7.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9d14baca2ee12c1a64740d4531356ba50b82543017f3ad6de0deb943c5979abb", size = 243189, upload-time = "2025-10-06T14:49:07.37Z" }, - { url = "https://files.pythonhosted.org/packages/7a/3d/77c79e1934cad2ee74991840f8a0110966d9599b3af95964c0cd79bb905b/multidict-6.7.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:295a92a76188917c7f99cda95858c822f9e4aae5824246bba9b6b44004ddd0a6", size = 237845, upload-time = "2025-10-06T14:49:08.759Z" }, - { url = "https://files.pythonhosted.org/packages/63/1b/834ce32a0a97a3b70f86437f685f880136677ac00d8bce0027e9fd9c2db7/multidict-6.7.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:39f1719f57adbb767ef592a50ae5ebb794220d1188f9ca93de471336401c34d2", size = 246374, upload-time = "2025-10-06T14:49:10.574Z" }, - { url = "https://files.pythonhosted.org/packages/23/ef/43d1c3ba205b5dec93dc97f3fba179dfa47910fc73aaaea4f7ceb41cec2a/multidict-6.7.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:0a13fb8e748dfc94749f622de065dd5c1def7e0d2216dba72b1d8069a389c6ff", size = 253345, upload-time = "2025-10-06T14:49:12.331Z" }, - { url = "https://files.pythonhosted.org/packages/6b/03/eaf95bcc2d19ead522001f6a650ef32811aa9e3624ff0ad37c445c7a588c/multidict-6.7.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:e3aa16de190d29a0ea1b48253c57d99a68492c8dd8948638073ab9e74dc9410b", size = 246940, upload-time = "2025-10-06T14:49:13.821Z" }, - { url = "https://files.pythonhosted.org/packages/e8/df/ec8a5fd66ea6cd6f525b1fcbb23511b033c3e9bc42b81384834ffa484a62/multidict-6.7.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a048ce45dcdaaf1defb76b2e684f997fb5abf74437b6cb7b22ddad934a964e34", size = 242229, upload-time = "2025-10-06T14:49:15.603Z" }, - { url = "https://files.pythonhosted.org/packages/8a/a2/59b405d59fd39ec86d1142630e9049243015a5f5291ba49cadf3c090c541/multidict-6.7.0-cp311-cp311-win32.whl", hash = "sha256:a90af66facec4cebe4181b9e62a68be65e45ac9b52b67de9eec118701856e7ff", size = 41308, upload-time = "2025-10-06T14:49:16.871Z" }, - { url = "https://files.pythonhosted.org/packages/32/0f/13228f26f8b882c34da36efa776c3b7348455ec383bab4a66390e42963ae/multidict-6.7.0-cp311-cp311-win_amd64.whl", hash = "sha256:95b5ffa4349df2887518bb839409bcf22caa72d82beec453216802f475b23c81", size = 46037, upload-time = "2025-10-06T14:49:18.457Z" }, - { url = "https://files.pythonhosted.org/packages/84/1f/68588e31b000535a3207fd3c909ebeec4fb36b52c442107499c18a896a2a/multidict-6.7.0-cp311-cp311-win_arm64.whl", hash = "sha256:329aa225b085b6f004a4955271a7ba9f1087e39dcb7e65f6284a988264a63912", size = 43023, upload-time = "2025-10-06T14:49:19.648Z" }, - { url = "https://files.pythonhosted.org/packages/c2/9e/9f61ac18d9c8b475889f32ccfa91c9f59363480613fc807b6e3023d6f60b/multidict-6.7.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:8a3862568a36d26e650a19bb5cbbba14b71789032aebc0423f8cc5f150730184", size = 76877, upload-time = "2025-10-06T14:49:20.884Z" }, - { url = "https://files.pythonhosted.org/packages/38/6f/614f09a04e6184f8824268fce4bc925e9849edfa654ddd59f0b64508c595/multidict-6.7.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:960c60b5849b9b4f9dcc9bea6e3626143c252c74113df2c1540aebce70209b45", size = 45467, upload-time = "2025-10-06T14:49:22.054Z" }, - { url = "https://files.pythonhosted.org/packages/b3/93/c4f67a436dd026f2e780c433277fff72be79152894d9fc36f44569cab1a6/multidict-6.7.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2049be98fb57a31b4ccf870bf377af2504d4ae35646a19037ec271e4c07998aa", size = 43834, upload-time = "2025-10-06T14:49:23.566Z" }, - { url = "https://files.pythonhosted.org/packages/7f/f5/013798161ca665e4a422afbc5e2d9e4070142a9ff8905e482139cd09e4d0/multidict-6.7.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0934f3843a1860dd465d38895c17fce1f1cb37295149ab05cd1b9a03afacb2a7", size = 250545, upload-time = "2025-10-06T14:49:24.882Z" }, - { url = "https://files.pythonhosted.org/packages/71/2f/91dbac13e0ba94669ea5119ba267c9a832f0cb65419aca75549fcf09a3dc/multidict-6.7.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b3e34f3a1b8131ba06f1a73adab24f30934d148afcd5f5de9a73565a4404384e", size = 258305, upload-time = "2025-10-06T14:49:26.778Z" }, - { url = "https://files.pythonhosted.org/packages/ef/b0/754038b26f6e04488b48ac621f779c341338d78503fb45403755af2df477/multidict-6.7.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:efbb54e98446892590dc2458c19c10344ee9a883a79b5cec4bc34d6656e8d546", size = 242363, upload-time = "2025-10-06T14:49:28.562Z" }, - { url = "https://files.pythonhosted.org/packages/87/15/9da40b9336a7c9fa606c4cf2ed80a649dffeb42b905d4f63a1d7eb17d746/multidict-6.7.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a35c5fc61d4f51eb045061e7967cfe3123d622cd500e8868e7c0c592a09fedc4", size = 268375, upload-time = "2025-10-06T14:49:29.96Z" }, - { url = "https://files.pythonhosted.org/packages/82/72/c53fcade0cc94dfaad583105fd92b3a783af2091eddcb41a6d5a52474000/multidict-6.7.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:29fe6740ebccba4175af1b9b87bf553e9c15cd5868ee967e010efcf94e4fd0f1", size = 269346, upload-time = "2025-10-06T14:49:31.404Z" }, - { url = "https://files.pythonhosted.org/packages/0d/e2/9baffdae21a76f77ef8447f1a05a96ec4bc0a24dae08767abc0a2fe680b8/multidict-6.7.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:123e2a72e20537add2f33a79e605f6191fba2afda4cbb876e35c1a7074298a7d", size = 256107, upload-time = "2025-10-06T14:49:32.974Z" }, - { url = "https://files.pythonhosted.org/packages/3c/06/3f06f611087dc60d65ef775f1fb5aca7c6d61c6db4990e7cda0cef9b1651/multidict-6.7.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b284e319754366c1aee2267a2036248b24eeb17ecd5dc16022095e747f2f4304", size = 253592, upload-time = "2025-10-06T14:49:34.52Z" }, - { url = "https://files.pythonhosted.org/packages/20/24/54e804ec7945b6023b340c412ce9c3f81e91b3bf5fa5ce65558740141bee/multidict-6.7.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:803d685de7be4303b5a657b76e2f6d1240e7e0a8aa2968ad5811fa2285553a12", size = 251024, upload-time = "2025-10-06T14:49:35.956Z" }, - { url = "https://files.pythonhosted.org/packages/14/48/011cba467ea0b17ceb938315d219391d3e421dfd35928e5dbdc3f4ae76ef/multidict-6.7.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c04a328260dfd5db8c39538f999f02779012268f54614902d0afc775d44e0a62", size = 251484, upload-time = "2025-10-06T14:49:37.631Z" }, - { url = "https://files.pythonhosted.org/packages/0d/2f/919258b43bb35b99fa127435cfb2d91798eb3a943396631ef43e3720dcf4/multidict-6.7.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8a19cdb57cd3df4cd865849d93ee14920fb97224300c88501f16ecfa2604b4e0", size = 263579, upload-time = "2025-10-06T14:49:39.502Z" }, - { url = "https://files.pythonhosted.org/packages/31/22/a0e884d86b5242b5a74cf08e876bdf299e413016b66e55511f7a804a366e/multidict-6.7.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:9b2fd74c52accced7e75de26023b7dccee62511a600e62311b918ec5c168fc2a", size = 259654, upload-time = "2025-10-06T14:49:41.32Z" }, - { url = "https://files.pythonhosted.org/packages/b2/e5/17e10e1b5c5f5a40f2fcbb45953c9b215f8a4098003915e46a93f5fcaa8f/multidict-6.7.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3e8bfdd0e487acf992407a140d2589fe598238eaeffa3da8448d63a63cd363f8", size = 251511, upload-time = "2025-10-06T14:49:46.021Z" }, - { url = "https://files.pythonhosted.org/packages/e3/9a/201bb1e17e7af53139597069c375e7b0dcbd47594604f65c2d5359508566/multidict-6.7.0-cp312-cp312-win32.whl", hash = "sha256:dd32a49400a2c3d52088e120ee00c1e3576cbff7e10b98467962c74fdb762ed4", size = 41895, upload-time = "2025-10-06T14:49:48.718Z" }, - { url = "https://files.pythonhosted.org/packages/46/e2/348cd32faad84eaf1d20cce80e2bb0ef8d312c55bca1f7fa9865e7770aaf/multidict-6.7.0-cp312-cp312-win_amd64.whl", hash = "sha256:92abb658ef2d7ef22ac9f8bb88e8b6c3e571671534e029359b6d9e845923eb1b", size = 46073, upload-time = "2025-10-06T14:49:50.28Z" }, - { url = "https://files.pythonhosted.org/packages/25/ec/aad2613c1910dce907480e0c3aa306905830f25df2e54ccc9dea450cb5aa/multidict-6.7.0-cp312-cp312-win_arm64.whl", hash = "sha256:490dab541a6a642ce1a9d61a4781656b346a55c13038f0b1244653828e3a83ec", size = 43226, upload-time = "2025-10-06T14:49:52.304Z" }, - { url = "https://files.pythonhosted.org/packages/b7/da/7d22601b625e241d4f23ef1ebff8acfc60da633c9e7e7922e24d10f592b3/multidict-6.7.0-py3-none-any.whl", hash = "sha256:394fc5c42a333c9ffc3e421a4c85e08580d990e08b99f6bf35b4132114c5dcb3", size = 12317, upload-time = "2025-10-06T14:52:29.272Z" }, + { url = "https://files.pythonhosted.org/packages/ce/f1/a90635c4f88fb913fbf4ce660b83b7445b7a02615bda034b2f8eb38fd597/multidict-6.7.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:7ff981b266af91d7b4b3793ca3382e53229088d193a85dfad6f5f4c27fc73e5d", size = 76626, upload-time = "2026-01-26T02:43:26.485Z" }, + { url = "https://files.pythonhosted.org/packages/a6/9b/267e64eaf6fc637a15b35f5de31a566634a2740f97d8d094a69d34f524a4/multidict-6.7.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:844c5bca0b5444adb44a623fb0a1310c2f4cd41f402126bb269cd44c9b3f3e1e", size = 44706, upload-time = "2026-01-26T02:43:27.607Z" }, + { url = "https://files.pythonhosted.org/packages/dd/a4/d45caf2b97b035c57267791ecfaafbd59c68212004b3842830954bb4b02e/multidict-6.7.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f2a0a924d4c2e9afcd7ec64f9de35fcd96915149b2216e1cb2c10a56df483855", size = 44356, upload-time = "2026-01-26T02:43:28.661Z" }, + { url = "https://files.pythonhosted.org/packages/fd/d2/0a36c8473f0cbaeadd5db6c8b72d15bbceeec275807772bfcd059bef487d/multidict-6.7.1-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:8be1802715a8e892c784c0197c2ace276ea52702a0ede98b6310c8f255a5afb3", size = 244355, upload-time = "2026-01-26T02:43:31.165Z" }, + { url = "https://files.pythonhosted.org/packages/5d/16/8c65be997fd7dd311b7d39c7b6e71a0cb449bad093761481eccbbe4b42a2/multidict-6.7.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2e2d2ed645ea29f31c4c7ea1552fcfd7cb7ba656e1eafd4134a6620c9f5fdd9e", size = 246433, upload-time = "2026-01-26T02:43:32.581Z" }, + { url = "https://files.pythonhosted.org/packages/01/fb/4dbd7e848d2799c6a026ec88ad39cf2b8416aa167fcc903baa55ecaa045c/multidict-6.7.1-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:95922cee9a778659e91db6497596435777bd25ed116701a4c034f8e46544955a", size = 225376, upload-time = "2026-01-26T02:43:34.417Z" }, + { url = "https://files.pythonhosted.org/packages/b6/8a/4a3a6341eac3830f6053062f8fbc9a9e54407c80755b3f05bc427295c2d0/multidict-6.7.1-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6b83cabdc375ffaaa15edd97eb7c0c672ad788e2687004990074d7d6c9b140c8", size = 257365, upload-time = "2026-01-26T02:43:35.741Z" }, + { url = "https://files.pythonhosted.org/packages/f7/a2/dd575a69c1aa206e12d27d0770cdf9b92434b48a9ef0cd0d1afdecaa93c4/multidict-6.7.1-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:38fb49540705369bab8484db0689d86c0a33a0a9f2c1b197f506b71b4b6c19b0", size = 254747, upload-time = "2026-01-26T02:43:36.976Z" }, + { url = "https://files.pythonhosted.org/packages/5a/56/21b27c560c13822ed93133f08aa6372c53a8e067f11fbed37b4adcdac922/multidict-6.7.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:439cbebd499f92e9aa6793016a8acaa161dfa749ae86d20960189f5398a19144", size = 246293, upload-time = "2026-01-26T02:43:38.258Z" }, + { url = "https://files.pythonhosted.org/packages/5a/a4/23466059dc3854763423d0ad6c0f3683a379d97673b1b89ec33826e46728/multidict-6.7.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6d3bc717b6fe763b8be3f2bee2701d3c8eb1b2a8ae9f60910f1b2860c82b6c49", size = 242962, upload-time = "2026-01-26T02:43:40.034Z" }, + { url = "https://files.pythonhosted.org/packages/1f/67/51dd754a3524d685958001e8fa20a0f5f90a6a856e0a9dcabff69be3dbb7/multidict-6.7.1-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:619e5a1ac57986dbfec9f0b301d865dddf763696435e2962f6d9cf2fdff2bb71", size = 237360, upload-time = "2026-01-26T02:43:41.752Z" }, + { url = "https://files.pythonhosted.org/packages/64/3f/036dfc8c174934d4b55d86ff4f978e558b0e585cef70cfc1ad01adc6bf18/multidict-6.7.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0b38ebffd9be37c1170d33bc0f36f4f262e0a09bc1aac1c34c7aa51a7293f0b3", size = 245940, upload-time = "2026-01-26T02:43:43.042Z" }, + { url = "https://files.pythonhosted.org/packages/3d/20/6214d3c105928ebc353a1c644a6ef1408bc5794fcb4f170bb524a3c16311/multidict-6.7.1-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:10ae39c9cfe6adedcdb764f5e8411d4a92b055e35573a2eaa88d3323289ef93c", size = 253502, upload-time = "2026-01-26T02:43:44.371Z" }, + { url = "https://files.pythonhosted.org/packages/b1/e2/c653bc4ae1be70a0f836b82172d643fcf1dade042ba2676ab08ec08bff0f/multidict-6.7.1-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:25167cc263257660290fba06b9318d2026e3c910be240a146e1f66dd114af2b0", size = 247065, upload-time = "2026-01-26T02:43:45.745Z" }, + { url = "https://files.pythonhosted.org/packages/c8/11/a854b4154cd3bd8b1fd375e8a8ca9d73be37610c361543d56f764109509b/multidict-6.7.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:128441d052254f42989ef98b7b6a6ecb1e6f708aa962c7984235316db59f50fa", size = 241870, upload-time = "2026-01-26T02:43:47.054Z" }, + { url = "https://files.pythonhosted.org/packages/13/bf/9676c0392309b5fdae322333d22a829715b570edb9baa8016a517b55b558/multidict-6.7.1-cp311-cp311-win32.whl", hash = "sha256:d62b7f64ffde3b99d06b707a280db04fb3855b55f5a06df387236051d0668f4a", size = 41302, upload-time = "2026-01-26T02:43:48.753Z" }, + { url = "https://files.pythonhosted.org/packages/c9/68/f16a3a8ba6f7b6dc92a1f19669c0810bd2c43fc5a02da13b1cbf8e253845/multidict-6.7.1-cp311-cp311-win_amd64.whl", hash = "sha256:bdbf9f3b332abd0cdb306e7c2113818ab1e922dc84b8f8fd06ec89ed2a19ab8b", size = 45981, upload-time = "2026-01-26T02:43:49.921Z" }, + { url = "https://files.pythonhosted.org/packages/ac/ad/9dd5305253fa00cd3c7555dbef69d5bf4133debc53b87ab8d6a44d411665/multidict-6.7.1-cp311-cp311-win_arm64.whl", hash = "sha256:b8c990b037d2fff2f4e33d3f21b9b531c5745b33a49a7d6dbe7a177266af44f6", size = 43159, upload-time = "2026-01-26T02:43:51.635Z" }, + { url = "https://files.pythonhosted.org/packages/8d/9c/f20e0e2cf80e4b2e4b1c365bf5fe104ee633c751a724246262db8f1a0b13/multidict-6.7.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a90f75c956e32891a4eda3639ce6dd86e87105271f43d43442a3aedf3cddf172", size = 76893, upload-time = "2026-01-26T02:43:52.754Z" }, + { url = "https://files.pythonhosted.org/packages/fe/cf/18ef143a81610136d3da8193da9d80bfe1cb548a1e2d1c775f26b23d024a/multidict-6.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fccb473e87eaa1382689053e4a4618e7ba7b9b9b8d6adf2027ee474597128cd", size = 45456, upload-time = "2026-01-26T02:43:53.893Z" }, + { url = "https://files.pythonhosted.org/packages/a9/65/1caac9d4cd32e8433908683446eebc953e82d22b03d10d41a5f0fefe991b/multidict-6.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0fa96985700739c4c7853a43c0b3e169360d6855780021bfc6d0f1ce7c123e7", size = 43872, upload-time = "2026-01-26T02:43:55.041Z" }, + { url = "https://files.pythonhosted.org/packages/cf/3b/d6bd75dc4f3ff7c73766e04e705b00ed6dbbaccf670d9e05a12b006f5a21/multidict-6.7.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb2a55f408c3043e42b40cc8eecd575afa27b7e0b956dfb190de0f8499a57a53", size = 251018, upload-time = "2026-01-26T02:43:56.198Z" }, + { url = "https://files.pythonhosted.org/packages/fd/80/c959c5933adedb9ac15152e4067c702a808ea183a8b64cf8f31af8ad3155/multidict-6.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb0ce7b2a32d09892b3dd6cc44877a0d02a33241fafca5f25c8b6b62374f8b75", size = 258883, upload-time = "2026-01-26T02:43:57.499Z" }, + { url = "https://files.pythonhosted.org/packages/86/85/7ed40adafea3d4f1c8b916e3b5cc3a8e07dfcdcb9cd72800f4ed3ca1b387/multidict-6.7.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c3a32d23520ee37bf327d1e1a656fec76a2edd5c038bf43eddfa0572ec49c60b", size = 242413, upload-time = "2026-01-26T02:43:58.755Z" }, + { url = "https://files.pythonhosted.org/packages/d2/57/b8565ff533e48595503c785f8361ff9a4fde4d67de25c207cd0ba3befd03/multidict-6.7.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9c90fed18bffc0189ba814749fdcc102b536e83a9f738a9003e569acd540a733", size = 268404, upload-time = "2026-01-26T02:44:00.216Z" }, + { url = "https://files.pythonhosted.org/packages/e0/50/9810c5c29350f7258180dfdcb2e52783a0632862eb334c4896ac717cebcb/multidict-6.7.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:da62917e6076f512daccfbbde27f46fed1c98fee202f0559adec8ee0de67f71a", size = 269456, upload-time = "2026-01-26T02:44:02.202Z" }, + { url = "https://files.pythonhosted.org/packages/f3/8d/5e5be3ced1d12966fefb5c4ea3b2a5b480afcea36406559442c6e31d4a48/multidict-6.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfde23ef6ed9db7eaee6c37dcec08524cb43903c60b285b172b6c094711b3961", size = 256322, upload-time = "2026-01-26T02:44:03.56Z" }, + { url = "https://files.pythonhosted.org/packages/31/6e/d8a26d81ac166a5592782d208dd90dfdc0a7a218adaa52b45a672b46c122/multidict-6.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3758692429e4e32f1ba0df23219cd0b4fc0a52f476726fff9337d1a57676a582", size = 253955, upload-time = "2026-01-26T02:44:04.845Z" }, + { url = "https://files.pythonhosted.org/packages/59/4c/7c672c8aad41534ba619bcd4ade7a0dc87ed6b8b5c06149b85d3dd03f0cd/multidict-6.7.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:398c1478926eca669f2fd6a5856b6de9c0acf23a2cb59a14c0ba5844fa38077e", size = 251254, upload-time = "2026-01-26T02:44:06.133Z" }, + { url = "https://files.pythonhosted.org/packages/7b/bd/84c24de512cbafbdbc39439f74e967f19570ce7924e3007174a29c348916/multidict-6.7.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c102791b1c4f3ab36ce4101154549105a53dc828f016356b3e3bcae2e3a039d3", size = 252059, upload-time = "2026-01-26T02:44:07.518Z" }, + { url = "https://files.pythonhosted.org/packages/fa/ba/f5449385510825b73d01c2d4087bf6d2fccc20a2d42ac34df93191d3dd03/multidict-6.7.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a088b62bd733e2ad12c50dad01b7d0166c30287c166e137433d3b410add807a6", size = 263588, upload-time = "2026-01-26T02:44:09.382Z" }, + { url = "https://files.pythonhosted.org/packages/d7/11/afc7c677f68f75c84a69fe37184f0f82fce13ce4b92f49f3db280b7e92b3/multidict-6.7.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d51ff4785d58d3f6c91bdbffcb5e1f7ddfda557727043aa20d20ec4f65e324a", size = 259642, upload-time = "2026-01-26T02:44:10.73Z" }, + { url = "https://files.pythonhosted.org/packages/2b/17/ebb9644da78c4ab36403739e0e6e0e30ebb135b9caf3440825001a0bddcb/multidict-6.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc5907494fccf3e7d3f94f95c91d6336b092b5fc83811720fae5e2765890dfba", size = 251377, upload-time = "2026-01-26T02:44:12.042Z" }, + { url = "https://files.pythonhosted.org/packages/ca/a4/840f5b97339e27846c46307f2530a2805d9d537d8b8bd416af031cad7fa0/multidict-6.7.1-cp312-cp312-win32.whl", hash = "sha256:28ca5ce2fd9716631133d0e9a9b9a745ad7f60bac2bccafb56aa380fc0b6c511", size = 41887, upload-time = "2026-01-26T02:44:14.245Z" }, + { url = "https://files.pythonhosted.org/packages/80/31/0b2517913687895f5904325c2069d6a3b78f66cc641a86a2baf75a05dcbb/multidict-6.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcee94dfbd638784645b066074b338bc9cc155d4b4bffa4adce1615c5a426c19", size = 46053, upload-time = "2026-01-26T02:44:15.371Z" }, + { url = "https://files.pythonhosted.org/packages/0c/5b/aba28e4ee4006ae4c7df8d327d31025d760ffa992ea23812a601d226e682/multidict-6.7.1-cp312-cp312-win_arm64.whl", hash = "sha256:ba0a9fb644d0c1a2194cf7ffb043bd852cea63a57f66fbd33959f7dae18517bf", size = 43307, upload-time = "2026-01-26T02:44:16.852Z" }, + { url = "https://files.pythonhosted.org/packages/81/08/7036c080d7117f28a4af526d794aab6a84463126db031b007717c1a6676e/multidict-6.7.1-py3-none-any.whl", hash = "sha256:55d97cc6dae627efa6a6e548885712d4864b81110ac76fa4e534c03819fa4a56", size = 12319, upload-time = "2026-01-26T02:46:44.004Z" }, ] [[package]] name = "multiprocess" -version = "0.70.18" +version = "0.70.16" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "dill" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/72/fd/2ae3826f5be24c6ed87266bc4e59c46ea5b059a103f3d7e7eb76a52aeecb/multiprocess-0.70.18.tar.gz", hash = "sha256:f9597128e6b3e67b23956da07cf3d2e5cba79e2f4e0fba8d7903636663ec6d0d", size = 1798503, upload-time = "2025-04-17T03:11:27.742Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/ae/04f39c5d0d0def03247c2893d6f2b83c136bf3320a2154d7b8858f2ba72d/multiprocess-0.70.16.tar.gz", hash = "sha256:161af703d4652a0e1410be6abccecde4a7ddffd19341be0a7011b94aeb171ac1", size = 1772603, upload-time = "2024-01-28T18:52:34.85Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/55/4d/9af0d1279c84618bcd35bf5fd7e371657358c7b0a523e54a9cffb87461f8/multiprocess-0.70.18-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:8b8940ae30139e04b076da6c5b83e9398585ebdf0f2ad3250673fef5b2ff06d6", size = 144695, upload-time = "2025-04-17T03:11:09.161Z" }, - { url = "https://files.pythonhosted.org/packages/17/bf/87323e79dd0562474fad3373c21c66bc6c3c9963b68eb2a209deb4c8575e/multiprocess-0.70.18-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:0929ba95831adb938edbd5fb801ac45e705ecad9d100b3e653946b7716cb6bd3", size = 144742, upload-time = "2025-04-17T03:11:10.072Z" }, - { url = "https://files.pythonhosted.org/packages/dd/74/cb8c831e58dc6d5cf450b17c7db87f14294a1df52eb391da948b5e0a0b94/multiprocess-0.70.18-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:4d77f8e4bfe6c6e2e661925bbf9aed4d5ade9a1c6502d5dfc10129b9d1141797", size = 144745, upload-time = "2025-04-17T03:11:11.453Z" }, - { url = "https://files.pythonhosted.org/packages/ba/d8/0cba6cf51a1a31f20471fbc823a716170c73012ddc4fb85d706630ed6e8f/multiprocess-0.70.18-py310-none-any.whl", hash = "sha256:60c194974c31784019c1f459d984e8f33ee48f10fcf42c309ba97b30d9bd53ea", size = 134948, upload-time = "2025-04-17T03:11:20.223Z" }, - { url = "https://files.pythonhosted.org/packages/4b/88/9039f2fed1012ef584751d4ceff9ab4a51e5ae264898f0b7cbf44340a859/multiprocess-0.70.18-py311-none-any.whl", hash = "sha256:5aa6eef98e691281b3ad923be2832bf1c55dd2c859acd73e5ec53a66aae06a1d", size = 144462, upload-time = "2025-04-17T03:11:21.657Z" }, - { url = "https://files.pythonhosted.org/packages/bf/b6/5f922792be93b82ec6b5f270bbb1ef031fd0622847070bbcf9da816502cc/multiprocess-0.70.18-py312-none-any.whl", hash = "sha256:9b78f8e5024b573730bfb654783a13800c2c0f2dfc0c25e70b40d184d64adaa2", size = 150287, upload-time = "2025-04-17T03:11:22.69Z" }, - { url = "https://files.pythonhosted.org/packages/3b/c3/ca84c19bd14cdfc21c388fdcebf08b86a7a470ebc9f5c3c084fc2dbc50f7/multiprocess-0.70.18-py38-none-any.whl", hash = "sha256:dbf705e52a154fe5e90fb17b38f02556169557c2dd8bb084f2e06c2784d8279b", size = 132636, upload-time = "2025-04-17T03:11:24.936Z" }, - { url = "https://files.pythonhosted.org/packages/6c/28/dd72947e59a6a8c856448a5e74da6201cb5502ddff644fbc790e4bd40b9a/multiprocess-0.70.18-py39-none-any.whl", hash = "sha256:e78ca805a72b1b810c690b6b4cc32579eba34f403094bbbae962b7b5bf9dfcb8", size = 133478, upload-time = "2025-04-17T03:11:26.253Z" }, + { url = "https://files.pythonhosted.org/packages/bc/f7/7ec7fddc92e50714ea3745631f79bd9c96424cb2702632521028e57d3a36/multiprocess-0.70.16-py310-none-any.whl", hash = "sha256:c4a9944c67bd49f823687463660a2d6daae94c289adff97e0f9d696ba6371d02", size = 134824, upload-time = "2024-01-28T18:52:26.062Z" }, + { url = "https://files.pythonhosted.org/packages/50/15/b56e50e8debaf439f44befec5b2af11db85f6e0f344c3113ae0be0593a91/multiprocess-0.70.16-py311-none-any.whl", hash = "sha256:af4cabb0dac72abfb1e794fa7855c325fd2b55a10a44628a3c1ad3311c04127a", size = 143519, upload-time = "2024-01-28T18:52:28.115Z" }, + { url = "https://files.pythonhosted.org/packages/0a/7d/a988f258104dcd2ccf1ed40fdc97e26c4ac351eeaf81d76e266c52d84e2f/multiprocess-0.70.16-py312-none-any.whl", hash = "sha256:fc0544c531920dde3b00c29863377f87e1632601092ea2daca74e4beb40faa2e", size = 146741, upload-time = "2024-01-28T18:52:29.395Z" }, + { url = "https://files.pythonhosted.org/packages/ea/89/38df130f2c799090c978b366cfdf5b96d08de5b29a4a293df7f7429fa50b/multiprocess-0.70.16-py38-none-any.whl", hash = "sha256:a71d82033454891091a226dfc319d0cfa8019a4e888ef9ca910372a446de4435", size = 132628, upload-time = "2024-01-28T18:52:30.853Z" }, + { url = "https://files.pythonhosted.org/packages/da/d9/f7f9379981e39b8c2511c9e0326d212accacb82f12fbfdc1aa2ce2a7b2b6/multiprocess-0.70.16-py39-none-any.whl", hash = "sha256:a0bafd3ae1b732eac64be2e72038231c1ba97724b60b09400d68f229fcc2fbf3", size = 133351, upload-time = "2024-01-28T18:52:31.981Z" }, +] + +[[package]] +name = "murmurhash" +version = "1.0.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/23/2e/88c147931ea9725d634840d538622e94122bceaf346233349b7b5c62964b/murmurhash-1.0.15.tar.gz", hash = "sha256:58e2b27b7847f9e2a6edf10b47a8c8dd70a4705f45dccb7bf76aeadacf56ba01", size = 13291, upload-time = "2025-11-14T09:51:15.272Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6b/ca/77d3e69924a8eb4508bb4f0ad34e46adbeedeb93616a71080e61e53dad71/murmurhash-1.0.15-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f32307fb9347680bb4fe1cbef6362fb39bd994f1b59abd8c09ca174e44199081", size = 27397, upload-time = "2025-11-14T09:50:03.077Z" }, + { url = "https://files.pythonhosted.org/packages/e6/53/a936f577d35b245d47b310f29e5e9f09fcac776c8c992f1ab51a9fb0cee2/murmurhash-1.0.15-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:539d8405885d1d19c005f3a2313b47e8e54b0ee89915eb8dfbb430b194328e6c", size = 27692, upload-time = "2025-11-14T09:50:04.144Z" }, + { url = "https://files.pythonhosted.org/packages/4d/64/5f8cfd1fd9cbeb43fcff96672f5bd9e7e1598d1c970f808ecd915490dc20/murmurhash-1.0.15-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c4cd739a00f5a4602201b74568ddabae46ec304719d9be752fd8f534a9464b5e", size = 128396, upload-time = "2025-11-14T09:50:05.268Z" }, + { url = "https://files.pythonhosted.org/packages/ac/10/d9ce29d559a75db0d8a3f13ea12c7f541ec9de2afca38dc70418b890eedb/murmurhash-1.0.15-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:44d211bcc3ec203c47dac06f48ee871093fcbdffa6652a6cc5ea7180306680a8", size = 128687, upload-time = "2025-11-14T09:50:06.527Z" }, + { url = "https://files.pythonhosted.org/packages/48/cd/dc97ab7e68cdfa1537a56e36dbc846c5a66701cc39ecee2d4399fe61996c/murmurhash-1.0.15-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:f9bf47101354fb1dc4b2e313192566f04ba295c28a37e2f71c692759acc1ba3c", size = 128198, upload-time = "2025-11-14T09:50:08.062Z" }, + { url = "https://files.pythonhosted.org/packages/53/73/32f2aaa22c1e4afae337106baf0c938abf36a6cc879cfee83a00461bbbf7/murmurhash-1.0.15-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3c69b4d3bcd6233782a78907fe10b9b7a796bdc5d28060cf097d067bec280a5d", size = 127214, upload-time = "2025-11-14T09:50:09.265Z" }, + { url = "https://files.pythonhosted.org/packages/82/ed/812103a7f353eba2d83655b08205e13a38c93b4db0692f94756e1eb44516/murmurhash-1.0.15-cp311-cp311-win_amd64.whl", hash = "sha256:e43a69496342ce530bdd670264cb7c8f45490b296e4764c837ce577e3c7ebd53", size = 25241, upload-time = "2025-11-14T09:50:10.373Z" }, + { url = "https://files.pythonhosted.org/packages/eb/5f/2c511bdd28f7c24da37a00116ffd0432b65669d098f0d0260c66ac0ffdc2/murmurhash-1.0.15-cp311-cp311-win_arm64.whl", hash = "sha256:f3e99a6ee36ef5372df5f138e3d9c801420776d3641a34a49e5c2555f44edba7", size = 23216, upload-time = "2025-11-14T09:50:11.651Z" }, + { url = "https://files.pythonhosted.org/packages/b6/46/be8522d3456fdccf1b8b049c6d82e7a3c1114c4fc2cfe14b04cba4b3e701/murmurhash-1.0.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d37e3ae44746bca80b1a917c2ea625cf216913564ed43f69d2888e5df97db0cb", size = 27884, upload-time = "2025-11-14T09:50:13.133Z" }, + { url = "https://files.pythonhosted.org/packages/ed/cc/630449bf4f6178d7daf948ce46ad00b25d279065fc30abd8d706be3d87e0/murmurhash-1.0.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0861cb11039409eaf46878456b7d985ef17b6b484103a6fc367b2ecec846891d", size = 27855, upload-time = "2025-11-14T09:50:14.859Z" }, + { url = "https://files.pythonhosted.org/packages/ff/30/ea8f601a9bf44db99468696efd59eb9cff1157cd55cb586d67116697583f/murmurhash-1.0.15-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:5a301decfaccfec70fe55cb01dde2a012c3014a874542eaa7cc73477bb749616", size = 134088, upload-time = "2025-11-14T09:50:15.958Z" }, + { url = "https://files.pythonhosted.org/packages/c9/de/c40ce8c0877d406691e735b8d6e9c815f36a82b499d358313db5dbe219d7/murmurhash-1.0.15-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:32c6fde7bd7e9407003370a07b5f4addacabe1556ad3dc2cac246b7a2bba3400", size = 133978, upload-time = "2025-11-14T09:50:17.572Z" }, + { url = "https://files.pythonhosted.org/packages/47/84/bd49963ecd84ebab2fe66595e2d1ed41d5e8b5153af5dc930f0bd827007c/murmurhash-1.0.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5d8b43a7011540dc3c7ce66f2134df9732e2bc3bbb4a35f6458bc755e48bde26", size = 132956, upload-time = "2025-11-14T09:50:18.742Z" }, + { url = "https://files.pythonhosted.org/packages/4f/7c/2530769c545074417c862583f05f4245644599f1e9ff619b3dfe2969aafc/murmurhash-1.0.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43bf4541892ecd95963fcd307bf1c575fc0fee1682f41c93007adee71ca2bb40", size = 134184, upload-time = "2025-11-14T09:50:19.941Z" }, + { url = "https://files.pythonhosted.org/packages/84/a4/b249b042f5afe34d14ada2dc4afc777e883c15863296756179652e081c44/murmurhash-1.0.15-cp312-cp312-win_amd64.whl", hash = "sha256:f4ac15a2089dc42e6eb0966622d42d2521590a12c92480aafecf34c085302cca", size = 25647, upload-time = "2025-11-14T09:50:21.049Z" }, + { url = "https://files.pythonhosted.org/packages/13/bf/028179259aebc18fd4ba5cae2601d1d47517427a537ab44336446431a215/murmurhash-1.0.15-cp312-cp312-win_arm64.whl", hash = "sha256:4a70ca4ae19e600d9be3da64d00710e79dde388a4d162f22078d64844d0ebdda", size = 23338, upload-time = "2025-11-14T09:50:22.359Z" }, ] [[package]] name = "mypy" -version = "1.17.1" +version = "1.19.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "librt", marker = "platform_python_implementation != 'PyPy'" }, { name = "mypy-extensions" }, { name = "pathspec" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8e/22/ea637422dedf0bf36f3ef238eab4e455e2a0dcc3082b5cc067615347ab8e/mypy-1.17.1.tar.gz", hash = "sha256:25e01ec741ab5bb3eec8ba9cdb0f769230368a22c959c4937360efb89b7e9f01", size = 3352570, upload-time = "2025-07-31T07:54:19.204Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/db/4efed9504bc01309ab9c2da7e352cc223569f05478012b5d9ece38fd44d2/mypy-1.19.1.tar.gz", hash = "sha256:19d88bb05303fe63f71dd2c6270daca27cb9401c4ca8255fe50d1d920e0eb9ba", size = 3582404, upload-time = "2025-12-15T05:03:48.42Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/46/cf/eadc80c4e0a70db1c08921dcc220357ba8ab2faecb4392e3cebeb10edbfa/mypy-1.17.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:ad37544be07c5d7fba814eb370e006df58fed8ad1ef33ed1649cb1889ba6ff58", size = 10921009, upload-time = "2025-07-31T07:53:23.037Z" }, - { url = "https://files.pythonhosted.org/packages/5d/c1/c869d8c067829ad30d9bdae051046561552516cfb3a14f7f0347b7d973ee/mypy-1.17.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:064e2ff508e5464b4bd807a7c1625bc5047c5022b85c70f030680e18f37273a5", size = 10047482, upload-time = "2025-07-31T07:53:26.151Z" }, - { url = "https://files.pythonhosted.org/packages/98/b9/803672bab3fe03cee2e14786ca056efda4bb511ea02dadcedde6176d06d0/mypy-1.17.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:70401bbabd2fa1aa7c43bb358f54037baf0586f41e83b0ae67dd0534fc64edfd", size = 11832883, upload-time = "2025-07-31T07:53:47.948Z" }, - { url = "https://files.pythonhosted.org/packages/88/fb/fcdac695beca66800918c18697b48833a9a6701de288452b6715a98cfee1/mypy-1.17.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e92bdc656b7757c438660f775f872a669b8ff374edc4d18277d86b63edba6b8b", size = 12566215, upload-time = "2025-07-31T07:54:04.031Z" }, - { url = "https://files.pythonhosted.org/packages/7f/37/a932da3d3dace99ee8eb2043b6ab03b6768c36eb29a02f98f46c18c0da0e/mypy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c1fdf4abb29ed1cb091cf432979e162c208a5ac676ce35010373ff29247bcad5", size = 12751956, upload-time = "2025-07-31T07:53:36.263Z" }, - { url = "https://files.pythonhosted.org/packages/8c/cf/6438a429e0f2f5cab8bc83e53dbebfa666476f40ee322e13cac5e64b79e7/mypy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:ff2933428516ab63f961644bc49bc4cbe42bbffb2cd3b71cc7277c07d16b1a8b", size = 9507307, upload-time = "2025-07-31T07:53:59.734Z" }, - { url = "https://files.pythonhosted.org/packages/17/a2/7034d0d61af8098ec47902108553122baa0f438df8a713be860f7407c9e6/mypy-1.17.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:69e83ea6553a3ba79c08c6e15dbd9bfa912ec1e493bf75489ef93beb65209aeb", size = 11086295, upload-time = "2025-07-31T07:53:28.124Z" }, - { url = "https://files.pythonhosted.org/packages/14/1f/19e7e44b594d4b12f6ba8064dbe136505cec813549ca3e5191e40b1d3cc2/mypy-1.17.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1b16708a66d38abb1e6b5702f5c2c87e133289da36f6a1d15f6a5221085c6403", size = 10112355, upload-time = "2025-07-31T07:53:21.121Z" }, - { url = "https://files.pythonhosted.org/packages/5b/69/baa33927e29e6b4c55d798a9d44db5d394072eef2bdc18c3e2048c9ed1e9/mypy-1.17.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89e972c0035e9e05823907ad5398c5a73b9f47a002b22359b177d40bdaee7056", size = 11875285, upload-time = "2025-07-31T07:53:55.293Z" }, - { url = "https://files.pythonhosted.org/packages/90/13/f3a89c76b0a41e19490b01e7069713a30949d9a6c147289ee1521bcea245/mypy-1.17.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:03b6d0ed2b188e35ee6d5c36b5580cffd6da23319991c49ab5556c023ccf1341", size = 12737895, upload-time = "2025-07-31T07:53:43.623Z" }, - { url = "https://files.pythonhosted.org/packages/23/a1/c4ee79ac484241301564072e6476c5a5be2590bc2e7bfd28220033d2ef8f/mypy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c837b896b37cd103570d776bda106eabb8737aa6dd4f248451aecf53030cdbeb", size = 12931025, upload-time = "2025-07-31T07:54:17.125Z" }, - { url = "https://files.pythonhosted.org/packages/89/b8/7409477be7919a0608900e6320b155c72caab4fef46427c5cc75f85edadd/mypy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:665afab0963a4b39dff7c1fa563cc8b11ecff7910206db4b2e64dd1ba25aed19", size = 9584664, upload-time = "2025-07-31T07:54:12.842Z" }, - { url = "https://files.pythonhosted.org/packages/1d/f3/8fcd2af0f5b806f6cf463efaffd3c9548a28f84220493ecd38d127b6b66d/mypy-1.17.1-py3-none-any.whl", hash = "sha256:a9f52c0351c21fe24c21d8c0eb1f62967b262d6729393397b6f443c3b773c3b9", size = 2283411, upload-time = "2025-07-31T07:53:24.664Z" }, + { url = "https://files.pythonhosted.org/packages/ef/47/6b3ebabd5474d9cdc170d1342fbf9dddc1b0ec13ec90bf9004ee6f391c31/mypy-1.19.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d8dfc6ab58ca7dda47d9237349157500468e404b17213d44fc1cb77bce532288", size = 13028539, upload-time = "2025-12-15T05:03:44.129Z" }, + { url = "https://files.pythonhosted.org/packages/5c/a6/ac7c7a88a3c9c54334f53a941b765e6ec6c4ebd65d3fe8cdcfbe0d0fd7db/mypy-1.19.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e3f276d8493c3c97930e354b2595a44a21348b320d859fb4a2b9f66da9ed27ab", size = 12083163, upload-time = "2025-12-15T05:03:37.679Z" }, + { url = "https://files.pythonhosted.org/packages/67/af/3afa9cf880aa4a2c803798ac24f1d11ef72a0c8079689fac5cfd815e2830/mypy-1.19.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2abb24cf3f17864770d18d673c85235ba52456b36a06b6afc1e07c1fdcd3d0e6", size = 12687629, upload-time = "2025-12-15T05:02:31.526Z" }, + { url = "https://files.pythonhosted.org/packages/2d/46/20f8a7114a56484ab268b0ab372461cb3a8f7deed31ea96b83a4e4cfcfca/mypy-1.19.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a009ffa5a621762d0c926a078c2d639104becab69e79538a494bcccb62cc0331", size = 13436933, upload-time = "2025-12-15T05:03:15.606Z" }, + { url = "https://files.pythonhosted.org/packages/5b/f8/33b291ea85050a21f15da910002460f1f445f8007adb29230f0adea279cb/mypy-1.19.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f7cee03c9a2e2ee26ec07479f38ea9c884e301d42c6d43a19d20fb014e3ba925", size = 13661754, upload-time = "2025-12-15T05:02:26.731Z" }, + { url = "https://files.pythonhosted.org/packages/fd/a3/47cbd4e85bec4335a9cd80cf67dbc02be21b5d4c9c23ad6b95d6c5196bac/mypy-1.19.1-cp311-cp311-win_amd64.whl", hash = "sha256:4b84a7a18f41e167f7995200a1d07a4a6810e89d29859df936f1c3923d263042", size = 10055772, upload-time = "2025-12-15T05:03:26.179Z" }, + { url = "https://files.pythonhosted.org/packages/06/8a/19bfae96f6615aa8a0604915512e0289b1fad33d5909bf7244f02935d33a/mypy-1.19.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a8174a03289288c1f6c46d55cef02379b478bfbc8e358e02047487cad44c6ca1", size = 13206053, upload-time = "2025-12-15T05:03:46.622Z" }, + { url = "https://files.pythonhosted.org/packages/a5/34/3e63879ab041602154ba2a9f99817bb0c85c4df19a23a1443c8986e4d565/mypy-1.19.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffcebe56eb09ff0c0885e750036a095e23793ba6c2e894e7e63f6d89ad51f22e", size = 12219134, upload-time = "2025-12-15T05:03:24.367Z" }, + { url = "https://files.pythonhosted.org/packages/89/cc/2db6f0e95366b630364e09845672dbee0cbf0bbe753a204b29a944967cd9/mypy-1.19.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b64d987153888790bcdb03a6473d321820597ab8dd9243b27a92153c4fa50fd2", size = 12731616, upload-time = "2025-12-15T05:02:44.725Z" }, + { url = "https://files.pythonhosted.org/packages/00/be/dd56c1fd4807bc1eba1cf18b2a850d0de7bacb55e158755eb79f77c41f8e/mypy-1.19.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c35d298c2c4bba75feb2195655dfea8124d855dfd7343bf8b8c055421eaf0cf8", size = 13620847, upload-time = "2025-12-15T05:03:39.633Z" }, + { url = "https://files.pythonhosted.org/packages/6d/42/332951aae42b79329f743bf1da088cd75d8d4d9acc18fbcbd84f26c1af4e/mypy-1.19.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:34c81968774648ab5ac09c29a375fdede03ba253f8f8287847bd480782f73a6a", size = 13834976, upload-time = "2025-12-15T05:03:08.786Z" }, + { url = "https://files.pythonhosted.org/packages/6f/63/e7493e5f90e1e085c562bb06e2eb32cae27c5057b9653348d38b47daaecc/mypy-1.19.1-cp312-cp312-win_amd64.whl", hash = "sha256:b10e7c2cd7870ba4ad9b2d8a6102eb5ffc1f16ca35e3de6bfa390c1113029d13", size = 10118104, upload-time = "2025-12-15T05:03:10.834Z" }, + { url = "https://files.pythonhosted.org/packages/8d/f4/4ce9a05ce5ded1de3ec1c1d96cf9f9504a04e54ce0ed55cfa38619a32b8d/mypy-1.19.1-py3-none-any.whl", hash = "sha256:f1235f5ea01b7db5468d53ece6aaddf1ad0b88d9e7462b86ef96fe04995d7247", size = 2471239, upload-time = "2025-12-15T05:03:07.248Z" }, ] [[package]] name = "mypy-boto3-bedrock-runtime" -version = "1.41.2" +version = "1.42.42" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/f1/00aea4f91501728e7af7e899ce3a75d48d6df97daa720db11e46730fa123/mypy_boto3_bedrock_runtime-1.41.2.tar.gz", hash = "sha256:ba2c11f2f18116fd69e70923389ce68378fa1620f70e600efb354395a1a9e0e5", size = 28890, upload-time = "2025-11-21T20:35:30.074Z" } +sdist = { url = "https://files.pythonhosted.org/packages/46/bb/65dc1b2c5796a6ab5f60bdb57343bd6c3ecb82251c580eca415c8548333e/mypy_boto3_bedrock_runtime-1.42.42.tar.gz", hash = "sha256:3a4088218478b6fbbc26055c03c95bee4fc04624a801090b3cce3037e8275c8d", size = 29840, upload-time = "2026-02-04T20:53:05.999Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/cc/96a2af58c632701edb5be1dda95434464da43df40ae868a1ab1ddf033839/mypy_boto3_bedrock_runtime-1.41.2-py3-none-any.whl", hash = "sha256:a720ff1e98cf10723c37a61a46cff220b190c55b8fb57d4397e6cf286262cf02", size = 34967, upload-time = "2025-11-21T20:35:27.655Z" }, + { url = "https://files.pythonhosted.org/packages/00/43/7ea062f2228f47b5779dcfa14dab48d6e29f979b35d1a5102b0ba80b9c1b/mypy_boto3_bedrock_runtime-1.42.42-py3-none-any.whl", hash = "sha256:b2d16eae22607d0685f90796b3a0afc78c0b09d45872e00eafd634a31dd9358f", size = 36077, upload-time = "2026-02-04T20:53:01.768Z" }, ] [[package]] @@ -4103,21 +4240,21 @@ wheels = [ [[package]] name = "mysql-connector-python" -version = "9.5.0" +version = "9.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/39/33/b332b001bc8c5ee09255a0d4b09a254da674450edd6a3e5228b245ca82a0/mysql_connector_python-9.5.0.tar.gz", hash = "sha256:92fb924285a86d8c146ebd63d94f9eaefa548da7813bc46271508fdc6cc1d596", size = 12251077, upload-time = "2025-10-22T09:05:45.423Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6f/6e/c89babc7de3df01467d159854414659c885152579903a8220c8db02a3835/mysql_connector_python-9.6.0.tar.gz", hash = "sha256:c453bb55347174d87504b534246fb10c589daf5d057515bf615627198a3c7ef1", size = 12254999, upload-time = "2026-02-10T12:04:52.63Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/03/77347d58b0027ce93a41858477e08422e498c6ebc24348b1f725ed7a67ae/mysql_connector_python-9.5.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:653e70cd10cf2d18dd828fae58dff5f0f7a5cf7e48e244f2093314dddf84a4b9", size = 17578984, upload-time = "2025-10-22T09:01:41.213Z" }, - { url = "https://files.pythonhosted.org/packages/a5/bb/0f45c7ee55ebc56d6731a593d85c0e7f25f83af90a094efebfd5be9fe010/mysql_connector_python-9.5.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:5add93f60b3922be71ea31b89bc8a452b876adbb49262561bd559860dae96b3f", size = 18445067, upload-time = "2025-10-22T09:01:43.215Z" }, - { url = "https://files.pythonhosted.org/packages/1c/ec/054de99d4aa50d851a37edca9039280f7194cc1bfd30aab38f5bd6977ebe/mysql_connector_python-9.5.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:20950a5e44896c03e3dc93ceb3a5e9b48c9acae18665ca6e13249b3fe5b96811", size = 33668029, upload-time = "2025-10-22T09:01:45.74Z" }, - { url = "https://files.pythonhosted.org/packages/90/a2/e6095dc3a7ad5c959fe4a65681db63af131f572e57cdffcc7816bc84e3ad/mysql_connector_python-9.5.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:7fdd3205b9242c284019310fa84437f3357b13f598e3f9b5d80d337d4a6406b8", size = 34101687, upload-time = "2025-10-22T09:01:48.462Z" }, - { url = "https://files.pythonhosted.org/packages/9c/88/bc13c33fca11acaf808bd1809d8602d78f5bb84f7b1e7b1a288c383a14fd/mysql_connector_python-9.5.0-cp311-cp311-win_amd64.whl", hash = "sha256:c021d8b0830958b28712c70c53b206b4cf4766948dae201ea7ca588a186605e0", size = 16511749, upload-time = "2025-10-22T09:01:51.032Z" }, - { url = "https://files.pythonhosted.org/packages/02/89/167ebee82f4b01ba7339c241c3cc2518886a2be9f871770a1efa81b940a0/mysql_connector_python-9.5.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:a72c2ef9d50b84f3c567c31b3bf30901af740686baa2a4abead5f202e0b7ea61", size = 17581904, upload-time = "2025-10-22T09:01:53.21Z" }, - { url = "https://files.pythonhosted.org/packages/67/46/630ca969ce10b30fdc605d65dab4a6157556d8cc3b77c724f56c2d83cb79/mysql_connector_python-9.5.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:bd9ba5a946cfd3b3b2688a75135357e862834b0321ed936fd968049be290872b", size = 18448195, upload-time = "2025-10-22T09:01:55.378Z" }, - { url = "https://files.pythonhosted.org/packages/f6/87/4c421f41ad169d8c9065ad5c46673c7af889a523e4899c1ac1d6bfd37262/mysql_connector_python-9.5.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5ef7accbdf8b5f6ec60d2a1550654b7e27e63bf6f7b04020d5fb4191fb02bc4d", size = 33668638, upload-time = "2025-10-22T09:01:57.896Z" }, - { url = "https://files.pythonhosted.org/packages/a6/01/67cf210d50bfefbb9224b9a5c465857c1767388dade1004c903c8e22a991/mysql_connector_python-9.5.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:a6e0a4a0274d15e3d4c892ab93f58f46431222117dba20608178dfb2cc4d5fd8", size = 34102899, upload-time = "2025-10-22T09:02:00.291Z" }, - { url = "https://files.pythonhosted.org/packages/cd/ef/3d1a67d503fff38cc30e11d111cf28f0976987fb175f47b10d44494e1080/mysql_connector_python-9.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:b6c69cb37600b7e22f476150034e2afbd53342a175e20aea887f8158fc5e3ff6", size = 16512684, upload-time = "2025-10-22T09:02:02.411Z" }, - { url = "https://files.pythonhosted.org/packages/95/e1/45373c06781340c7b74fe9b88b85278ac05321889a307eaa5be079a997d4/mysql_connector_python-9.5.0-py2.py3-none-any.whl", hash = "sha256:ace137b88eb6fdafa1e5b2e03ac76ce1b8b1844b3a4af1192a02ae7c1a45bdee", size = 479047, upload-time = "2025-10-22T09:02:27.809Z" }, + { url = "https://files.pythonhosted.org/packages/2a/08/0e9bce000736454c2b8bb4c40bded79328887483689487dad7df4cf59fb7/mysql_connector_python-9.6.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:011931f7392a1087e10d305b0303f2a20cc1af2c1c8a15cd5691609aa95dfcbd", size = 17582646, upload-time = "2026-01-21T09:04:48.327Z" }, + { url = "https://files.pythonhosted.org/packages/93/aa/3dd4db039fc6a9bcbdbade83be9914ead6786c0be4918170dfaf89327b76/mysql_connector_python-9.6.0-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:b5212372aff6833473d2560ac87d3df9fb2498d0faacb7ebf231d947175fa36a", size = 18449358, upload-time = "2026-01-21T09:04:50.278Z" }, + { url = "https://files.pythonhosted.org/packages/53/38/ecd6d35382b6265ff5f030464d53b45e51ff2c2523ab88771c277fd84c05/mysql_connector_python-9.6.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:61deca6e243fafbb3cf08ae27bd0c83d0f8188de8456e46aeba0d3db15bb7230", size = 34169309, upload-time = "2026-01-21T09:04:52.402Z" }, + { url = "https://files.pythonhosted.org/packages/18/1d/fe1133eb76089342854d8fbe88e28598f7e06bc684a763d21fc7b23f1d5e/mysql_connector_python-9.6.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:adabbc5e1475cdf5fb6f1902a25edc3bd1e0726fa45f01ab1b8f479ff43b3337", size = 34541101, upload-time = "2026-01-21T09:04:55.897Z" }, + { url = "https://files.pythonhosted.org/packages/3f/99/da0f55beb970ca049fd7d37a6391d686222af89a8b13e636d8e9bbd06536/mysql_connector_python-9.6.0-cp311-cp311-win_amd64.whl", hash = "sha256:8732ca0b7417b45238bcbfc7e64d9c4d62c759672207c6284f0921c366efddc7", size = 16514767, upload-time = "2026-02-10T12:03:50.584Z" }, + { url = "https://files.pythonhosted.org/packages/8f/d9/2a4b4d90b52f4241f0f71618cd4bd8779dd6d18db8058b0a4dd83ec0541c/mysql_connector_python-9.6.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9664e217c72dd6fb700f4c8512af90261f72d2f5d7c00c4e13e4c1e09bfa3d5e", size = 17585672, upload-time = "2026-02-10T12:03:52.955Z" }, + { url = "https://files.pythonhosted.org/packages/33/91/2495835733a054e716a17dc28404748b33f2dc1da1ae4396fb45574adf40/mysql_connector_python-9.6.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:1ed4b5c4761e5333035293e746683890e4ef2e818e515d14023fd80293bc31fa", size = 18452624, upload-time = "2026-02-10T12:03:56.153Z" }, + { url = "https://files.pythonhosted.org/packages/7a/69/e83abbbbf7f8eed855b5a5ff7285bc0afb1199418ac036c7691edf41e154/mysql_connector_python-9.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5095758dcb89a6bce2379f349da336c268c407129002b595c5dba82ce387e2a5", size = 34169154, upload-time = "2026-02-10T12:03:58.831Z" }, + { url = "https://files.pythonhosted.org/packages/82/44/67bb61c71f398fbc739d07e8dcadad94e2f655874cb32ae851454066bea0/mysql_connector_python-9.6.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4ae4e7780fad950a4f267dea5851048d160f5b71314a342cdbf30b154f1c74f7", size = 34542947, upload-time = "2026-02-10T12:04:02.408Z" }, + { url = "https://files.pythonhosted.org/packages/ba/39/994c4f7e9c59d3ca534a831d18442ac4c529865db20aeaa4fd94e2af5efd/mysql_connector_python-9.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c180e0b4100d7402e03993bfac5c97d18e01d7ca9d198d742fffc245077f8ffe", size = 16515709, upload-time = "2026-02-10T12:04:04.924Z" }, + { url = "https://files.pythonhosted.org/packages/15/dd/b3250826c29cee7816de4409a2fe5e469a68b9a89f6bfaa5eed74f05532c/mysql_connector_python-9.6.0-py2.py3-none-any.whl", hash = "sha256:44b0fb57207ebc6ae05b5b21b7968a9ed33b29187fe87b38951bad2a334d75d5", size = 480527, upload-time = "2026-02-10T12:04:36.176Z" }, ] [[package]] @@ -4131,35 +4268,11 @@ wheels = [ [[package]] name = "networkx" -version = "3.6" +version = "3.6.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/fc/7b6fd4d22c8c4dc5704430140d8b3f520531d4fe7328b8f8d03f5a7950e8/networkx-3.6.tar.gz", hash = "sha256:285276002ad1f7f7da0f7b42f004bcba70d381e936559166363707fdad3d72ad", size = 2511464, upload-time = "2025-11-24T03:03:47.158Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/51/63fe664f3908c97be9d2e4f1158eb633317598cfa6e1fc14af5383f17512/networkx-3.6.1.tar.gz", hash = "sha256:26b7c357accc0c8cde558ad486283728b65b6a95d85ee1cd66bafab4c8168509", size = 2517025, upload-time = "2025-12-08T17:02:39.908Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/07/c7/d64168da60332c17d24c0d2f08bdf3987e8d1ae9d84b5bbd0eec2eb26a55/networkx-3.6-py3-none-any.whl", hash = "sha256:cdb395b105806062473d3be36458d8f1459a4e4b98e236a66c3a48996e07684f", size = 2063713, upload-time = "2025-11-24T03:03:45.21Z" }, -] - -[[package]] -name = "nh3" -version = "0.3.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cc/37/ab55eb2b05e334ff9a1ad52c556ace1f9c20a3f63613a165d384d5387657/nh3-0.3.3.tar.gz", hash = "sha256:185ed41b88c910b9ca8edc89ca3b4be688a12cb9de129d84befa2f74a0039fee", size = 18968, upload-time = "2026-02-14T09:35:15.664Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/13/3e/aef8cf8e0419b530c95e96ae93a5078e9b36c1e6613eeb1df03a80d5194e/nh3-0.3.3-cp38-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:e8ee96156f7dfc6e30ecda650e480c5ae0a7d38f0c6fafc3c1c655e2500421d9", size = 1448640, upload-time = "2026-02-14T09:34:49.316Z" }, - { url = "https://files.pythonhosted.org/packages/ca/43/d2011a4f6c0272cb122eeff40062ee06bb2b6e57eabc3a5e057df0d582df/nh3-0.3.3-cp38-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45fe0d6a607264910daec30360c8a3b5b1500fd832d21b2da608256287bcb92d", size = 839405, upload-time = "2026-02-14T09:34:50.779Z" }, - { url = "https://files.pythonhosted.org/packages/f8/f3/965048510c1caf2a34ed04411a46a04a06eb05563cd06f1aa57b71eb2bc8/nh3-0.3.3-cp38-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:5bc1d4b30ba1ba896669d944b6003630592665974bd11a3dc2f661bde92798a7", size = 825849, upload-time = "2026-02-14T09:34:52.622Z" }, - { url = "https://files.pythonhosted.org/packages/78/99/b4bbc6ad16329d8db2c2c320423f00b549ca3b129c2b2f9136be2606dbb0/nh3-0.3.3-cp38-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f433a2dd66545aad4a720ad1b2150edcdca75bfff6f4e6f378ade1ec138d5e77", size = 1068303, upload-time = "2026-02-14T09:34:54.179Z" }, - { url = "https://files.pythonhosted.org/packages/3f/34/3420d97065aab1b35f3e93ce9c96c8ebd423ce86fe84dee3126790421a2a/nh3-0.3.3-cp38-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52e973cb742e95b9ae1b35822ce23992428750f4b46b619fe86eba4205255b30", size = 1029316, upload-time = "2026-02-14T09:34:56.186Z" }, - { url = "https://files.pythonhosted.org/packages/f1/9a/99eda757b14e596fdb2ca5f599a849d9554181aa899274d0d183faef4493/nh3-0.3.3-cp38-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4c730617bdc15d7092dcc0469dc2826b914c8f874996d105b4bc3842a41c1cd9", size = 919944, upload-time = "2026-02-14T09:34:57.886Z" }, - { url = "https://files.pythonhosted.org/packages/6f/84/c0dc75c7fb596135f999e59a410d9f45bdabb989f1cb911f0016d22b747b/nh3-0.3.3-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e98fa3dbfd54e25487e36ba500bc29bca3a4cab4ffba18cfb1a35a2d02624297", size = 811461, upload-time = "2026-02-14T09:34:59.65Z" }, - { url = "https://files.pythonhosted.org/packages/7e/ec/b1bf57cab6230eec910e4863528dc51dcf21b57aaf7c88ee9190d62c9185/nh3-0.3.3-cp38-abi3-manylinux_2_31_riscv64.whl", hash = "sha256:3a62b8ae7c235481715055222e54c682422d0495a5c73326807d4e44c5d14691", size = 840360, upload-time = "2026-02-14T09:35:01.444Z" }, - { url = "https://files.pythonhosted.org/packages/37/5e/326ae34e904dde09af1de51219a611ae914111f0970f2f111f4f0188f57e/nh3-0.3.3-cp38-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fc305a2264868ec8fa16548296f803d8fd9c1fa66cd28b88b605b1bd06667c0b", size = 859872, upload-time = "2026-02-14T09:35:03.348Z" }, - { url = "https://files.pythonhosted.org/packages/09/38/7eba529ce17ab4d3790205da37deabb4cb6edcba15f27b8562e467f2fc97/nh3-0.3.3-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:90126a834c18af03bfd6ff9a027bfa6bbf0e238527bc780a24de6bd7cc1041e2", size = 1023550, upload-time = "2026-02-14T09:35:04.829Z" }, - { url = "https://files.pythonhosted.org/packages/05/a2/556fdecd37c3681b1edee2cf795a6799c6ed0a5551b2822636960d7e7651/nh3-0.3.3-cp38-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:24769a428e9e971e4ccfb24628f83aaa7dc3c8b41b130c8ddc1835fa1c924489", size = 1105212, upload-time = "2026-02-14T09:35:06.821Z" }, - { url = "https://files.pythonhosted.org/packages/dd/e3/5db0b0ad663234967d83702277094687baf7c498831a2d3ad3451c11770f/nh3-0.3.3-cp38-abi3-musllinux_1_2_i686.whl", hash = "sha256:b7a18ee057761e455d58b9d31445c3e4b2594cff4ddb84d2e331c011ef46f462", size = 1069970, upload-time = "2026-02-14T09:35:08.504Z" }, - { url = "https://files.pythonhosted.org/packages/79/b2/2ea21b79c6e869581ce5f51549b6e185c4762233591455bf2a326fb07f3b/nh3-0.3.3-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5a4b2c1f3e6f3cbe7048e17f4fefad3f8d3e14cc0fd08fb8599e0d5653f6b181", size = 1047588, upload-time = "2026-02-14T09:35:09.911Z" }, - { url = "https://files.pythonhosted.org/packages/e2/92/2e434619e658c806d9c096eed2cdff9a883084299b7b19a3f0824eb8e63d/nh3-0.3.3-cp38-abi3-win32.whl", hash = "sha256:e974850b131fdffa75e7ad8e0d9c7a855b96227b093417fdf1bd61656e530f37", size = 616179, upload-time = "2026-02-14T09:35:11.366Z" }, - { url = "https://files.pythonhosted.org/packages/73/88/1ce287ef8649dc51365b5094bd3713b76454838140a32ab4f8349973883c/nh3-0.3.3-cp38-abi3-win_amd64.whl", hash = "sha256:2efd17c0355d04d39e6d79122b42662277ac10a17ea48831d90b46e5ef7e4fc0", size = 631159, upload-time = "2026-02-14T09:35:12.77Z" }, - { url = "https://files.pythonhosted.org/packages/31/f1/b4835dbde4fb06f29db89db027576d6014081cd278d9b6751facc3e69e43/nh3-0.3.3-cp38-abi3-win_arm64.whl", hash = "sha256:b838e619f483531483d26d889438e53a880510e832d2aafe73f93b7b1ac2bce2", size = 616645, upload-time = "2026-02-14T09:35:14.062Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" }, ] [[package]] @@ -4177,51 +4290,40 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" }, ] -[[package]] -name = "nodeenv" -version = "1.10.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/bf/d1bda4f6168e0b2e9e5958945e01910052158313224ada5ce1fb2e1113b8/nodeenv-1.10.0.tar.gz", hash = "sha256:996c191ad80897d076bdfba80a41994c2b47c68e224c542b48feba42ba00f8bb", size = 55611, upload-time = "2025-12-20T14:08:54.006Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/b2/d0896bdcdc8d28a7fc5717c305f1a861c26e18c05047949fb371034d98bd/nodeenv-1.10.0-py2.py3-none-any.whl", hash = "sha256:5bb13e3eed2923615535339b3c620e76779af4cb4c6a90deccc9e36b274d3827", size = 23438, upload-time = "2025-12-20T14:08:52.782Z" }, -] - [[package]] name = "nodejs-wheel-binaries" -version = "24.11.1" +version = "24.14.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e4/89/da307731fdbb05a5f640b26de5b8ac0dc463fef059162accfc89e32f73bc/nodejs_wheel_binaries-24.11.1.tar.gz", hash = "sha256:413dfffeadfb91edb4d8256545dea797c237bba9b3faefea973cde92d96bb922", size = 8059, upload-time = "2025-11-18T18:21:58.207Z" } +sdist = { url = "https://files.pythonhosted.org/packages/71/05/c75c0940b1ebf82975d14f37176679b6f3229eae8b47b6a70d1e1dae0723/nodejs_wheel_binaries-24.14.0.tar.gz", hash = "sha256:c87b515e44b0e4a523017d8c59f26ccbd05b54fe593338582825d4b51fc91e1c", size = 8057, upload-time = "2026-02-27T02:57:30.931Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e4/5f/be5a4112e678143d4c15264d918f9a2dc086905c6426eb44515cf391a958/nodejs_wheel_binaries-24.11.1-py2.py3-none-macosx_13_0_arm64.whl", hash = "sha256:0e14874c3579def458245cdbc3239e37610702b0aa0975c1dc55e2cb80e42102", size = 55114309, upload-time = "2025-11-18T18:21:21.697Z" }, - { url = "https://files.pythonhosted.org/packages/fa/1c/2e9d6af2ea32b65928c42b3e5baa7a306870711d93c3536cb25fc090a80d/nodejs_wheel_binaries-24.11.1-py2.py3-none-macosx_13_0_x86_64.whl", hash = "sha256:c2741525c9874b69b3e5a6d6c9179a6fe484ea0c3d5e7b7c01121c8e5d78b7e2", size = 55285957, upload-time = "2025-11-18T18:21:27.177Z" }, - { url = "https://files.pythonhosted.org/packages/d0/79/35696d7ba41b1bd35ef8682f13d46ba38c826c59e58b86b267458eb53d87/nodejs_wheel_binaries-24.11.1-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:5ef598101b0fb1c2bf643abb76dfbf6f76f1686198ed17ae46009049ee83c546", size = 59645875, upload-time = "2025-11-18T18:21:33.004Z" }, - { url = "https://files.pythonhosted.org/packages/b4/98/2a9694adee0af72bc602a046b0632a0c89e26586090c558b1c9199b187cc/nodejs_wheel_binaries-24.11.1-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:cde41d5e4705266688a8d8071debf4f8a6fcea264c61292782672ee75a6905f9", size = 60140941, upload-time = "2025-11-18T18:21:37.228Z" }, - { url = "https://files.pythonhosted.org/packages/d0/d6/573e5e2cba9d934f5f89d0beab00c3315e2e6604eb4df0fcd1d80c5a07a8/nodejs_wheel_binaries-24.11.1-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:78bc5bb889313b565df8969bb7423849a9c7fc218bf735ff0ce176b56b3e96f0", size = 61644243, upload-time = "2025-11-18T18:21:43.325Z" }, - { url = "https://files.pythonhosted.org/packages/c7/e6/643234d5e94067df8ce8d7bba10f3804106668f7a1050aeb10fdd226ead4/nodejs_wheel_binaries-24.11.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:c79a7e43869ccecab1cae8183778249cceb14ca2de67b5650b223385682c6239", size = 62225657, upload-time = "2025-11-18T18:21:47.708Z" }, - { url = "https://files.pythonhosted.org/packages/4d/1c/2fb05127102a80225cab7a75c0e9edf88a0a1b79f912e1e36c7c1aaa8f4e/nodejs_wheel_binaries-24.11.1-py2.py3-none-win_amd64.whl", hash = "sha256:10197b1c9c04d79403501766f76508b0dac101ab34371ef8a46fcf51773497d0", size = 41322308, upload-time = "2025-11-18T18:21:51.347Z" }, - { url = "https://files.pythonhosted.org/packages/ad/b7/bc0cdbc2cc3a66fcac82c79912e135a0110b37b790a14c477f18e18d90cd/nodejs_wheel_binaries-24.11.1-py2.py3-none-win_arm64.whl", hash = "sha256:376b9ea1c4bc1207878975dfeb604f7aa5668c260c6154dcd2af9d42f7734116", size = 39026497, upload-time = "2025-11-18T18:21:54.634Z" }, + { url = "https://files.pythonhosted.org/packages/58/8c/b057c2db3551a6fe04e93dd14e33d810ac8907891534ffcc7a051b253858/nodejs_wheel_binaries-24.14.0-py2.py3-none-macosx_13_0_arm64.whl", hash = "sha256:59bb78b8eb08c3e32186da1ef913f1c806b5473d8bd0bb4492702092747b674a", size = 54798488, upload-time = "2026-02-27T02:56:56.831Z" }, + { url = "https://files.pythonhosted.org/packages/30/88/7e1b29c067b6625c97c81eb8b0ef37cf5ad5b62bb81e23f4bde804910ec9/nodejs_wheel_binaries-24.14.0-py2.py3-none-macosx_13_0_x86_64.whl", hash = "sha256:348fa061b57625de7250d608e2d9b7c4bc170544da7e328325343860eadd59e5", size = 54972803, upload-time = "2026-02-27T02:57:01.696Z" }, + { url = "https://files.pythonhosted.org/packages/1e/e0/a83f0ff12faca2a56366462e572e38ac6f5cb361877bb29e289138eb7f24/nodejs_wheel_binaries-24.14.0-py2.py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:222dbf516ccc877afcad4e4789a81b4ee93daaa9f0ad97c464417d9597f49449", size = 59340859, upload-time = "2026-02-27T02:57:06.125Z" }, + { url = "https://files.pythonhosted.org/packages/e2/9f/06fad4ae8a723ae7096b5311eba67ad8b4df5f359c0a68e366750b7fef78/nodejs_wheel_binaries-24.14.0-py2.py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:b35d6fcccfe4fb0a409392d237fbc67796bac0d357b996bc12d057a1531a238b", size = 59838751, upload-time = "2026-02-27T02:57:10.449Z" }, + { url = "https://files.pythonhosted.org/packages/8c/72/4916dadc7307c3e9bcfa43b4b6f88237932d502c66f89eb2d90fb07810db/nodejs_wheel_binaries-24.14.0-py2.py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:519507fb74f3f2b296ab1e9f00dcc211f36bbfb93c60229e72dcdee9dafd301a", size = 61340534, upload-time = "2026-02-27T02:57:15.309Z" }, + { url = "https://files.pythonhosted.org/packages/2e/df/a8ba881ee5d04b04e0d93abc8ce501ff7292813583e97f9789eb3fc0472a/nodejs_wheel_binaries-24.14.0-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:68c93c52ff06d704bcb5ed160b4ba04ab1b291d238aaf996b03a5396e0e9a7ed", size = 61922394, upload-time = "2026-02-27T02:57:20.24Z" }, + { url = "https://files.pythonhosted.org/packages/60/8c/b8c5f61201c72a0c7dc694b459941f89a6defda85deff258a9940a4e2efc/nodejs_wheel_binaries-24.14.0-py2.py3-none-win_amd64.whl", hash = "sha256:60b83c4e98b0c7d836ac9ccb67dcb36e343691cbe62cd325799ff9ed936286f3", size = 41218783, upload-time = "2026-02-27T02:57:24.175Z" }, + { url = "https://files.pythonhosted.org/packages/91/23/1f904bc9cbd8eece393e20840c08ba3ac03440090c3a4e95168fa6d2709f/nodejs_wheel_binaries-24.14.0-py2.py3-none-win_arm64.whl", hash = "sha256:78a9bd1d6b11baf1433f9fb84962ff8aa71c87d48b6434f98224bc49a2253a6e", size = 38926103, upload-time = "2026-02-27T02:57:27.458Z" }, ] [[package]] name = "numba" -version = "0.62.1" +version = "0.64.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "llvmlite" }, { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a3/20/33dbdbfe60e5fd8e3dbfde299d106279a33d9f8308346022316781368591/numba-0.62.1.tar.gz", hash = "sha256:7b774242aa890e34c21200a1fc62e5b5757d5286267e71103257f4e2af0d5161", size = 2749817, upload-time = "2025-09-29T10:46:31.551Z" } +sdist = { url = "https://files.pythonhosted.org/packages/23/c9/a0fb41787d01d621046138da30f6c2100d80857bf34b3390dd68040f27a3/numba-0.64.0.tar.gz", hash = "sha256:95e7300af648baa3308127b1955b52ce6d11889d16e8cfe637b4f85d2fca52b1", size = 2765679, upload-time = "2026-02-18T18:41:20.974Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/dd/5f/8b3491dd849474f55e33c16ef55678ace1455c490555337899c35826836c/numba-0.62.1-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:f43e24b057714e480fe44bc6031de499e7cf8150c63eb461192caa6cc8530bc8", size = 2684279, upload-time = "2025-09-29T10:43:37.213Z" }, - { url = "https://files.pythonhosted.org/packages/bf/18/71969149bfeb65a629e652b752b80167fe8a6a6f6e084f1f2060801f7f31/numba-0.62.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:57cbddc53b9ee02830b828a8428757f5c218831ccc96490a314ef569d8342b7b", size = 2687330, upload-time = "2025-09-29T10:43:59.601Z" }, - { url = "https://files.pythonhosted.org/packages/0e/7d/403be3fecae33088027bc8a95dc80a2fda1e3beff3e0e5fc4374ada3afbe/numba-0.62.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:604059730c637c7885386521bb1b0ddcbc91fd56131a6dcc54163d6f1804c872", size = 3739727, upload-time = "2025-09-29T10:42:45.922Z" }, - { url = "https://files.pythonhosted.org/packages/e0/c3/3d910d08b659a6d4c62ab3cd8cd93c4d8b7709f55afa0d79a87413027ff6/numba-0.62.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d6c540880170bee817011757dc9049dba5a29db0c09b4d2349295991fe3ee55f", size = 3445490, upload-time = "2025-09-29T10:43:12.692Z" }, - { url = "https://files.pythonhosted.org/packages/5b/82/9d425c2f20d9f0a37f7cb955945a553a00fa06a2b025856c3550227c5543/numba-0.62.1-cp311-cp311-win_amd64.whl", hash = "sha256:03de6d691d6b6e2b76660ba0f38f37b81ece8b2cc524a62f2a0cfae2bfb6f9da", size = 2745550, upload-time = "2025-09-29T10:44:20.571Z" }, - { url = "https://files.pythonhosted.org/packages/5e/fa/30fa6873e9f821c0ae755915a3ca444e6ff8d6a7b6860b669a3d33377ac7/numba-0.62.1-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:1b743b32f8fa5fff22e19c2e906db2f0a340782caf024477b97801b918cf0494", size = 2685346, upload-time = "2025-09-29T10:43:43.677Z" }, - { url = "https://files.pythonhosted.org/packages/a9/d5/504ce8dc46e0dba2790c77e6b878ee65b60fe3e7d6d0006483ef6fde5a97/numba-0.62.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:90fa21b0142bcf08ad8e32a97d25d0b84b1e921bc9423f8dda07d3652860eef6", size = 2688139, upload-time = "2025-09-29T10:44:04.894Z" }, - { url = "https://files.pythonhosted.org/packages/50/5f/6a802741176c93f2ebe97ad90751894c7b0c922b52ba99a4395e79492205/numba-0.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6ef84d0ac19f1bf80431347b6f4ce3c39b7ec13f48f233a48c01e2ec06ecbc59", size = 3796453, upload-time = "2025-09-29T10:42:52.771Z" }, - { url = "https://files.pythonhosted.org/packages/7e/df/efd21527d25150c4544eccc9d0b7260a5dec4b7e98b5a581990e05a133c0/numba-0.62.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9315cc5e441300e0ca07c828a627d92a6802bcbf27c5487f31ae73783c58da53", size = 3496451, upload-time = "2025-09-29T10:43:19.279Z" }, - { url = "https://files.pythonhosted.org/packages/80/44/79bfdab12a02796bf4f1841630355c82b5a69933b1d50eb15c7fa37dabe8/numba-0.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:44e3aa6228039992f058f5ebfcfd372c83798e9464297bdad8cc79febcf7891e", size = 2745552, upload-time = "2025-09-29T10:44:26.399Z" }, + { url = "https://files.pythonhosted.org/packages/89/a3/1a4286a1c16136c8896d8e2090d950e79b3ec626d3a8dc9620f6234d5a38/numba-0.64.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:766156ee4b8afeeb2b2e23c81307c5d19031f18d5ce76ae2c5fb1429e72fa92b", size = 2682938, upload-time = "2026-02-18T18:40:52.897Z" }, + { url = "https://files.pythonhosted.org/packages/19/16/aa6e3ba3cd45435c117d1101b278b646444ed05b7c712af631b91353f573/numba-0.64.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d17071b4ffc9d39b75d8e6c101a36f0c81b646123859898c9799cb31807c8f78", size = 3747376, upload-time = "2026-02-18T18:40:54.925Z" }, + { url = "https://files.pythonhosted.org/packages/c0/f1/dd2f25e18d75fdf897f730b78c5a7b00cc4450f2405564dbebfaf359f21f/numba-0.64.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4ead5630434133bac87fa67526eacb264535e4e9a2d5ec780e0b4fc381a7d275", size = 3453292, upload-time = "2026-02-18T18:40:56.818Z" }, + { url = "https://files.pythonhosted.org/packages/31/29/e09d5630578a50a2b3fa154990b6b839cf95327aa0709e2d50d0b6816cd1/numba-0.64.0-cp311-cp311-win_amd64.whl", hash = "sha256:f2b1fd93e7aaac07d6fbaed059c00679f591f2423885c206d8c1b55d65ca3f2d", size = 2749824, upload-time = "2026-02-18T18:40:58.392Z" }, + { url = "https://files.pythonhosted.org/packages/70/a6/9fc52cb4f0d5e6d8b5f4d81615bc01012e3cf24e1052a60f17a68deb8092/numba-0.64.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:69440a8e8bc1a81028446f06b363e28635aa67bd51b1e498023f03b812e0ce68", size = 2683418, upload-time = "2026-02-18T18:40:59.886Z" }, + { url = "https://files.pythonhosted.org/packages/9b/89/1a74ea99b180b7a5587b0301ed1b183a2937c4b4b67f7994689b5d36fc34/numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f13721011f693ba558b8dd4e4db7f2640462bba1b855bdc804be45bbeb55031a", size = 3804087, upload-time = "2026-02-18T18:41:01.699Z" }, + { url = "https://files.pythonhosted.org/packages/91/e1/583c647404b15f807410510fec1eb9b80cb8474165940b7749f026f21cbc/numba-0.64.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0b180b1133f2b5d8b3f09d96b6d7a9e51a7da5dda3c09e998b5bcfac85d222c", size = 3504309, upload-time = "2026-02-18T18:41:03.252Z" }, + { url = "https://files.pythonhosted.org/packages/85/23/0fce5789b8a5035e7ace21216a468143f3144e02013252116616c58339aa/numba-0.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:e63dc94023b47894849b8b106db28ccb98b49d5498b98878fac1a38f83ac007a", size = 2752740, upload-time = "2026-02-18T18:41:05.097Z" }, ] [[package]] @@ -4277,14 +4379,14 @@ wheels = [ [[package]] name = "numpy-typing-compat" -version = "20250818.1.25" +version = "20251206.1.25" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ff/a7/780dc00f4fed2f2b653f76a196b3a6807c7c667f30ae95a7fd082c1081d8/numpy_typing_compat-20250818.1.25.tar.gz", hash = "sha256:8ff461725af0b436e9b0445d07712f1e6e3a97540a3542810f65f936dcc587a5", size = 5027, upload-time = "2025-08-18T23:46:39.062Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/63/f166333649396d083b9e95b5aa15feb56f9168f766a72540132206119937/numpy_typing_compat-20251206.1.25.tar.gz", hash = "sha256:27ff188fe70102312ea5e8553423897a4f3365eee15aa2a7ee1fcf6efc6fed12", size = 5060, upload-time = "2025-12-06T20:02:00.974Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/71/30e8d317b6896acbc347d3089764b6209ba299095550773e14d27dcf035f/numpy_typing_compat-20250818.1.25-py3-none-any.whl", hash = "sha256:4f91427369583074b236c804dd27559134f08ec4243485034c8e7d258cbd9cd3", size = 6355, upload-time = "2025-08-18T23:46:30.927Z" }, + { url = "https://files.pythonhosted.org/packages/b4/cb/99443f79c562466d128e3bf94d1507146fba386ec2ce85e97fe916225691/numpy_typing_compat-20251206.1.25-py3-none-any.whl", hash = "sha256:9be87412b68c1e9e193e7bfd996cae4ec07de5880c19d70bf81f890f51644e7f", size = 6354, upload-time = "2025-12-06T20:01:51.007Z" }, ] [[package]] @@ -4314,25 +4416,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/17/d3/b64c356a907242d719fc668b71befd73324e47ab46c8ebbbede252c154b2/olefile-0.47-py2.py3-none-any.whl", hash = "sha256:543c7da2a7adadf21214938bb79c83ea12b473a4b6ee4ad4bf854e7715e13d1f", size = 114565, upload-time = "2023-12-01T16:22:51.518Z" }, ] -[[package]] -name = "ollama" -version = "0.6.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "httpx" }, - { name = "pydantic" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/9d/5a/652dac4b7affc2b37b95386f8ae78f22808af09d720689e3d7a86b6ed98e/ollama-0.6.1.tar.gz", hash = "sha256:478c67546836430034b415ed64fa890fd3d1ff91781a9d548b3325274e69d7c6", size = 51620, upload-time = "2025-11-13T23:02:17.416Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/47/4f/4a617ee93d8208d2bcf26b2d8b9402ceaed03e3853c754940e2290fed063/ollama-0.6.1-py3-none-any.whl", hash = "sha256:fc4c984b345735c5486faeee67d8a265214a31cbb828167782dc642ce0a2bf8c", size = 14354, upload-time = "2025-11-13T23:02:16.292Z" }, -] - [[package]] name = "onnxruntime" -version = "1.23.2" +version = "1.24.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "coloredlogs" }, { name = "flatbuffers" }, { name = "numpy" }, { name = "packaging" }, @@ -4340,21 +4428,21 @@ dependencies = [ { name = "sympy" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/44/be/467b00f09061572f022ffd17e49e49e5a7a789056bad95b54dfd3bee73ff/onnxruntime-1.23.2-cp311-cp311-macosx_13_0_arm64.whl", hash = "sha256:6f91d2c9b0965e86827a5ba01531d5b669770b01775b23199565d6c1f136616c", size = 17196113, upload-time = "2025-10-22T03:47:33.526Z" }, - { url = "https://files.pythonhosted.org/packages/9f/a8/3c23a8f75f93122d2b3410bfb74d06d0f8da4ac663185f91866b03f7da1b/onnxruntime-1.23.2-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:87d8b6eaf0fbeb6835a60a4265fde7a3b60157cf1b2764773ac47237b4d48612", size = 19153857, upload-time = "2025-10-22T03:46:37.578Z" }, - { url = "https://files.pythonhosted.org/packages/3f/d8/506eed9af03d86f8db4880a4c47cd0dffee973ef7e4f4cff9f1d4bcf7d22/onnxruntime-1.23.2-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bbfd2fca76c855317568c1b36a885ddea2272c13cb0e395002c402f2360429a6", size = 15220095, upload-time = "2025-10-22T03:46:24.769Z" }, - { url = "https://files.pythonhosted.org/packages/e9/80/113381ba832d5e777accedc6cb41d10f9eca82321ae31ebb6bcede530cea/onnxruntime-1.23.2-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:da44b99206e77734c5819aa2142c69e64f3b46edc3bd314f6a45a932defc0b3e", size = 17372080, upload-time = "2025-10-22T03:47:00.265Z" }, - { url = "https://files.pythonhosted.org/packages/3a/db/1b4a62e23183a0c3fe441782462c0ede9a2a65c6bbffb9582fab7c7a0d38/onnxruntime-1.23.2-cp311-cp311-win_amd64.whl", hash = "sha256:902c756d8b633ce0dedd889b7c08459433fbcf35e9c38d1c03ddc020f0648c6e", size = 13468349, upload-time = "2025-10-22T03:47:25.783Z" }, - { url = "https://files.pythonhosted.org/packages/1b/9e/f748cd64161213adeef83d0cb16cb8ace1e62fa501033acdd9f9341fff57/onnxruntime-1.23.2-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:b8f029a6b98d3cf5be564d52802bb50a8489ab73409fa9db0bf583eabb7c2321", size = 17195929, upload-time = "2025-10-22T03:47:36.24Z" }, - { url = "https://files.pythonhosted.org/packages/91/9d/a81aafd899b900101988ead7fb14974c8a58695338ab6a0f3d6b0100f30b/onnxruntime-1.23.2-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:218295a8acae83905f6f1aed8cacb8e3eb3bd7513a13fe4ba3b2664a19fc4a6b", size = 19157705, upload-time = "2025-10-22T03:46:40.415Z" }, - { url = "https://files.pythonhosted.org/packages/3c/35/4e40f2fba272a6698d62be2cd21ddc3675edfc1a4b9ddefcc4648f115315/onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76ff670550dc23e58ea9bc53b5149b99a44e63b34b524f7b8547469aaa0dcb8c", size = 15226915, upload-time = "2025-10-22T03:46:27.773Z" }, - { url = "https://files.pythonhosted.org/packages/ef/88/9cc25d2bafe6bc0d4d3c1db3ade98196d5b355c0b273e6a5dc09c5d5d0d5/onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f9b4ae77f8e3c9bee50c27bc1beede83f786fe1d52e99ac85aa8d65a01e9b77", size = 17382649, upload-time = "2025-10-22T03:47:02.782Z" }, - { url = "https://files.pythonhosted.org/packages/c0/b4/569d298f9fc4d286c11c45e85d9ffa9e877af12ace98af8cab52396e8f46/onnxruntime-1.23.2-cp312-cp312-win_amd64.whl", hash = "sha256:25de5214923ce941a3523739d34a520aac30f21e631de53bba9174dc9c004435", size = 13470528, upload-time = "2025-10-22T03:47:28.106Z" }, + { url = "https://files.pythonhosted.org/packages/60/69/6c40720201012c6af9aa7d4ecdd620e521bd806dc6269d636fdd5c5aeebe/onnxruntime-1.24.4-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:0bdfce8e9a6497cec584aab407b71bf697dac5e1b7b7974adc50bf7533bdb3a2", size = 17332131, upload-time = "2026-03-17T22:05:49.005Z" }, + { url = "https://files.pythonhosted.org/packages/38/e9/8c901c150ce0c368da38638f44152fb411059c0c7364b497c9e5c957321a/onnxruntime-1.24.4-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:046ff290045a387676941a02a8ae5c3ebec6b4f551ae228711968c4a69d8f6b7", size = 15152472, upload-time = "2026-03-17T22:03:26.176Z" }, + { url = "https://files.pythonhosted.org/packages/d5/b6/7a4df417cdd01e8f067a509e123ac8b31af450a719fa7ed81787dd6057ec/onnxruntime-1.24.4-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e54ad52e61d2d4618dcff8fa1480ac66b24ee2eab73331322db1049f11ccf330", size = 17222993, upload-time = "2026-03-17T22:04:34.485Z" }, + { url = "https://files.pythonhosted.org/packages/dd/59/8febe015f391aa1757fa5ba82c759ea4b6c14ef970132efb5e316665ba61/onnxruntime-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b43b63eb24a2bc8fc77a09be67587a570967a412cccb837b6245ccb546691153", size = 12594863, upload-time = "2026-03-17T22:05:38.749Z" }, + { url = "https://files.pythonhosted.org/packages/32/84/4155fcd362e8873eb6ce305acfeeadacd9e0e59415adac474bea3d9281bb/onnxruntime-1.24.4-cp311-cp311-win_arm64.whl", hash = "sha256:e26478356dba25631fb3f20112e345f8e8bf62c499bb497e8a559f7d69cf7e7b", size = 12259895, upload-time = "2026-03-17T22:05:28.812Z" }, + { url = "https://files.pythonhosted.org/packages/d7/38/31db1b232b4ba960065a90c1506ad7a56995cd8482033184e97fadca17cc/onnxruntime-1.24.4-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cad1c2b3f455c55678ab2a8caa51fb420c25e6e3cf10f4c23653cdabedc8de78", size = 17341875, upload-time = "2026-03-17T22:05:51.669Z" }, + { url = "https://files.pythonhosted.org/packages/aa/60/c4d1c8043eb42f8a9aa9e931c8c293d289c48ff463267130eca97d13357f/onnxruntime-1.24.4-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1a5c5a544b22f90859c88617ecb30e161ee3349fcc73878854f43d77f00558b5", size = 15172485, upload-time = "2026-03-17T22:03:32.182Z" }, + { url = "https://files.pythonhosted.org/packages/6d/ab/5b68110e0460d73fad814d5bd11c7b1ddcce5c37b10177eb264d6a36e331/onnxruntime-1.24.4-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0d640eb9f3782689b55cfa715094474cd5662f2f137be6a6f847a594b6e9705c", size = 17244912, upload-time = "2026-03-17T22:04:37.251Z" }, + { url = "https://files.pythonhosted.org/packages/8b/f4/6b89e297b93704345f0f3f8c62229bee323ef25682a3f9b4f89a39324950/onnxruntime-1.24.4-cp312-cp312-win_amd64.whl", hash = "sha256:535b29475ca42b593c45fbb2152fbf1cdf3f287315bf650e6a724a0a1d065cdb", size = 12596856, upload-time = "2026-03-17T22:05:41.224Z" }, + { url = "https://files.pythonhosted.org/packages/43/06/8b8ec6e9e6a474fcd5d772453f627ad4549dfe3ab8c0bf70af5afcde551b/onnxruntime-1.24.4-cp312-cp312-win_arm64.whl", hash = "sha256:e6214096e14b7b52e3bee1903dc12dc7ca09cb65e26664668a4620cc5e6f9a90", size = 12270275, upload-time = "2026-03-17T22:05:31.132Z" }, ] [[package]] name = "openai" -version = "1.109.1" +version = "2.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -4366,9 +4454,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c6/a1/a303104dc55fc546a3f6914c842d3da471c64eec92043aef8f652eb6c524/openai-1.109.1.tar.gz", hash = "sha256:d173ed8dbca665892a6db099b4a2dfac624f94d20a93f46eb0b56aae940ed869", size = 564133, upload-time = "2025-09-24T13:00:53.075Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b4/15/203d537e58986b5673e7f232453a2a2f110f22757b15921cbdeea392e520/openai-2.29.0.tar.gz", hash = "sha256:32d09eb2f661b38d3edd7d7e1a2943d1633f572596febe64c0cd370c86d52bec", size = 671128, upload-time = "2026-03-17T17:53:49.599Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/2a/7dd3d207ec669cacc1f186fd856a0f61dbc255d24f6fdc1a6715d6051b0f/openai-1.109.1-py3-none-any.whl", hash = "sha256:6bcaf57086cf59159b8e27447e4e7dd019db5d29a438072fbd49c290c7e65315", size = 948627, upload-time = "2025-09-24T13:00:50.754Z" }, + { url = "https://files.pythonhosted.org/packages/d0/b1/35b6f9c8cf9318e3dbb7146cc82dab4cf61182a8d5406fc9b50864362895/openai-2.29.0-py3-none-any.whl", hash = "sha256:b7c5de513c3286d17c5e29b92c4c98ceaf0d775244ac8159aeb1bddf840eb42a", size = 1141533, upload-time = "2026-03-17T17:53:47.348Z" }, ] [[package]] @@ -4389,7 +4477,7 @@ wheels = [ [[package]] name = "openinference-instrumentation" -version = "0.1.42" +version = "0.1.46" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "openinference-semantic-conventions" }, @@ -4397,18 +4485,18 @@ dependencies = [ { name = "opentelemetry-sdk" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/00/d0/b19061a21fd6127d2857c77744a36073bba9c1502d1d5e8517b708eb8b7c/openinference_instrumentation-0.1.42.tar.gz", hash = "sha256:2275babc34022e151b5492cfba41d3b12e28377f8e08cb45e5d64fe2d9d7fe37", size = 23954, upload-time = "2025-11-05T01:37:46.869Z" } +sdist = { url = "https://files.pythonhosted.org/packages/64/8d/9b76b43e8b2ee2ccf1fe15b21c924095f9c0e4839919bcd4951b1c99c2ab/openinference_instrumentation-0.1.46.tar.gz", hash = "sha256:0b520002a1c682c525dcab49005c209bfd71611e8e4e4933b49779d5e899e6db", size = 23937, upload-time = "2026-03-04T10:13:48.883Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c3/71/43ee4616fc95dbd2f560550f199c6652a5eb93f84e8aa0039bc95c19cfe0/openinference_instrumentation-0.1.42-py3-none-any.whl", hash = "sha256:e7521ff90833ef7cc65db526a2f59b76a496180abeaaee30ec6abbbc0b43f8ec", size = 30086, upload-time = "2025-11-05T01:37:43.866Z" }, + { url = "https://files.pythonhosted.org/packages/25/d1/f6668492152a4180492044313e2dc427fbc237904f6bb1629abd030e3469/openinference_instrumentation-0.1.46-py3-none-any.whl", hash = "sha256:f7b63ccd5f93ce82e4e40035c9faa6b021984cbe06ad791f4cf033551533bc48", size = 30124, upload-time = "2026-03-04T10:13:47.613Z" }, ] [[package]] name = "openinference-semantic-conventions" -version = "0.1.25" +version = "0.1.28" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/68/81c8a0b90334ff11e4f285e4934c57f30bea3ef0c0b9f99b65e7b80fae3b/openinference_semantic_conventions-0.1.25.tar.gz", hash = "sha256:f0a8c2cfbd00195d1f362b4803518341e80867d446c2959bf1743f1894fce31d", size = 12767, upload-time = "2025-11-05T01:37:45.89Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/32/c79bf8bd3ea5a00e492449b31ca600bbc2a8e88a301e42c872af925a156c/openinference_semantic_conventions-0.1.28.tar.gz", hash = "sha256:6388465174e8ab3f27ebc6a9e9bb2e1b804d30caefb57234e16db874da1c6a7b", size = 12893, upload-time = "2026-03-11T04:45:46.543Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fd/3d/dd14ee2eb8a3f3054249562e76b253a1545c76adbbfd43a294f71acde5c3/openinference_semantic_conventions-0.1.25-py3-none-any.whl", hash = "sha256:3814240f3bd61f05d9562b761de70ee793d55b03bca1634edf57d7a2735af238", size = 10395, upload-time = "2025-11-05T01:37:43.697Z" }, + { url = "https://files.pythonhosted.org/packages/04/40/34b570462c3ce250277254bb0cca655eb39b64c0dffe63cd7751f103f8d6/openinference_semantic_conventions-0.1.28-py3-none-any.whl", hash = "sha256:a2fed5bb167aa56c1c7448cdb7a8d775f989339ba1f8b04a7b45d4f8388cccfb", size = 10522, upload-time = "2026-03-11T04:45:45.423Z" }, ] [[package]] @@ -4454,59 +4542,59 @@ wheels = [ [[package]] name = "opentelemetry-api" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecated" }, { name = "importlib-metadata" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/83/93114b6de85a98963aec218a51509a52ed3f8de918fe91eb0f7299805c3f/opentelemetry_api-1.27.0.tar.gz", hash = "sha256:ed673583eaa5f81b5ce5e86ef7cdaf622f88ef65f0b9aab40b843dcae5bef342", size = 62693, upload-time = "2024-08-28T21:35:31.445Z" } +sdist = { url = "https://files.pythonhosted.org/packages/79/36/260eaea0f74fdd0c0d8f22ed3a3031109ea1c85531f94f4fde266c29e29a/opentelemetry_api-1.28.0.tar.gz", hash = "sha256:578610bcb8aa5cdcb11169d136cc752958548fb6ccffb0969c1036b0ee9e5353", size = 62803, upload-time = "2024-11-05T19:14:45.497Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/1f/737dcdbc9fea2fa96c1b392ae47275165a7c641663fbb08a8d252968eed2/opentelemetry_api-1.27.0-py3-none-any.whl", hash = "sha256:953d5871815e7c30c81b56d910c707588000fff7a3ca1c73e6531911d53065e7", size = 63970, upload-time = "2024-08-28T21:35:00.598Z" }, + { url = "https://files.pythonhosted.org/packages/22/e4/3b25d8b856791c04d8a62b1257b5fc09dc41a057800db06885af8ddcdce1/opentelemetry_api-1.28.0-py3-none-any.whl", hash = "sha256:8457cd2c59ea1bd0988560f021656cecd254ad7ef6be4ba09dbefeca2409ce52", size = 64314, upload-time = "2024-11-05T19:14:21.659Z" }, ] [[package]] name = "opentelemetry-distro" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-sdk" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/09/423e17c439ed24c45110affe84aad886a536b7871a42637d2ad14a179b47/opentelemetry_distro-0.48b0.tar.gz", hash = "sha256:5cb15915780ac4972583286a56683d43bd4ca95371d72f5f3f179c8b0b2ddc91", size = 2556, upload-time = "2024-08-28T21:27:40.455Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4d/75/7cb7c33899e66bb366d40a889111a78c22df0951038b6699f1663e715a9f/opentelemetry_distro-0.49b0.tar.gz", hash = "sha256:1bafa274f9e83baa0d2a5d47ed02caffcf9bcca60107b389b145400d82b07513", size = 2560, upload-time = "2024-11-05T19:21:39.379Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/82/cf/fa9a5fe954f1942e03b319ae0e319ebc93d9f984b548bcd9b3f232a1434d/opentelemetry_distro-0.48b0-py3-none-any.whl", hash = "sha256:b2f8fce114325b020769af3b9bf503efb8af07efc190bd1b9deac7843171664a", size = 3321, upload-time = "2024-08-28T21:26:26.584Z" }, + { url = "https://files.pythonhosted.org/packages/4c/db/806172b6a4933966eee518db814b375e620602f7fe776b74ef795690f135/opentelemetry_distro-0.49b0-py3-none-any.whl", hash = "sha256:1af4074702f605ea210753dd41947dc2fd61b39724f23cdcf15d5654867cd3c2", size = 3318, upload-time = "2024-11-05T19:20:34.065Z" }, ] [[package]] name = "opentelemetry-exporter-otlp" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-exporter-otlp-proto-grpc" }, { name = "opentelemetry-exporter-otlp-proto-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/d3/8156cc14e8f4573a3572ee7f30badc7aabd02961a09acc72ab5f2c789ef1/opentelemetry_exporter_otlp-1.27.0.tar.gz", hash = "sha256:4a599459e623868cc95d933c301199c2367e530f089750e115599fccd67cb2a1", size = 6166, upload-time = "2024-08-28T21:35:33.746Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/16/14e3fc163930ea68f0980a4cdd4ae5796e60aeb898965990e13263d64baf/opentelemetry_exporter_otlp-1.28.0.tar.gz", hash = "sha256:31ae7495831681dd3da34ac457f6970f147465ae4b9aae3a888d7a581c7cd868", size = 6170, upload-time = "2024-11-05T19:14:47.349Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/6d/95e1fc2c8d945a734db32e87a5aa7a804f847c1657a21351df9338bd1c9c/opentelemetry_exporter_otlp-1.27.0-py3-none-any.whl", hash = "sha256:7688791cbdd951d71eb6445951d1cfbb7b6b2d7ee5948fac805d404802931145", size = 7001, upload-time = "2024-08-28T21:35:04.02Z" }, + { url = "https://files.pythonhosted.org/packages/c2/82/3f521b3c1f2a411ed60a24a8c9f486c1beeaf8c6c55337c87d3ae1642151/opentelemetry_exporter_otlp-1.28.0-py3-none-any.whl", hash = "sha256:1fd02d70f2c1b7ac5579c81e78de4594b188d3317c8ceb69e8b53900fb7b40fd", size = 7024, upload-time = "2024-11-05T19:14:24.534Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-common" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-proto" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/2e/7eaf4ba595fb5213cf639c9158dfb64aacb2e4c7d74bfa664af89fa111f4/opentelemetry_exporter_otlp_proto_common-1.27.0.tar.gz", hash = "sha256:159d27cf49f359e3798c4c3eb8da6ef4020e292571bd8c5604a2a573231dd5c8", size = 17860, upload-time = "2024-08-28T21:35:34.896Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/8d/5d411084ac441052f4c9bae03a1aec65ae5d16b439fea7b9c5ac3842c013/opentelemetry_exporter_otlp_proto_common-1.28.0.tar.gz", hash = "sha256:5fa0419b0c8e291180b0fc8430a20dd44a3f3236f8e0827992145914f273ec4f", size = 18505, upload-time = "2024-11-05T19:14:48.204Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/27/4610ab3d9bb3cde4309b6505f98b3aabca04a26aa480aa18cede23149837/opentelemetry_exporter_otlp_proto_common-1.27.0-py3-none-any.whl", hash = "sha256:675db7fffcb60946f3a5c43e17d1168a3307a94a930ecf8d2ea1f286f3d4f79a", size = 17848, upload-time = "2024-08-28T21:35:05.412Z" }, + { url = "https://files.pythonhosted.org/packages/e1/72/3c44aabc74db325aaba09361b6a0d80f6d601f0ff86ecea8ee655c9538fc/opentelemetry_exporter_otlp_proto_common-1.28.0-py3-none-any.whl", hash = "sha256:467e6437d24e020156dffecece8c0a4471a8a60f6a34afeda7386df31a092410", size = 18403, upload-time = "2024-11-05T19:14:25.798Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-grpc" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecated" }, @@ -4517,14 +4605,14 @@ dependencies = [ { name = "opentelemetry-proto" }, { name = "opentelemetry-sdk" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a1/d0/c1e375b292df26e0ffebf194e82cd197e4c26cc298582bda626ce3ce74c5/opentelemetry_exporter_otlp_proto_grpc-1.27.0.tar.gz", hash = "sha256:af6f72f76bcf425dfb5ad11c1a6d6eca2863b91e63575f89bb7b4b55099d968f", size = 26244, upload-time = "2024-08-28T21:35:36.314Z" } +sdist = { url = "https://files.pythonhosted.org/packages/43/4d/f215162e58041afb4bdf5dbd0d8faf0b7fc9bf7b3d3fc0e44e06f9e7e869/opentelemetry_exporter_otlp_proto_grpc-1.28.0.tar.gz", hash = "sha256:47a11c19dc7f4289e220108e113b7de90d59791cb4c37fc29f69a6a56f2c3735", size = 26237, upload-time = "2024-11-05T19:14:49.026Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/80/32217460c2c64c0568cea38410124ff680a9b65f6732867bbf857c4d8626/opentelemetry_exporter_otlp_proto_grpc-1.27.0-py3-none-any.whl", hash = "sha256:56b5bbd5d61aab05e300d9d62a6b3c134827bbd28d0b12f2649c2da368006c9e", size = 18541, upload-time = "2024-08-28T21:35:06.493Z" }, + { url = "https://files.pythonhosted.org/packages/1d/b5/afabc8106abc0f9cfeecf5b3e682622b3e04bba1d9b967dbfcd91b9c4ebe/opentelemetry_exporter_otlp_proto_grpc-1.28.0-py3-none-any.whl", hash = "sha256:edbdc53e7783f88d4535db5807cb91bd7b1ec9e9b9cdbfee14cd378f29a3b328", size = 18532, upload-time = "2024-11-05T19:14:26.853Z" }, ] [[package]] name = "opentelemetry-exporter-otlp-proto-http" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecated" }, @@ -4535,28 +4623,29 @@ dependencies = [ { name = "opentelemetry-sdk" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/31/0a/f05c55e8913bf58a033583f2580a0ec31a5f4cf2beacc9e286dcb74d6979/opentelemetry_exporter_otlp_proto_http-1.27.0.tar.gz", hash = "sha256:2103479092d8eb18f61f3fbff084f67cc7f2d4a7d37e75304b8b56c1d09ebef5", size = 15059, upload-time = "2024-08-28T21:35:37.079Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f1/2a/555f2845928086cd51aa6941c7a546470805b68ed631ec139ce7d841763d/opentelemetry_exporter_otlp_proto_http-1.28.0.tar.gz", hash = "sha256:d83a9a03a8367ead577f02a64127d827c79567de91560029688dd5cfd0152a8e", size = 15051, upload-time = "2024-11-05T19:14:49.813Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2d/8d/4755884afc0b1db6000527cac0ca17273063b6142c773ce4ecd307a82e72/opentelemetry_exporter_otlp_proto_http-1.27.0-py3-none-any.whl", hash = "sha256:688027575c9da42e179a69fe17e2d1eba9b14d81de8d13553a21d3114f3b4d75", size = 17203, upload-time = "2024-08-28T21:35:08.141Z" }, + { url = "https://files.pythonhosted.org/packages/b2/ce/80d5adabbf7ab4a0ca7b5e0f4039b24d273be370c3ba85fc05b13794411c/opentelemetry_exporter_otlp_proto_http-1.28.0-py3-none-any.whl", hash = "sha256:e8f3f7961b747edb6b44d51de4901a61e9c01d50debd747b120a08c4996c7e7b", size = 17228, upload-time = "2024-11-05T19:14:28.613Z" }, ] [[package]] name = "opentelemetry-instrumentation" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, - { name = "setuptools" }, + { name = "opentelemetry-semantic-conventions" }, + { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/04/0e/d9394839af5d55c8feb3b22cd11138b953b49739b20678ca96289e30f904/opentelemetry_instrumentation-0.48b0.tar.gz", hash = "sha256:94929685d906380743a71c3970f76b5f07476eea1834abd5dd9d17abfe23cc35", size = 24724, upload-time = "2024-08-28T21:27:42.82Z" } +sdist = { url = "https://files.pythonhosted.org/packages/de/6b/6c25b15063c92a011cf3f68375971e2c58a9c764690847edc97df2d94eeb/opentelemetry_instrumentation-0.49b0.tar.gz", hash = "sha256:398a93e0b9dc2d11cc8627e1761665c506fe08c6b2df252a2ab3ade53d751c46", size = 26478, upload-time = "2024-11-05T19:21:41.402Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0a/7f/405c41d4f359121376c9d5117dcf68149b8122d3f6c718996d037bd4d800/opentelemetry_instrumentation-0.48b0-py3-none-any.whl", hash = "sha256:a69750dc4ba6a5c3eb67986a337185a25b739966d80479befe37b546fc870b44", size = 29449, upload-time = "2024-08-28T21:26:31.288Z" }, + { url = "https://files.pythonhosted.org/packages/93/61/e0d21e958d6072ce25c4f5e26a1d22835fc86f80836660adf6badb6038ce/opentelemetry_instrumentation-0.49b0-py3-none-any.whl", hash = "sha256:68364d73a1ff40894574cbc6138c5f98674790cae1f3b0865e21cf702f24dcb3", size = 30694, upload-time = "2024-11-05T19:20:38.584Z" }, ] [[package]] name = "opentelemetry-instrumentation-asgi" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "asgiref" }, @@ -4565,28 +4654,28 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/44/ac/fd3d40bab3234ec3f5c052a815100676baaae1832fa1067935f11e5c59c6/opentelemetry_instrumentation_asgi-0.48b0.tar.gz", hash = "sha256:04c32174b23c7fa72ddfe192dad874954968a6a924608079af9952964ecdf785", size = 23435, upload-time = "2024-08-28T21:27:47.276Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e8/55/693c3d0938ba5fead5c3aa4ac7022a992b4ff99a8e9979800d0feb843ff4/opentelemetry_instrumentation_asgi-0.49b0.tar.gz", hash = "sha256:959fd9b1345c92f20c6ef1d42f92ef6a76b3c3083fbc4104d59da6859b15b083", size = 24117, upload-time = "2024-11-05T19:21:46.769Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/74/a0e0d38622856597dd8e630f2bd793760485eb165708e11b8be1696bbb5a/opentelemetry_instrumentation_asgi-0.48b0-py3-none-any.whl", hash = "sha256:ddb1b5fc800ae66e85a4e2eca4d9ecd66367a8c7b556169d9e7b57e10676e44d", size = 15958, upload-time = "2024-08-28T21:26:38.139Z" }, + { url = "https://files.pythonhosted.org/packages/2c/0b/7900c782a1dfaa584588d724bc3bbdf8405a32497537dd96b3fcbf8461b9/opentelemetry_instrumentation_asgi-0.49b0-py3-none-any.whl", hash = "sha256:722a90856457c81956c88f35a6db606cc7db3231046b708aae2ddde065723dbe", size = 16326, upload-time = "2024-11-05T19:20:46.176Z" }, ] [[package]] name = "opentelemetry-instrumentation-celery" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/42/68/72975eff50cc22d8f65f96c425a2e8844f91488e78ffcfb603ac7cee0e5a/opentelemetry_instrumentation_celery-0.48b0.tar.gz", hash = "sha256:1d33aa6c4a1e6c5d17a64215245208a96e56c9d07611685dbae09a557704af26", size = 14445, upload-time = "2024-08-28T21:27:56.392Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/8b/9b8a9dda3ed53354c6f707a45cdb7a4730e1c109b50fc1b413525493f811/opentelemetry_instrumentation_celery-0.49b0.tar.gz", hash = "sha256:afbaee97cc9c75f29bcc9784f16f8e37c415d4fe9b334748c5b90a3d30d12473", size = 14702, upload-time = "2024-11-05T19:21:53.672Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/28/59/f09e8f9f596d375fd86b7677751525bbc485c8cc8c5388e39786a3d3b968/opentelemetry_instrumentation_celery-0.48b0-py3-none-any.whl", hash = "sha256:c1904e38cc58fb2a33cd657d6e296285c5ffb0dca3f164762f94b905e5abc88e", size = 13697, upload-time = "2024-08-28T21:26:50.01Z" }, + { url = "https://files.pythonhosted.org/packages/21/8c/d7d4adb36abbc0e517a69f7a069f32742122ae22d6017202f64570d9f4c5/opentelemetry_instrumentation_celery-0.49b0-py3-none-any.whl", hash = "sha256:38d4a78c78f33020032ef77ef0ead756bdf7838bcfb603de10f5925d39f14929", size = 13749, upload-time = "2024-11-05T19:20:54.98Z" }, ] [[package]] name = "opentelemetry-instrumentation-fastapi" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4595,17 +4684,16 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/58/20/43477da5850ef2cd3792715d442aecd051e885e0603b6ee5783b2104ba8f/opentelemetry_instrumentation_fastapi-0.48b0.tar.gz", hash = "sha256:21a72563ea412c0b535815aeed75fc580240f1f02ebc72381cfab672648637a2", size = 18497, upload-time = "2024-08-28T21:28:01.14Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/bf/8e6d2a4807360f2203192017eb4845f5628dbeaf0597adf3d141cc5c24e1/opentelemetry_instrumentation_fastapi-0.49b0.tar.gz", hash = "sha256:6d14935c41fd3e49328188b6a59dd4c37bd17a66b01c15b0c64afa9714a1f905", size = 19230, upload-time = "2024-11-05T19:21:59.361Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/50/745ab075a3041b7a5f29a579d2c28eaad54f64b4589d8f9fd364c62cf0f3/opentelemetry_instrumentation_fastapi-0.48b0-py3-none-any.whl", hash = "sha256:afeb820a59e139d3e5d96619600f11ce0187658b8ae9e3480857dd790bc024f2", size = 11777, upload-time = "2024-08-28T21:26:57.457Z" }, + { url = "https://files.pythonhosted.org/packages/b1/f4/0895b9410c10abf987c90dee1b7688a8f2214a284fe15e575648f6a1473a/opentelemetry_instrumentation_fastapi-0.49b0-py3-none-any.whl", hash = "sha256:646e1b18523cbe6860ae9711eb2c7b9c85466c3c7697cd6b8fb5180d85d3fe6e", size = 12101, upload-time = "2024-11-05T19:21:01.805Z" }, ] [[package]] name = "opentelemetry-instrumentation-flask" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "importlib-metadata" }, { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-instrumentation-wsgi" }, @@ -4613,29 +4701,30 @@ dependencies = [ { name = "opentelemetry-util-http" }, { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ed/2f/5c3af780a69f9ba78445fe0e5035c41f67281a31b08f3c3e7ec460bda726/opentelemetry_instrumentation_flask-0.48b0.tar.gz", hash = "sha256:e03a34428071aebf4864ea6c6a564acef64f88c13eb3818e64ea90da61266c3d", size = 19196, upload-time = "2024-08-28T21:28:01.986Z" } +sdist = { url = "https://files.pythonhosted.org/packages/17/12/dc72873fb1e35699941d8eb6a53ef25e8c5843dea37665dad33bd720f047/opentelemetry_instrumentation_flask-0.49b0.tar.gz", hash = "sha256:f7c5ab67753c4781a2e21c8f43dc5fc02ece74fdd819466c75d025db80aa7576", size = 19176, upload-time = "2024-11-05T19:22:00.816Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/3d/fcde4f8f0bf9fa1ee73a12304fa538076fb83fe0a2ae966ab0f0b7da5109/opentelemetry_instrumentation_flask-0.48b0-py3-none-any.whl", hash = "sha256:26b045420b9d76e85493b1c23fcf27517972423480dc6cf78fd6924248ba5808", size = 14588, upload-time = "2024-08-28T21:26:58.504Z" }, + { url = "https://files.pythonhosted.org/packages/a2/fc/354da8f33ef0daebfc8e4eac995d342ae13a35097bbad512cfe0d2f3c61a/opentelemetry_instrumentation_flask-0.49b0-py3-none-any.whl", hash = "sha256:f3ef330c3cee3e2c161f27f1e7017c8800b9bfb6f9204f2f7bfb0b274874be0e", size = 14582, upload-time = "2024-11-05T19:21:02.793Z" }, ] [[package]] name = "opentelemetry-instrumentation-httpx" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-instrumentation" }, { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, + { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d3/d9/c65d818607c16d1b7ea8d2de6111c6cecadf8d2fd38c1885a72733a7c6d3/opentelemetry_instrumentation_httpx-0.48b0.tar.gz", hash = "sha256:ee977479e10398931921fb995ac27ccdeea2e14e392cb27ef012fc549089b60a", size = 16931, upload-time = "2024-08-28T21:28:03.794Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/53/8b5e05e55a513d846ead5afb0509bec37a34a1c3e82f30b13d14156334b1/opentelemetry_instrumentation_httpx-0.49b0.tar.gz", hash = "sha256:07165b624f3e58638cee47ecf1c81939a8c2beb7e42ce9f69e25a9f21dc3f4cf", size = 17750, upload-time = "2024-11-05T19:22:02.911Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/fe/f2daa9d6d988c093b8c7b1d35df675761a8ece0b600b035dc04982746c9d/opentelemetry_instrumentation_httpx-0.48b0-py3-none-any.whl", hash = "sha256:d94f9d612c82d09fe22944d1904a30a464c19bea2ba76be656c99a28ad8be8e5", size = 13900, upload-time = "2024-08-28T21:27:01.566Z" }, + { url = "https://files.pythonhosted.org/packages/3b/9f/843391c6d645cd4f6914b27bc807fc1ff52b97f84cbe3ca675641976b23f/opentelemetry_instrumentation_httpx-0.49b0-py3-none-any.whl", hash = "sha256:e59e0d2fda5ef841630c68da1d78ff9192f63590a9099f12f0eab614abdf239a", size = 14110, upload-time = "2024-11-05T19:21:04.698Z" }, ] [[package]] name = "opentelemetry-instrumentation-redis" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4643,14 +4732,14 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/70/be/92e98e4c7f275be3d373899a41b0a7d4df64266657d985dbbdb9a54de0d5/opentelemetry_instrumentation_redis-0.48b0.tar.gz", hash = "sha256:61e33e984b4120e1b980d9fba6e9f7ca0c8d972f9970654d8f6e9f27fa115a8c", size = 10511, upload-time = "2024-08-28T21:28:15.061Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/5b/1398eb2f92fd76787ccec28d24dc4c7dfaaf97a7557e7729e2f7c2c05d84/opentelemetry_instrumentation_redis-0.49b0.tar.gz", hash = "sha256:922542c3bd192ad4ba74e2c7e0a253c7c58a5cefbd6f89da2aba4d193a974703", size = 11353, upload-time = "2024-11-05T19:22:12.822Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/94/40/892f30d400091106309cc047fd3f6d76a828fedd984a953fd5386b78a2fb/opentelemetry_instrumentation_redis-0.48b0-py3-none-any.whl", hash = "sha256:48c7f2e25cbb30bde749dc0d8b9c74c404c851f554af832956b9630b27f5bcb7", size = 11610, upload-time = "2024-08-28T21:27:18.759Z" }, + { url = "https://files.pythonhosted.org/packages/24/e4/4f258fef0759629f2e8a0210d5533cfef3ecad69ff35be044637a3e2783e/opentelemetry_instrumentation_redis-0.49b0-py3-none-any.whl", hash = "sha256:b7d8f758bac53e77b7e7ca98ce80f91230577502dacb619ebe8e8b6058042067", size = 12453, upload-time = "2024-11-05T19:21:18.534Z" }, ] [[package]] name = "opentelemetry-instrumentation-sqlalchemy" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4659,14 +4748,14 @@ dependencies = [ { name = "packaging" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/77/3fcebbca8bd729da50dc2130d8ca869a235aa5483a85ef06c5dc8643476b/opentelemetry_instrumentation_sqlalchemy-0.48b0.tar.gz", hash = "sha256:dbf2d5a755b470e64e5e2762b56f8d56313787e4c7d71a87fe25c33f48eb3493", size = 13194, upload-time = "2024-08-28T21:28:18.122Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a0/a7/24f6cce3808ae1802dd1b60d752fbab877db5655198929cf4ee8ea416923/opentelemetry_instrumentation_sqlalchemy-0.49b0.tar.gz", hash = "sha256:32658e520fc8b35823c722f5d8831d3a410b76dd2724adb2887befc041ddef04", size = 13194, upload-time = "2024-11-05T19:22:14.92Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/84/4b6f1e9e9f83a52d966e91963f5a8424edc4a3d5ea32854c96c2d1618284/opentelemetry_instrumentation_sqlalchemy-0.48b0-py3-none-any.whl", hash = "sha256:625848a34aa5770cb4b1dcdbd95afce4307a0230338711101325261d739f391f", size = 13360, upload-time = "2024-08-28T21:27:22.102Z" }, + { url = "https://files.pythonhosted.org/packages/ec/6b/a1a3685fed593282999cdc374ece15efbd56f8d774bd368bf7ff2cf5923c/opentelemetry_instrumentation_sqlalchemy-0.49b0-py3-none-any.whl", hash = "sha256:d854052d2b02cd0562e5628a514c8153fceada7f585137e173165dfd0a46ef6a", size = 13358, upload-time = "2024-11-05T19:21:23.654Z" }, ] [[package]] name = "opentelemetry-instrumentation-wsgi" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, @@ -4674,75 +4763,75 @@ dependencies = [ { name = "opentelemetry-semantic-conventions" }, { name = "opentelemetry-util-http" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/de/a5/f45cdfba18f22aefd2378eac8c07c1f8c9656d6bf7ce315ced48c67f3437/opentelemetry_instrumentation_wsgi-0.48b0.tar.gz", hash = "sha256:1a1e752367b0df4397e0b835839225ef5c2c3c053743a261551af13434fc4d51", size = 17974, upload-time = "2024-08-28T21:28:24.902Z" } +sdist = { url = "https://files.pythonhosted.org/packages/17/2b/91b022b004ac9e9ab0eefd10bc4257975291f88adc81b4ef2c601ddb1adf/opentelemetry_instrumentation_wsgi-0.49b0.tar.gz", hash = "sha256:0812a02e132f8fc3d5c897bba84e530c37b85c315b199bb97ca6508279e7eb23", size = 17733, upload-time = "2024-11-05T19:22:24.3Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/87/fa420007e0ba7e8cd43799ab204717ab515f000236fa2726a6be3299efdd/opentelemetry_instrumentation_wsgi-0.48b0-py3-none-any.whl", hash = "sha256:c6051124d741972090fe94b2fa302555e1e2a22e9cdda32dd39ed49a5b34e0c6", size = 13691, upload-time = "2024-08-28T21:27:33.257Z" }, + { url = "https://files.pythonhosted.org/packages/02/1d/59979665778ed8c85bc31c92b75571cd7afb8e3322fb513c87fe1bad6d78/opentelemetry_instrumentation_wsgi-0.49b0-py3-none-any.whl", hash = "sha256:8869ccf96611827e4448417718920e9eec6d25bffb5bf72c7952c7346ec33fbc", size = 13699, upload-time = "2024-11-05T19:21:35.039Z" }, ] [[package]] name = "opentelemetry-propagator-b3" -version = "1.27.0" +version = "1.40.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "deprecated" }, { name = "opentelemetry-api" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/a3/3ceeb5ff5a1906371834d5c594e24e5b84f35528d219054833deca4ac44c/opentelemetry_propagator_b3-1.27.0.tar.gz", hash = "sha256:39377b6aa619234e08fbc6db79bf880aff36d7e2761efa9afa28b78d5937308f", size = 9590, upload-time = "2024-08-28T21:35:43.971Z" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/fe/e0c84af5c654ec42165ba57af83c7f67e4b8af77f836ddc29dee59ff73c6/opentelemetry_propagator_b3-1.40.0.tar.gz", hash = "sha256:59b6925498947c08a1b7e0dd38193ff97e5009bec74ec23824300c2e32f77bcf", size = 9587, upload-time = "2026-03-04T14:17:30.079Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/03/3f/75ba77b8d9938bae575bc457a5c56ca2246ff5367b54c7d4252a31d1c91f/opentelemetry_propagator_b3-1.27.0-py3-none-any.whl", hash = "sha256:1dd75e9801ba02e870df3830097d35771a64c123127c984d9b05c352a35aa9cc", size = 8899, upload-time = "2024-08-28T21:35:18.317Z" }, + { url = "https://files.pythonhosted.org/packages/8f/84/8654cc0539b5145046b2e60d058cebad401a600dd0b1240f1711c6788643/opentelemetry_propagator_b3-1.40.0-py3-none-any.whl", hash = "sha256:cb72a1698fd1d1b434f70dc90c1de62da8ade1dd84850d1f040eccf6a420fa7b", size = 8922, upload-time = "2026-03-04T14:17:14.732Z" }, ] [[package]] name = "opentelemetry-proto" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9a/59/959f0beea798ae0ee9c979b90f220736fbec924eedbefc60ca581232e659/opentelemetry_proto-1.27.0.tar.gz", hash = "sha256:33c9345d91dafd8a74fc3d7576c5a38f18b7fdf8d02983ac67485386132aedd6", size = 34749, upload-time = "2024-08-28T21:35:45.839Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/63/ac4cef4d30ea0ca1d2153ad2fc62d91d1cf3b89b0e4e5cbd61a8c567885f/opentelemetry_proto-1.28.0.tar.gz", hash = "sha256:4a45728dfefa33f7908b828b9b7c9f2c6de42a05d5ec7b285662ddae71c4c870", size = 34331, upload-time = "2024-11-05T19:14:59.503Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/94/56/3d2d826834209b19a5141eed717f7922150224d1a982385d19a9444cbf8d/opentelemetry_proto-1.27.0-py3-none-any.whl", hash = "sha256:b133873de5581a50063e1e4b29cdcf0c5e253a8c2d8dc1229add20a4c3830ace", size = 52464, upload-time = "2024-08-28T21:35:21.434Z" }, + { url = "https://files.pythonhosted.org/packages/86/94/c0b43d16e1d96ee1e699373aa59f14a3aa2e7126af3f11d6adc5dcc531cd/opentelemetry_proto-1.28.0-py3-none-any.whl", hash = "sha256:d5ad31b997846543b8e15504657d9a8cf1ad3c71dcbbb6c4799b1ab29e38f7f9", size = 55832, upload-time = "2024-11-05T19:14:40.446Z" }, ] [[package]] name = "opentelemetry-sdk" -version = "1.27.0" +version = "1.28.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "opentelemetry-api" }, { name = "opentelemetry-semantic-conventions" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/9a/82a6ac0f06590f3d72241a587cb8b0b751bd98728e896cc4cbd4847248e6/opentelemetry_sdk-1.27.0.tar.gz", hash = "sha256:d525017dea0ccce9ba4e0245100ec46ecdc043f2d7b8315d56b19aff0904fa6f", size = 145019, upload-time = "2024-08-28T21:35:46.708Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/5b/a509ccab93eacc6044591d5ec437d8266e76f893d0389bbf7e5592c7da32/opentelemetry_sdk-1.28.0.tar.gz", hash = "sha256:41d5420b2e3fb7716ff4981b510d551eff1fc60eb5a95cf7335b31166812a893", size = 156155, upload-time = "2024-11-05T19:15:00.451Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bd/a6602e71e315055d63b2ff07172bd2d012b4cba2d4e00735d74ba42fc4d6/opentelemetry_sdk-1.27.0-py3-none-any.whl", hash = "sha256:365f5e32f920faf0fd9e14fdfd92c086e317eaa5f860edba9cdc17a380d9197d", size = 110505, upload-time = "2024-08-28T21:35:24.769Z" }, + { url = "https://files.pythonhosted.org/packages/c3/fe/c8decbebb5660529f1d6ba65e50a45b1294022dfcba2968fc9c8697c42b2/opentelemetry_sdk-1.28.0-py3-none-any.whl", hash = "sha256:4b37da81d7fad67f6683c4420288c97f4ed0d988845d5886435f428ec4b8429a", size = 118692, upload-time = "2024-11-05T19:14:41.669Z" }, ] [[package]] name = "opentelemetry-semantic-conventions" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "deprecated" }, { name = "opentelemetry-api" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/89/1724ad69f7411772446067cdfa73b598694c8c91f7f8c922e344d96d81f9/opentelemetry_semantic_conventions-0.48b0.tar.gz", hash = "sha256:12d74983783b6878162208be57c9effcb89dc88691c64992d70bb89dc00daa1a", size = 89445, upload-time = "2024-08-28T21:35:47.673Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/c8/433b0e54143f8c9369f5c4a7a83e73eec7eb2ee7d0b7e81a9243e78c8e80/opentelemetry_semantic_conventions-0.49b0.tar.gz", hash = "sha256:dbc7b28339e5390b6b28e022835f9bac4e134a80ebf640848306d3c5192557e8", size = 95227, upload-time = "2024-11-05T19:15:01.443Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/7a/4f0063dbb0b6c971568291a8bc19a4ca70d3c185db2d956230dd67429dfc/opentelemetry_semantic_conventions-0.48b0-py3-none-any.whl", hash = "sha256:a0de9f45c413a8669788a38569c7e0a11ce6ce97861a628cca785deecdc32a1f", size = 149685, upload-time = "2024-08-28T21:35:25.983Z" }, + { url = "https://files.pythonhosted.org/packages/25/05/20104df4ef07d3bf5c3fd6bcc796ef70ab4ea4309378a9ba57bc4b4d01fa/opentelemetry_semantic_conventions-0.49b0-py3-none-any.whl", hash = "sha256:0458117f6ead0b12e3221813e3e511d85698c31901cac84682052adb9c17c7cd", size = 159214, upload-time = "2024-11-05T19:14:43.047Z" }, ] [[package]] name = "opentelemetry-util-http" -version = "0.48b0" +version = "0.49b0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d6/d7/185c494754340e0a3928fd39fde2616ee78f2c9d66253affaad62d5b7935/opentelemetry_util_http-0.48b0.tar.gz", hash = "sha256:60312015153580cc20f322e5cdc3d3ecad80a71743235bdb77716e742814623c", size = 7863, upload-time = "2024-08-28T21:28:27.266Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/99/377ef446928808211b127b9ab31c348bc465c8da4514ebeec6e4a3de3d21/opentelemetry_util_http-0.49b0.tar.gz", hash = "sha256:02928496afcffd58a7c15baf99d2cedae9b8325a8ac52b0d0877b2e8f936dd1b", size = 7863, upload-time = "2024-11-05T19:22:26.973Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/2e/36097c0a4d0115b8c7e377c90bab7783ac183bc5cb4071308f8959454311/opentelemetry_util_http-0.48b0-py3-none-any.whl", hash = "sha256:76f598af93aab50328d2a69c786beaedc8b6a7770f7a818cc307eb353debfffb", size = 6946, upload-time = "2024-08-28T21:27:37.975Z" }, + { url = "https://files.pythonhosted.org/packages/66/0e/ab0a89b315d0bacdd355a345bb69b20c50fc1f0804b52b56fe1c35a60e68/opentelemetry_util_http-0.49b0-py3-none-any.whl", hash = "sha256:8661bbd6aea1839badc44de067ec9c15c05eab05f729f496c856c50a1203caf1", size = 6945, upload-time = "2024-11-05T19:21:37.81Z" }, ] [[package]] name = "opik" -version = "1.8.102" +version = "1.10.45" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4761,21 +4850,21 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/30/af/f6382cea86bdfbfd0f9571960a15301da4a6ecd1506070d9252a0c0a7564/opik-1.8.102.tar.gz", hash = "sha256:c836a113e8b7fdf90770a3854dcc859b3c30d6347383d7c11e52971a530ed2c3", size = 490462, upload-time = "2025-11-05T18:54:50.142Z" } +sdist = { url = "https://files.pythonhosted.org/packages/85/17/edea6308347cec62e6828de7c573c596559c502b54fa4f0c88a52e2e81f5/opik-1.10.45.tar.gz", hash = "sha256:d8d8627ba03d12def46965e03d58f611daaf5cf878b3d087c53fe1159788c140", size = 789876, upload-time = "2026-03-20T11:35:12.457Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/8b/9b15a01f8360201100b9a5d3e0aeeeda57833fca2b16d34b9fada147fc4b/opik-1.8.102-py3-none-any.whl", hash = "sha256:d8501134bf62bf95443de036f6eaa4f66006f81f9b99e0a8a09e21d8be8c1628", size = 885834, upload-time = "2025-11-05T18:54:48.22Z" }, + { url = "https://files.pythonhosted.org/packages/b7/17/150e9eecfa28cb23f7a0bfe83ae1486a11022b97fe6d12328b455784658d/opik-1.10.45-py3-none-any.whl", hash = "sha256:e8050d9e5e0d92ff587f156eacbdd02099897f39cfe79a98380b6c8ae9906b95", size = 1337714, upload-time = "2026-03-20T11:35:10.237Z" }, ] [[package]] name = "optype" -version = "0.14.0" +version = "0.17.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/ca/d3a2abcf12cc8c18ccac1178ef87ab50a235bf386d2401341776fdad18aa/optype-0.14.0.tar.gz", hash = "sha256:925cf060b7d1337647f880401f6094321e7d8e837533b8e159b9a92afa3157c6", size = 100880, upload-time = "2025-10-01T04:49:56.232Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/9f/3b13bab05debf685678b8af004e46b8c67c6f98ffa08eaf5d33bcf162c16/optype-0.17.0.tar.gz", hash = "sha256:31351a1e64d9eba7bf67e14deefb286e85c66458db63c67dd5e26dd72e4664e5", size = 53484, upload-time = "2026-03-08T23:03:12.594Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/84/a6/11b0eb65eeafa87260d36858b69ec4e0072d09e37ea6714280960030bc93/optype-0.14.0-py3-none-any.whl", hash = "sha256:50d02edafd04edf2e5e27d6249760a51b2198adb9f6ffd778030b3d2806b026b", size = 89465, upload-time = "2025-10-01T04:49:54.674Z" }, + { url = "https://files.pythonhosted.org/packages/6b/44/dca78187415947d1bb90b2ee2a58e47d9573528331e8dc6196996b53612a/optype-0.17.0-py3-none-any.whl", hash = "sha256:8c2d88ff13149454bcf6eb47502f80d288bc542e7238fcc412ac4d222c439397", size = 65854, upload-time = "2026-03-08T23:03:11.425Z" }, ] [package.optional-dependencies] @@ -4786,66 +4875,93 @@ numpy = [ [[package]] name = "oracledb" -version = "3.3.0" +version = "3.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cryptography" }, + { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/51/c9/fae18fa5d803712d188486f8e86ad4f4e00316793ca19745d7c11092c360/oracledb-3.3.0.tar.gz", hash = "sha256:e830d3544a1578296bcaa54c6e8c8ae10a58c7db467c528c4b27adbf9c8b4cb0", size = 811776, upload-time = "2025-07-29T22:34:10.489Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/02/70a872d1a4a739b4f7371ab8d3d5ed8c6e57e142e2503531aafcb220893c/oracledb-3.4.2.tar.gz", hash = "sha256:46e0f2278ff1fe83fbc33a3b93c72d429323ec7eed47bc9484e217776cd437e5", size = 855467, upload-time = "2026-01-28T17:25:39.91Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/35/95d9a502fdc48ce1ef3a513ebd027488353441e15aa0448619abb3d09d32/oracledb-3.3.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:d9adb74f837838e21898d938e3a725cf73099c65f98b0b34d77146b453e945e0", size = 3963945, upload-time = "2025-07-29T22:34:28.633Z" }, - { url = "https://files.pythonhosted.org/packages/16/a7/8f1ef447d995bb51d9fdc36356697afeceb603932f16410c12d52b2df1a4/oracledb-3.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4b063d1007882570f170ebde0f364e78d4a70c8f015735cc900663278b9ceef7", size = 2449385, upload-time = "2025-07-29T22:34:30.592Z" }, - { url = "https://files.pythonhosted.org/packages/b3/fa/6a78480450bc7d256808d0f38ade3385735fb5a90dab662167b4257dcf94/oracledb-3.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:187728f0a2d161676b8c581a9d8f15d9631a8fea1e628f6d0e9fa2f01280cd22", size = 2634943, upload-time = "2025-07-29T22:34:33.142Z" }, - { url = "https://files.pythonhosted.org/packages/5b/90/ea32b569a45fb99fac30b96f1ac0fb38b029eeebb78357bc6db4be9dde41/oracledb-3.3.0-cp311-cp311-win32.whl", hash = "sha256:920f14314f3402c5ab98f2efc5932e0547e9c0a4ca9338641357f73844e3e2b1", size = 1483549, upload-time = "2025-07-29T22:34:35.015Z" }, - { url = "https://files.pythonhosted.org/packages/81/55/ae60f72836eb8531b630299f9ed68df3fe7868c6da16f820a108155a21f9/oracledb-3.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:825edb97976468db1c7e52c78ba38d75ce7e2b71a2e88f8629bcf02be8e68a8a", size = 1834737, upload-time = "2025-07-29T22:34:36.824Z" }, - { url = "https://files.pythonhosted.org/packages/08/a8/f6b7809d70e98e113786d5a6f1294da81c046d2fa901ad656669fc5d7fae/oracledb-3.3.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:9d25e37d640872731ac9b73f83cbc5fc4743cd744766bdb250488caf0d7696a8", size = 3943512, upload-time = "2025-07-29T22:34:39.237Z" }, - { url = "https://files.pythonhosted.org/packages/df/b9/8145ad8991f4864d3de4a911d439e5bc6cdbf14af448f3ab1e846a54210c/oracledb-3.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0bf7cdc2b668f939aa364f552861bc7a149d7cd3f3794730d43ef07613b2bf9", size = 2276258, upload-time = "2025-07-29T22:34:41.547Z" }, - { url = "https://files.pythonhosted.org/packages/56/bf/f65635ad5df17d6e4a2083182750bb136ac663ff0e9996ce59d77d200f60/oracledb-3.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2fe20540fde64a6987046807ea47af93be918fd70b9766b3eb803c01e6d4202e", size = 2458811, upload-time = "2025-07-29T22:34:44.648Z" }, - { url = "https://files.pythonhosted.org/packages/7d/30/e0c130b6278c10b0e6cd77a3a1a29a785c083c549676cf701c5d180b8e63/oracledb-3.3.0-cp312-cp312-win32.whl", hash = "sha256:db080be9345cbf9506ffdaea3c13d5314605355e76d186ec4edfa49960ffb813", size = 1445525, upload-time = "2025-07-29T22:34:46.603Z" }, - { url = "https://files.pythonhosted.org/packages/1a/5c/7254f5e1a33a5d6b8bf6813d4f4fdcf5c4166ec8a7af932d987879d5595c/oracledb-3.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:be81e3afe79f6c8ece79a86d6067ad1572d2992ce1c590a086f3755a09535eb4", size = 1789976, upload-time = "2025-07-29T22:34:48.5Z" }, + { url = "https://files.pythonhosted.org/packages/64/80/be263b668ba32b258d07c85f7bfb6967a9677e016c299207b28734f04c4b/oracledb-3.4.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b8e4b8a852251cef09038b75f30fce1227010835f4e19cfbd436027acba2697c", size = 4228552, upload-time = "2026-01-28T17:25:54.844Z" }, + { url = "https://files.pythonhosted.org/packages/91/bc/e832a649529da7c60409a81be41f3213b4c7ffda4fe424222b2145e8d43c/oracledb-3.4.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1617a1db020346883455af005efbefd51be2c4d797e43b1b38455a19f8526b48", size = 2421924, upload-time = "2026-01-28T17:25:56.984Z" }, + { url = "https://files.pythonhosted.org/packages/86/21/d867c37e493a63b5521bd248110ad5b97b18253d64a30703e3e8f3d9631e/oracledb-3.4.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5ed78d7e7079a778062744ccf42141ce4806818c3f4dd6463e4a7edd561c9f86", size = 2599301, upload-time = "2026-01-28T17:25:58.529Z" }, + { url = "https://files.pythonhosted.org/packages/2a/de/9b1843ea27f7791449652d7f340f042c3053336d2c11caf29e59bab86189/oracledb-3.4.2-cp311-cp311-win32.whl", hash = "sha256:0e16fe3d057e0c41a23ad2ae95bfa002401690773376d476be608f79ac74bf05", size = 1492890, upload-time = "2026-01-28T17:26:00.662Z" }, + { url = "https://files.pythonhosted.org/packages/d6/10/cbc8afa2db0cec80530858d3e4574f9734fae8c0b7f1df261398aa026c5f/oracledb-3.4.2-cp311-cp311-win_amd64.whl", hash = "sha256:f93cae08e8ed20f2d5b777a8602a71f9418389c661d2c937e84d94863e7e7011", size = 1843355, upload-time = "2026-01-28T17:26:02.637Z" }, + { url = "https://files.pythonhosted.org/packages/8f/81/2e6154f34b71cd93b4946c73ea13b69d54b8d45a5f6bbffe271793240d21/oracledb-3.4.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a7396664e592881225ba66385ee83ce339d864f39003d6e4ca31a894a7e7c552", size = 4220806, upload-time = "2026-01-28T17:26:04.322Z" }, + { url = "https://files.pythonhosted.org/packages/ab/a9/a1d59aaac77d8f727156ec6a3b03399917c90b7da4f02d057f92e5601f56/oracledb-3.4.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0f04a2d62073407672f114d02529921de0677c6883ed7c64d8d1a3c04caa3238", size = 2233795, upload-time = "2026-01-28T17:26:05.877Z" }, + { url = "https://files.pythonhosted.org/packages/94/ec/8c4a38020cd251572bd406ddcbde98ca052ec94b5684f9aa9ef1ddfcc68c/oracledb-3.4.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d8d75e4f879b908be66cce05ba6c05791a5dbb4a15e39abc01aa25c8a2492bd9", size = 2424756, upload-time = "2026-01-28T17:26:07.35Z" }, + { url = "https://files.pythonhosted.org/packages/fa/7d/c251c2a8567151ccfcfbe3467ea9a60fb5480dc4719342e2e6b7a9679e5d/oracledb-3.4.2-cp312-cp312-win32.whl", hash = "sha256:31b7ee83c23d0439778303de8a675717f805f7e8edb5556d48c4d8343bcf14f5", size = 1453486, upload-time = "2026-01-28T17:26:08.869Z" }, + { url = "https://files.pythonhosted.org/packages/4c/78/c939f3c16fb39400c4734d5a3340db5659ba4e9dce23032d7b33ccfd3fe5/oracledb-3.4.2-cp312-cp312-win_amd64.whl", hash = "sha256:ac25a0448fc830fb7029ad50cd136cdbfcd06975d53967e269772cc5cb8c203a", size = 1794445, upload-time = "2026-01-28T17:26:10.66Z" }, ] [[package]] name = "orjson" -version = "3.11.4" +version = "3.11.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c6/fe/ed708782d6709cc60eb4c2d8a361a440661f74134675c72990f2c48c785f/orjson-3.11.4.tar.gz", hash = "sha256:39485f4ab4c9b30a3943cfe99e1a213c4776fb69e8abd68f66b83d5a0b0fdc6d", size = 5945188, upload-time = "2025-10-24T15:50:38.027Z" } +sdist = { url = "https://files.pythonhosted.org/packages/53/45/b268004f745ede84e5798b48ee12b05129d19235d0e15267aa57dcdb400b/orjson-3.11.7.tar.gz", hash = "sha256:9b1a67243945819ce55d24a30b59d6a168e86220452d2c96f4d1f093e71c0c49", size = 6144992, upload-time = "2026-02-02T15:38:49.29Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/63/1d/1ea6005fffb56715fd48f632611e163d1604e8316a5bad2288bee9a1c9eb/orjson-3.11.4-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:5e59d23cd93ada23ec59a96f215139753fbfe3a4d989549bcb390f8c00370b39", size = 243498, upload-time = "2025-10-24T15:48:48.101Z" }, - { url = "https://files.pythonhosted.org/packages/37/d7/ffed10c7da677f2a9da307d491b9eb1d0125b0307019c4ad3d665fd31f4f/orjson-3.11.4-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:5c3aedecfc1beb988c27c79d52ebefab93b6c3921dbec361167e6559aba2d36d", size = 128961, upload-time = "2025-10-24T15:48:49.571Z" }, - { url = "https://files.pythonhosted.org/packages/a2/96/3e4d10a18866d1368f73c8c44b7fe37cc8a15c32f2a7620be3877d4c55a3/orjson-3.11.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da9e5301f1c2caa2a9a4a303480d79c9ad73560b2e7761de742ab39fe59d9175", size = 130321, upload-time = "2025-10-24T15:48:50.713Z" }, - { url = "https://files.pythonhosted.org/packages/eb/1f/465f66e93f434f968dd74d5b623eb62c657bdba2332f5a8be9f118bb74c7/orjson-3.11.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:8873812c164a90a79f65368f8f96817e59e35d0cc02786a5356f0e2abed78040", size = 129207, upload-time = "2025-10-24T15:48:52.193Z" }, - { url = "https://files.pythonhosted.org/packages/28/43/d1e94837543321c119dff277ae8e348562fe8c0fafbb648ef7cb0c67e521/orjson-3.11.4-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5d7feb0741ebb15204e748f26c9638e6665a5fa93c37a2c73d64f1669b0ddc63", size = 136323, upload-time = "2025-10-24T15:48:54.806Z" }, - { url = "https://files.pythonhosted.org/packages/bf/04/93303776c8890e422a5847dd012b4853cdd88206b8bbd3edc292c90102d1/orjson-3.11.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:01ee5487fefee21e6910da4c2ee9eef005bee568a0879834df86f888d2ffbdd9", size = 137440, upload-time = "2025-10-24T15:48:56.326Z" }, - { url = "https://files.pythonhosted.org/packages/1e/ef/75519d039e5ae6b0f34d0336854d55544ba903e21bf56c83adc51cd8bf82/orjson-3.11.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3d40d46f348c0321df01507f92b95a377240c4ec31985225a6668f10e2676f9a", size = 136680, upload-time = "2025-10-24T15:48:57.476Z" }, - { url = "https://files.pythonhosted.org/packages/b5/18/bf8581eaae0b941b44efe14fee7b7862c3382fbc9a0842132cfc7cf5ecf4/orjson-3.11.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95713e5fc8af84d8edc75b785d2386f653b63d62b16d681687746734b4dfc0be", size = 136160, upload-time = "2025-10-24T15:48:59.631Z" }, - { url = "https://files.pythonhosted.org/packages/c4/35/a6d582766d351f87fc0a22ad740a641b0a8e6fc47515e8614d2e4790ae10/orjson-3.11.4-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ad73ede24f9083614d6c4ca9a85fe70e33be7bf047ec586ee2363bc7418fe4d7", size = 140318, upload-time = "2025-10-24T15:49:00.834Z" }, - { url = "https://files.pythonhosted.org/packages/76/b3/5a4801803ab2e2e2d703bce1a56540d9f99a9143fbec7bf63d225044fef8/orjson-3.11.4-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:842289889de515421f3f224ef9c1f1efb199a32d76d8d2ca2706fa8afe749549", size = 406330, upload-time = "2025-10-24T15:49:02.327Z" }, - { url = "https://files.pythonhosted.org/packages/80/55/a8f682f64833e3a649f620eafefee175cbfeb9854fc5b710b90c3bca45df/orjson-3.11.4-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:3b2427ed5791619851c52a1261b45c233930977e7de8cf36de05636c708fa905", size = 149580, upload-time = "2025-10-24T15:49:03.517Z" }, - { url = "https://files.pythonhosted.org/packages/ad/e4/c132fa0c67afbb3eb88274fa98df9ac1f631a675e7877037c611805a4413/orjson-3.11.4-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3c36e524af1d29982e9b190573677ea02781456b2e537d5840e4538a5ec41907", size = 139846, upload-time = "2025-10-24T15:49:04.761Z" }, - { url = "https://files.pythonhosted.org/packages/54/06/dc3491489efd651fef99c5908e13951abd1aead1257c67f16135f95ce209/orjson-3.11.4-cp311-cp311-win32.whl", hash = "sha256:87255b88756eab4a68ec61837ca754e5d10fa8bc47dc57f75cedfeaec358d54c", size = 135781, upload-time = "2025-10-24T15:49:05.969Z" }, - { url = "https://files.pythonhosted.org/packages/79/b7/5e5e8d77bd4ea02a6ac54c42c818afb01dd31961be8a574eb79f1d2cfb1e/orjson-3.11.4-cp311-cp311-win_amd64.whl", hash = "sha256:e2d5d5d798aba9a0e1fede8d853fa899ce2cb930ec0857365f700dffc2c7af6a", size = 131391, upload-time = "2025-10-24T15:49:07.355Z" }, - { url = "https://files.pythonhosted.org/packages/0f/dc/9484127cc1aa213be398ed735f5f270eedcb0c0977303a6f6ddc46b60204/orjson-3.11.4-cp311-cp311-win_arm64.whl", hash = "sha256:6bb6bb41b14c95d4f2702bce9975fda4516f1db48e500102fc4d8119032ff045", size = 126252, upload-time = "2025-10-24T15:49:08.869Z" }, - { url = "https://files.pythonhosted.org/packages/63/51/6b556192a04595b93e277a9ff71cd0cc06c21a7df98bcce5963fa0f5e36f/orjson-3.11.4-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:d4371de39319d05d3f482f372720b841c841b52f5385bd99c61ed69d55d9ab50", size = 243571, upload-time = "2025-10-24T15:49:10.008Z" }, - { url = "https://files.pythonhosted.org/packages/1c/2c/2602392ddf2601d538ff11848b98621cd465d1a1ceb9db9e8043181f2f7b/orjson-3.11.4-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:e41fd3b3cac850eaae78232f37325ed7d7436e11c471246b87b2cd294ec94853", size = 128891, upload-time = "2025-10-24T15:49:11.297Z" }, - { url = "https://files.pythonhosted.org/packages/4e/47/bf85dcf95f7a3a12bf223394a4f849430acd82633848d52def09fa3f46ad/orjson-3.11.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:600e0e9ca042878c7fdf189cf1b028fe2c1418cc9195f6cb9824eb6ed99cb938", size = 130137, upload-time = "2025-10-24T15:49:12.544Z" }, - { url = "https://files.pythonhosted.org/packages/b4/4d/a0cb31007f3ab6f1fd2a1b17057c7c349bc2baf8921a85c0180cc7be8011/orjson-3.11.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7bbf9b333f1568ef5da42bc96e18bf30fd7f8d54e9ae066d711056add508e415", size = 129152, upload-time = "2025-10-24T15:49:13.754Z" }, - { url = "https://files.pythonhosted.org/packages/f7/ef/2811def7ce3d8576b19e3929fff8f8f0d44bc5eb2e0fdecb2e6e6cc6c720/orjson-3.11.4-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4806363144bb6e7297b8e95870e78d30a649fdc4e23fc84daa80c8ebd366ce44", size = 136834, upload-time = "2025-10-24T15:49:15.307Z" }, - { url = "https://files.pythonhosted.org/packages/00/d4/9aee9e54f1809cec8ed5abd9bc31e8a9631d19460e3b8470145d25140106/orjson-3.11.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ad355e8308493f527d41154e9053b86a5be892b3b359a5c6d5d95cda23601cb2", size = 137519, upload-time = "2025-10-24T15:49:16.557Z" }, - { url = "https://files.pythonhosted.org/packages/db/ea/67bfdb5465d5679e8ae8d68c11753aaf4f47e3e7264bad66dc2f2249e643/orjson-3.11.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c8a7517482667fb9f0ff1b2f16fe5829296ed7a655d04d68cd9711a4d8a4e708", size = 136749, upload-time = "2025-10-24T15:49:17.796Z" }, - { url = "https://files.pythonhosted.org/packages/01/7e/62517dddcfce6d53a39543cd74d0dccfcbdf53967017c58af68822100272/orjson-3.11.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97eb5942c7395a171cbfecc4ef6701fc3c403e762194683772df4c54cfbb2210", size = 136325, upload-time = "2025-10-24T15:49:19.347Z" }, - { url = "https://files.pythonhosted.org/packages/18/ae/40516739f99ab4c7ec3aaa5cc242d341fcb03a45d89edeeaabc5f69cb2cf/orjson-3.11.4-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:149d95d5e018bdd822e3f38c103b1a7c91f88d38a88aada5c4e9b3a73a244241", size = 140204, upload-time = "2025-10-24T15:49:20.545Z" }, - { url = "https://files.pythonhosted.org/packages/82/18/ff5734365623a8916e3a4037fcef1cd1782bfc14cf0992afe7940c5320bf/orjson-3.11.4-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:624f3951181eb46fc47dea3d221554e98784c823e7069edb5dbd0dc826ac909b", size = 406242, upload-time = "2025-10-24T15:49:21.884Z" }, - { url = "https://files.pythonhosted.org/packages/e1/43/96436041f0a0c8c8deca6a05ebeaf529bf1de04839f93ac5e7c479807aec/orjson-3.11.4-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:03bfa548cf35e3f8b3a96c4e8e41f753c686ff3d8e182ce275b1751deddab58c", size = 150013, upload-time = "2025-10-24T15:49:23.185Z" }, - { url = "https://files.pythonhosted.org/packages/1b/48/78302d98423ed8780479a1e682b9aecb869e8404545d999d34fa486e573e/orjson-3.11.4-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:525021896afef44a68148f6ed8a8bf8375553d6066c7f48537657f64823565b9", size = 139951, upload-time = "2025-10-24T15:49:24.428Z" }, - { url = "https://files.pythonhosted.org/packages/4a/7b/ad613fdcdaa812f075ec0875143c3d37f8654457d2af17703905425981bf/orjson-3.11.4-cp312-cp312-win32.whl", hash = "sha256:b58430396687ce0f7d9eeb3dd47761ca7d8fda8e9eb92b3077a7a353a75efefa", size = 136049, upload-time = "2025-10-24T15:49:25.973Z" }, - { url = "https://files.pythonhosted.org/packages/b9/3c/9cf47c3ff5f39b8350fb21ba65d789b6a1129d4cbb3033ba36c8a9023520/orjson-3.11.4-cp312-cp312-win_amd64.whl", hash = "sha256:c6dbf422894e1e3c80a177133c0dda260f81428f9de16d61041949f6a2e5c140", size = 131461, upload-time = "2025-10-24T15:49:27.259Z" }, - { url = "https://files.pythonhosted.org/packages/c6/3b/e2425f61e5825dc5b08c2a5a2b3af387eaaca22a12b9c8c01504f8614c36/orjson-3.11.4-cp312-cp312-win_arm64.whl", hash = "sha256:d38d2bc06d6415852224fcc9c0bfa834c25431e466dc319f0edd56cca81aa96e", size = 126167, upload-time = "2025-10-24T15:49:28.511Z" }, + { url = "https://files.pythonhosted.org/packages/37/02/da6cb01fc6087048d7f61522c327edf4250f1683a58a839fdcc435746dd5/orjson-3.11.7-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9487abc2c2086e7c8eb9a211d2ce8855bae0e92586279d0d27b341d5ad76c85c", size = 228664, upload-time = "2026-02-02T15:37:25.542Z" }, + { url = "https://files.pythonhosted.org/packages/c1/c2/5885e7a5881dba9a9af51bc564e8967225a642b3e03d089289a35054e749/orjson-3.11.7-cp311-cp311-macosx_15_0_arm64.whl", hash = "sha256:79cacb0b52f6004caf92405a7e1f11e6e2de8bdf9019e4f76b44ba045125cd6b", size = 125344, upload-time = "2026-02-02T15:37:26.92Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1d/4e7688de0a92d1caf600dfd5fb70b4c5bfff51dfa61ac555072ef2d0d32a/orjson-3.11.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2e85fe4698b6a56d5e2ebf7ae87544d668eb6bde1ad1226c13f44663f20ec9e", size = 128404, upload-time = "2026-02-02T15:37:28.108Z" }, + { url = "https://files.pythonhosted.org/packages/2f/b2/ec04b74ae03a125db7bd69cffd014b227b7f341e3261bf75b5eb88a1aa92/orjson-3.11.7-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b8d14b71c0b12963fe8a62aac87119f1afdf4cb88a400f61ca5ae581449efcb5", size = 123677, upload-time = "2026-02-02T15:37:30.287Z" }, + { url = "https://files.pythonhosted.org/packages/4c/69/f95bdf960605f08f827f6e3291fe243d8aa9c5c9ff017a8d7232209184c3/orjson-3.11.7-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91c81ef070c8f3220054115e1ef468b1c9ce8497b4e526cb9f68ab4dc0a7ac62", size = 128950, upload-time = "2026-02-02T15:37:31.595Z" }, + { url = "https://files.pythonhosted.org/packages/a4/1b/de59c57bae1d148ef298852abd31909ac3089cff370dfd4cd84cc99cbc42/orjson-3.11.7-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:411ebaf34d735e25e358a6d9e7978954a9c9d58cfb47bc6683cdc3964cd2f910", size = 141756, upload-time = "2026-02-02T15:37:32.985Z" }, + { url = "https://files.pythonhosted.org/packages/ee/9e/9decc59f4499f695f65c650f6cfa6cd4c37a3fbe8fa235a0a3614cb54386/orjson-3.11.7-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a16bcd08ab0bcdfc7e8801d9c4a9cc17e58418e4d48ddc6ded4e9e4b1a94062b", size = 130812, upload-time = "2026-02-02T15:37:34.204Z" }, + { url = "https://files.pythonhosted.org/packages/28/e6/59f932bcabd1eac44e334fe8e3281a92eacfcb450586e1f4bde0423728d8/orjson-3.11.7-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9c0b51672e466fd7e56230ffbae7f1639e18d0ce023351fb75da21b71bc2c960", size = 133444, upload-time = "2026-02-02T15:37:35.446Z" }, + { url = "https://files.pythonhosted.org/packages/f1/36/b0f05c0eaa7ca30bc965e37e6a2956b0d67adb87a9872942d3568da846ae/orjson-3.11.7-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:136dcd6a2e796dfd9ffca9fc027d778567b0b7c9968d092842d3c323cef88aa8", size = 138609, upload-time = "2026-02-02T15:37:36.657Z" }, + { url = "https://files.pythonhosted.org/packages/b8/03/58ec7d302b8d86944c60c7b4b82975d5161fcce4c9bc8c6cb1d6741b6115/orjson-3.11.7-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:7ba61079379b0ae29e117db13bda5f28d939766e410d321ec1624afc6a0b0504", size = 408918, upload-time = "2026-02-02T15:37:38.076Z" }, + { url = "https://files.pythonhosted.org/packages/06/3a/868d65ef9a8b99be723bd510de491349618abd9f62c826cf206d962db295/orjson-3.11.7-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:0527a4510c300e3b406591b0ba69b5dc50031895b0a93743526a3fc45f59d26e", size = 143998, upload-time = "2026-02-02T15:37:39.706Z" }, + { url = "https://files.pythonhosted.org/packages/5b/c7/1e18e1c83afe3349f4f6dc9e14910f0ae5f82eac756d1412ea4018938535/orjson-3.11.7-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a709e881723c9b18acddcfb8ba357322491ad553e277cf467e1e7e20e2d90561", size = 134802, upload-time = "2026-02-02T15:37:41.002Z" }, + { url = "https://files.pythonhosted.org/packages/d4/0b/ccb7ee1a65b37e8eeb8b267dc953561d72370e85185e459616d4345bab34/orjson-3.11.7-cp311-cp311-win32.whl", hash = "sha256:c43b8b5bab288b6b90dac410cca7e986a4fa747a2e8f94615aea407da706980d", size = 127828, upload-time = "2026-02-02T15:37:42.241Z" }, + { url = "https://files.pythonhosted.org/packages/af/9e/55c776dffda3f381e0f07d010a4f5f3902bf48eaba1bb7684d301acd4924/orjson-3.11.7-cp311-cp311-win_amd64.whl", hash = "sha256:6543001328aa857187f905308a028935864aefe9968af3848401b6fe80dbb471", size = 124941, upload-time = "2026-02-02T15:37:43.444Z" }, + { url = "https://files.pythonhosted.org/packages/aa/8e/424a620fa7d263b880162505fb107ef5e0afaa765b5b06a88312ac291560/orjson-3.11.7-cp311-cp311-win_arm64.whl", hash = "sha256:1ee5cc7160a821dfe14f130bc8e63e7611051f964b463d9e2a3a573204446a4d", size = 126245, upload-time = "2026-02-02T15:37:45.18Z" }, + { url = "https://files.pythonhosted.org/packages/80/bf/76f4f1665f6983385938f0e2a5d7efa12a58171b8456c252f3bae8a4cf75/orjson-3.11.7-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:bd03ea7606833655048dab1a00734a2875e3e86c276e1d772b2a02556f0d895f", size = 228545, upload-time = "2026-02-02T15:37:46.376Z" }, + { url = "https://files.pythonhosted.org/packages/79/53/6c72c002cb13b5a978a068add59b25a8bdf2800ac1c9c8ecdb26d6d97064/orjson-3.11.7-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:89e440ebc74ce8ab5c7bc4ce6757b4a6b1041becb127df818f6997b5c71aa60b", size = 125224, upload-time = "2026-02-02T15:37:47.697Z" }, + { url = "https://files.pythonhosted.org/packages/2c/83/10e48852865e5dd151bdfe652c06f7da484578ed02c5fca938e3632cb0b8/orjson-3.11.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ede977b5fe5ac91b1dffc0a517ca4542d2ec8a6a4ff7b2652d94f640796342a", size = 128154, upload-time = "2026-02-02T15:37:48.954Z" }, + { url = "https://files.pythonhosted.org/packages/6e/52/a66e22a2b9abaa374b4a081d410edab6d1e30024707b87eab7c734afe28d/orjson-3.11.7-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b7b1dae39230a393df353827c855a5f176271c23434cfd2db74e0e424e693e10", size = 123548, upload-time = "2026-02-02T15:37:50.187Z" }, + { url = "https://files.pythonhosted.org/packages/de/38/605d371417021359f4910c496f764c48ceb8997605f8c25bf1dfe58c0ebe/orjson-3.11.7-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ed46f17096e28fb28d2975834836a639af7278aa87c84f68ab08fbe5b8bd75fa", size = 129000, upload-time = "2026-02-02T15:37:51.426Z" }, + { url = "https://files.pythonhosted.org/packages/44/98/af32e842b0ffd2335c89714d48ca4e3917b42f5d6ee5537832e069a4b3ac/orjson-3.11.7-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3726be79e36e526e3d9c1aceaadbfb4a04ee80a72ab47b3f3c17fefb9812e7b8", size = 141686, upload-time = "2026-02-02T15:37:52.607Z" }, + { url = "https://files.pythonhosted.org/packages/96/0b/fc793858dfa54be6feee940c1463370ece34b3c39c1ca0aa3845f5ba9892/orjson-3.11.7-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0724e265bc548af1dedebd9cb3d24b4e1c1e685a343be43e87ba922a5c5fff2f", size = 130812, upload-time = "2026-02-02T15:37:53.944Z" }, + { url = "https://files.pythonhosted.org/packages/dc/91/98a52415059db3f374757d0b7f0f16e3b5cd5976c90d1c2b56acaea039e6/orjson-3.11.7-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e7745312efa9e11c17fbd3cb3097262d079da26930ae9ae7ba28fb738367cbad", size = 133440, upload-time = "2026-02-02T15:37:55.615Z" }, + { url = "https://files.pythonhosted.org/packages/dc/b6/cb540117bda61791f46381f8c26c8f93e802892830a6055748d3bb1925ab/orjson-3.11.7-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f904c24bdeabd4298f7a977ef14ca2a022ca921ed670b92ecd16ab6f3d01f867", size = 138386, upload-time = "2026-02-02T15:37:56.814Z" }, + { url = "https://files.pythonhosted.org/packages/63/1a/50a3201c334a7f17c231eee5f841342190723794e3b06293f26e7cf87d31/orjson-3.11.7-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b9fc4d0f81f394689e0814617aadc4f2ea0e8025f38c226cbf22d3b5ddbf025d", size = 408853, upload-time = "2026-02-02T15:37:58.291Z" }, + { url = "https://files.pythonhosted.org/packages/87/cd/8de1c67d0be44fdc22701e5989c0d015a2adf391498ad42c4dc589cd3013/orjson-3.11.7-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:849e38203e5be40b776ed2718e587faf204d184fc9a008ae441f9442320c0cab", size = 144130, upload-time = "2026-02-02T15:38:00.163Z" }, + { url = "https://files.pythonhosted.org/packages/0f/fe/d605d700c35dd55f51710d159fc54516a280923cd1b7e47508982fbb387d/orjson-3.11.7-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4682d1db3bcebd2b64757e0ddf9e87ae5f00d29d16c5cdf3a62f561d08cc3dd2", size = 134818, upload-time = "2026-02-02T15:38:01.507Z" }, + { url = "https://files.pythonhosted.org/packages/e4/e4/15ecc67edb3ddb3e2f46ae04475f2d294e8b60c1825fbe28a428b93b3fbd/orjson-3.11.7-cp312-cp312-win32.whl", hash = "sha256:f4f7c956b5215d949a1f65334cf9d7612dde38f20a95f2315deef167def91a6f", size = 127923, upload-time = "2026-02-02T15:38:02.75Z" }, + { url = "https://files.pythonhosted.org/packages/34/70/2e0855361f76198a3965273048c8e50a9695d88cd75811a5b46444895845/orjson-3.11.7-cp312-cp312-win_amd64.whl", hash = "sha256:bf742e149121dc5648ba0a08ea0871e87b660467ef168a3a5e53bc1fbd64bb74", size = 125007, upload-time = "2026-02-02T15:38:04.032Z" }, + { url = "https://files.pythonhosted.org/packages/68/40/c2051bd19fc467610fed469dc29e43ac65891571138f476834ca192bc290/orjson-3.11.7-cp312-cp312-win_arm64.whl", hash = "sha256:26c3b9132f783b7d7903bf1efb095fed8d4a3a85ec0d334ee8beff3d7a4749d5", size = 126089, upload-time = "2026-02-02T15:38:05.297Z" }, +] + +[[package]] +name = "ormsgpack" +version = "1.12.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33", size = 39031, upload-time = "2026-01-18T20:55:28.023Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4b/08/8b68f24b18e69d92238aa8f258218e6dfeacf4381d9d07ab8df303f524a9/ormsgpack-1.12.2-cp311-cp311-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:bd5f4bf04c37888e864f08e740c5a573c4017f6fd6e99fa944c5c935fabf2dd9", size = 378266, upload-time = "2026-01-18T20:55:59.876Z" }, + { url = "https://files.pythonhosted.org/packages/0d/24/29fc13044ecb7c153523ae0a1972269fcd613650d1fa1a9cec1044c6b666/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d5b28b3570e9fed9a5a76528fc7230c3c76333bc214798958e58e9b79cc18a", size = 203035, upload-time = "2026-01-18T20:55:30.59Z" }, + { url = "https://files.pythonhosted.org/packages/ad/c2/00169fb25dd8f9213f5e8a549dfb73e4d592009ebc85fbbcd3e1dcac575b/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3708693412c28f3538fb5a65da93787b6bbab3484f6bc6e935bfb77a62400ae5", size = 210539, upload-time = "2026-01-18T20:55:48.569Z" }, + { url = "https://files.pythonhosted.org/packages/1b/33/543627f323ff3c73091f51d6a20db28a1a33531af30873ea90c5ac95a9b5/ormsgpack-1.12.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:43013a3f3e2e902e1d05e72c0f1aeb5bedbb8e09240b51e26792a3c89267e181", size = 212401, upload-time = "2026-01-18T20:56:10.101Z" }, + { url = "https://files.pythonhosted.org/packages/e8/5d/f70e2c3da414f46186659d24745483757bcc9adccb481a6eb93e2b729301/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7c8b1667a72cbba74f0ae7ecf3105a5e01304620ed14528b2cb4320679d2869b", size = 387082, upload-time = "2026-01-18T20:56:12.047Z" }, + { url = "https://files.pythonhosted.org/packages/c0/d6/06e8dc920c7903e051f30934d874d4afccc9bb1c09dcaf0bc03a7de4b343/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:df6961442140193e517303d0b5d7bc2e20e69a879c2d774316125350c4a76b92", size = 482346, upload-time = "2026-01-18T20:56:05.152Z" }, + { url = "https://files.pythonhosted.org/packages/66/c4/f337ac0905eed9c393ef990c54565cd33644918e0a8031fe48c098c71dbf/ormsgpack-1.12.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c6a4c34ddef109647c769d69be65fa1de7a6022b02ad45546a69b3216573eb4a", size = 425181, upload-time = "2026-01-18T20:55:37.83Z" }, + { url = "https://files.pythonhosted.org/packages/78/29/6d5758fabef3babdf4bbbc453738cc7de9cd3334e4c38dd5737e27b85653/ormsgpack-1.12.2-cp311-cp311-win_amd64.whl", hash = "sha256:73670ed0375ecc303858e3613f407628dd1fca18fe6ac57b7b7ce66cc7bb006c", size = 117182, upload-time = "2026-01-18T20:55:31.472Z" }, + { url = "https://files.pythonhosted.org/packages/c4/57/17a15549233c37e7fd054c48fe9207492e06b026dbd872b826a0b5f833b6/ormsgpack-1.12.2-cp311-cp311-win_arm64.whl", hash = "sha256:c2be829954434e33601ae5da328cccce3266b098927ca7a30246a0baec2ce7bd", size = 111464, upload-time = "2026-01-18T20:55:38.811Z" }, + { url = "https://files.pythonhosted.org/packages/4c/36/16c4b1921c308a92cef3bf6663226ae283395aa0ff6e154f925c32e91ff5/ormsgpack-1.12.2-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7a29d09b64b9694b588ff2f80e9826bdceb3a2b91523c5beae1fab27d5c940e7", size = 378618, upload-time = "2026-01-18T20:55:50.835Z" }, + { url = "https://files.pythonhosted.org/packages/c0/68/468de634079615abf66ed13bb5c34ff71da237213f29294363beeeca5306/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b39e629fd2e1c5b2f46f99778450b59454d1f901bc507963168985e79f09c5d", size = 203186, upload-time = "2026-01-18T20:56:11.163Z" }, + { url = "https://files.pythonhosted.org/packages/73/a9/d756e01961442688b7939bacd87ce13bfad7d26ce24f910f6028178b2cc8/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:958dcb270d30a7cb633a45ee62b9444433fa571a752d2ca484efdac07480876e", size = 210738, upload-time = "2026-01-18T20:56:09.181Z" }, + { url = "https://files.pythonhosted.org/packages/7b/ba/795b1036888542c9113269a3f5690ab53dd2258c6fb17676ac4bd44fcf94/ormsgpack-1.12.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58d379d72b6c5e964851c77cfedfb386e474adee4fd39791c2c5d9efb53505cc", size = 212569, upload-time = "2026-01-18T20:56:06.135Z" }, + { url = "https://files.pythonhosted.org/packages/6c/aa/bff73c57497b9e0cba8837c7e4bcab584b1a6dbc91a5dd5526784a5030c8/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8463a3fc5f09832e67bdb0e2fda6d518dc4281b133166146a67f54c08496442e", size = 387166, upload-time = "2026-01-18T20:55:36.738Z" }, + { url = "https://files.pythonhosted.org/packages/d3/cf/f8283cba44bcb7b14f97b6274d449db276b3a86589bdb363169b51bc12de/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:eddffb77eff0bad4e67547d67a130604e7e2dfbb7b0cde0796045be4090f35c6", size = 482498, upload-time = "2026-01-18T20:55:29.626Z" }, + { url = "https://files.pythonhosted.org/packages/05/be/71e37b852d723dfcbe952ad04178c030df60d6b78eba26bfd14c9a40575e/ormsgpack-1.12.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fcd55e5f6ba0dbce624942adf9f152062135f991a0126064889f68eb850de0dd", size = 425518, upload-time = "2026-01-18T20:55:49.556Z" }, + { url = "https://files.pythonhosted.org/packages/7a/0c/9803aa883d18c7ef197213cd2cbf73ba76472a11fe100fb7dab2884edf48/ormsgpack-1.12.2-cp312-cp312-win_amd64.whl", hash = "sha256:d024b40828f1dde5654faebd0d824f9cc29ad46891f626272dd5bfd7af2333a4", size = 117462, upload-time = "2026-01-18T20:55:47.726Z" }, + { url = "https://files.pythonhosted.org/packages/c8/9e/029e898298b2cc662f10d7a15652a53e3b525b1e7f07e21fef8536a09bb8/ormsgpack-1.12.2-cp312-cp312-win_arm64.whl", hash = "sha256:da538c542bac7d1c8f3f2a937863dba36f013108ce63e55745941dda4b75dbb6", size = 111559, upload-time = "2026-01-18T20:55:54.273Z" }, ] [[package]] name = "oss2" -version = "2.18.5" +version = "2.19.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aliyun-python-sdk-core" }, @@ -4855,7 +4971,7 @@ dependencies = [ { name = "requests" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/61/ce/d23a9d44268dc992ae1a878d24341dddaea4de4ae374c261209bb6e9554b/oss2-2.18.5.tar.gz", hash = "sha256:555c857f4441ae42a2c0abab8fc9482543fba35d65a4a4be73101c959a2b4011", size = 283388, upload-time = "2024-04-29T12:49:07.686Z" } +sdist = { url = "https://files.pythonhosted.org/packages/df/b5/f2cb1950dda46ac2284d6c950489fdacd0e743c2d79a347924d3cc44b86f/oss2-2.19.1.tar.gz", hash = "sha256:a8ab9ee7eb99e88a7e1382edc6ea641d219d585a7e074e3776e9dec9473e59c1", size = 298845, upload-time = "2024-10-25T11:37:46.638Z" } [[package]] name = "overrides" @@ -4877,30 +4993,31 @@ wheels = [ [[package]] name = "pandas" -version = "2.2.3" +version = "3.0.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "python-dateutil" }, - { name = "pytz" }, - { name = "tzdata" }, + { name = "tzdata", marker = "sys_platform == 'emscripten' or sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9c/d6/9f8431bacc2e19dca897724cd097b1bb224a6ad5433784a44b587c7c13af/pandas-2.2.3.tar.gz", hash = "sha256:4f18ba62b61d7e192368b84517265a99b4d7ee8912f8708660fb4a366cc82667", size = 4399213, upload-time = "2024-09-20T13:10:04.827Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/0c/b28ed414f080ee0ad153f848586d61d1878f91689950f037f976ce15f6c8/pandas-3.0.1.tar.gz", hash = "sha256:4186a699674af418f655dbd420ed87f50d56b4cd6603784279d9eef6627823c8", size = 4641901, upload-time = "2026-02-17T22:20:16.434Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/44/d9502bf0ed197ba9bf1103c9867d5904ddcaf869e52329787fc54ed70cc8/pandas-2.2.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:66108071e1b935240e74525006034333f98bcdb87ea116de573a6a0dccb6c039", size = 12602222, upload-time = "2024-09-20T13:08:56.254Z" }, - { url = "https://files.pythonhosted.org/packages/52/11/9eac327a38834f162b8250aab32a6781339c69afe7574368fffe46387edf/pandas-2.2.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:7c2875855b0ff77b2a64a0365e24455d9990730d6431b9e0ee18ad8acee13dbd", size = 11321274, upload-time = "2024-09-20T13:08:58.645Z" }, - { url = "https://files.pythonhosted.org/packages/45/fb/c4beeb084718598ba19aa9f5abbc8aed8b42f90930da861fcb1acdb54c3a/pandas-2.2.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cd8d0c3be0515c12fed0bdbae072551c8b54b7192c7b1fda0ba56059a0179698", size = 15579836, upload-time = "2024-09-20T19:01:57.571Z" }, - { url = "https://files.pythonhosted.org/packages/cd/5f/4dba1d39bb9c38d574a9a22548c540177f78ea47b32f99c0ff2ec499fac5/pandas-2.2.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c124333816c3a9b03fbeef3a9f230ba9a737e9e5bb4060aa2107a86cc0a497fc", size = 13058505, upload-time = "2024-09-20T13:09:01.501Z" }, - { url = "https://files.pythonhosted.org/packages/b9/57/708135b90391995361636634df1f1130d03ba456e95bcf576fada459115a/pandas-2.2.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:63cc132e40a2e084cf01adf0775b15ac515ba905d7dcca47e9a251819c575ef3", size = 16744420, upload-time = "2024-09-20T19:02:00.678Z" }, - { url = "https://files.pythonhosted.org/packages/86/4a/03ed6b7ee323cf30404265c284cee9c65c56a212e0a08d9ee06984ba2240/pandas-2.2.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:29401dbfa9ad77319367d36940cd8a0b3a11aba16063e39632d98b0e931ddf32", size = 14440457, upload-time = "2024-09-20T13:09:04.105Z" }, - { url = "https://files.pythonhosted.org/packages/ed/8c/87ddf1fcb55d11f9f847e3c69bb1c6f8e46e2f40ab1a2d2abadb2401b007/pandas-2.2.3-cp311-cp311-win_amd64.whl", hash = "sha256:3fc6873a41186404dad67245896a6e440baacc92f5b716ccd1bc9ed2995ab2c5", size = 11617166, upload-time = "2024-09-20T13:09:06.917Z" }, - { url = "https://files.pythonhosted.org/packages/17/a3/fb2734118db0af37ea7433f57f722c0a56687e14b14690edff0cdb4b7e58/pandas-2.2.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b1d432e8d08679a40e2a6d8b2f9770a5c21793a6f9f47fdd52c5ce1948a5a8a9", size = 12529893, upload-time = "2024-09-20T13:09:09.655Z" }, - { url = "https://files.pythonhosted.org/packages/e1/0c/ad295fd74bfac85358fd579e271cded3ac969de81f62dd0142c426b9da91/pandas-2.2.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a5a1595fe639f5988ba6a8e5bc9649af3baf26df3998a0abe56c02609392e0a4", size = 11363475, upload-time = "2024-09-20T13:09:14.718Z" }, - { url = "https://files.pythonhosted.org/packages/c6/2a/4bba3f03f7d07207481fed47f5b35f556c7441acddc368ec43d6643c5777/pandas-2.2.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5de54125a92bb4d1c051c0659e6fcb75256bf799a732a87184e5ea503965bce3", size = 15188645, upload-time = "2024-09-20T19:02:03.88Z" }, - { url = "https://files.pythonhosted.org/packages/38/f8/d8fddee9ed0d0c0f4a2132c1dfcf0e3e53265055da8df952a53e7eaf178c/pandas-2.2.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fffb8ae78d8af97f849404f21411c95062db1496aeb3e56f146f0355c9989319", size = 12739445, upload-time = "2024-09-20T13:09:17.621Z" }, - { url = "https://files.pythonhosted.org/packages/20/e8/45a05d9c39d2cea61ab175dbe6a2de1d05b679e8de2011da4ee190d7e748/pandas-2.2.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:6dfcb5ee8d4d50c06a51c2fffa6cff6272098ad6540aed1a76d15fb9318194d8", size = 16359235, upload-time = "2024-09-20T19:02:07.094Z" }, - { url = "https://files.pythonhosted.org/packages/1d/99/617d07a6a5e429ff90c90da64d428516605a1ec7d7bea494235e1c3882de/pandas-2.2.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:062309c1b9ea12a50e8ce661145c6aab431b1e99530d3cd60640e255778bd43a", size = 14056756, upload-time = "2024-09-20T13:09:20.474Z" }, - { url = "https://files.pythonhosted.org/packages/29/d4/1244ab8edf173a10fd601f7e13b9566c1b525c4f365d6bee918e68381889/pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13", size = 11504248, upload-time = "2024-09-20T13:09:23.137Z" }, + { url = "https://files.pythonhosted.org/packages/ff/07/c7087e003ceee9b9a82539b40414ec557aa795b584a1a346e89180853d79/pandas-3.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:de09668c1bf3b925c07e5762291602f0d789eca1b3a781f99c1c78f6cac0e7ea", size = 10323380, upload-time = "2026-02-17T22:18:16.133Z" }, + { url = "https://files.pythonhosted.org/packages/c1/27/90683c7122febeefe84a56f2cde86a9f05f68d53885cebcc473298dfc33e/pandas-3.0.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:24ba315ba3d6e5806063ac6eb717504e499ce30bd8c236d8693a5fd3f084c796", size = 9923455, upload-time = "2026-02-17T22:18:19.13Z" }, + { url = "https://files.pythonhosted.org/packages/0e/f1/ed17d927f9950643bc7631aa4c99ff0cc83a37864470bc419345b656a41f/pandas-3.0.1-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:406ce835c55bac912f2a0dcfaf27c06d73c6b04a5dde45f1fd3169ce31337389", size = 10753464, upload-time = "2026-02-17T22:18:21.134Z" }, + { url = "https://files.pythonhosted.org/packages/2e/7c/870c7e7daec2a6c7ff2ac9e33b23317230d4e4e954b35112759ea4a924a7/pandas-3.0.1-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:830994d7e1f31dd7e790045235605ab61cff6c94defc774547e8b7fdfbff3dc7", size = 11255234, upload-time = "2026-02-17T22:18:24.175Z" }, + { url = "https://files.pythonhosted.org/packages/5c/39/3653fe59af68606282b989c23d1a543ceba6e8099cbcc5f1d506a7bae2aa/pandas-3.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:a64ce8b0f2de1d2efd2ae40b0abe7f8ae6b29fbfb3812098ed5a6f8e235ad9bf", size = 11767299, upload-time = "2026-02-17T22:18:26.824Z" }, + { url = "https://files.pythonhosted.org/packages/9b/31/1daf3c0c94a849c7a8dab8a69697b36d313b229918002ba3e409265c7888/pandas-3.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:9832c2c69da24b602c32e0c7b1b508a03949c18ba08d4d9f1c1033426685b447", size = 12333292, upload-time = "2026-02-17T22:18:28.996Z" }, + { url = "https://files.pythonhosted.org/packages/1f/67/af63f83cd6ca603a00fe8530c10a60f0879265b8be00b5930e8e78c5b30b/pandas-3.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:84f0904a69e7365f79a0c77d3cdfccbfb05bf87847e3a51a41e1426b0edb9c79", size = 9892176, upload-time = "2026-02-17T22:18:31.79Z" }, + { url = "https://files.pythonhosted.org/packages/79/ab/9c776b14ac4b7b4140788eca18468ea39894bc7340a408f1d1e379856a6b/pandas-3.0.1-cp311-cp311-win_arm64.whl", hash = "sha256:4a68773d5a778afb31d12e34f7dd4612ab90de8c6fb1d8ffe5d4a03b955082a1", size = 9151328, upload-time = "2026-02-17T22:18:35.721Z" }, + { url = "https://files.pythonhosted.org/packages/37/51/b467209c08dae2c624873d7491ea47d2b47336e5403309d433ea79c38571/pandas-3.0.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:476f84f8c20c9f5bc47252b66b4bb25e1a9fc2fa98cead96744d8116cb85771d", size = 10344357, upload-time = "2026-02-17T22:18:38.262Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f1/e2567ffc8951ab371db2e40b2fe068e36b81d8cf3260f06ae508700e5504/pandas-3.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0ab749dfba921edf641d4036c4c21c0b3ea70fea478165cb98a998fb2a261955", size = 9884543, upload-time = "2026-02-17T22:18:41.476Z" }, + { url = "https://files.pythonhosted.org/packages/d7/39/327802e0b6d693182403c144edacbc27eb82907b57062f23ef5a4c4a5ea7/pandas-3.0.1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b8e36891080b87823aff3640c78649b91b8ff6eea3c0d70aeabd72ea43ab069b", size = 10396030, upload-time = "2026-02-17T22:18:43.822Z" }, + { url = "https://files.pythonhosted.org/packages/3d/fe/89d77e424365280b79d99b3e1e7d606f5165af2f2ecfaf0c6d24c799d607/pandas-3.0.1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:532527a701281b9dd371e2f582ed9094f4c12dd9ffb82c0c54ee28d8ac9520c4", size = 10876435, upload-time = "2026-02-17T22:18:45.954Z" }, + { url = "https://files.pythonhosted.org/packages/b5/a6/2a75320849dd154a793f69c951db759aedb8d1dd3939eeacda9bdcfa1629/pandas-3.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:356e5c055ed9b0da1580d465657bc7d00635af4fd47f30afb23025352ba764d1", size = 11405133, upload-time = "2026-02-17T22:18:48.533Z" }, + { url = "https://files.pythonhosted.org/packages/58/53/1d68fafb2e02d7881df66aa53be4cd748d25cbe311f3b3c85c93ea5d30ca/pandas-3.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:9d810036895f9ad6345b8f2a338dd6998a74e8483847403582cab67745bff821", size = 11932065, upload-time = "2026-02-17T22:18:50.837Z" }, + { url = "https://files.pythonhosted.org/packages/75/08/67cc404b3a966b6df27b38370ddd96b3b023030b572283d035181854aac5/pandas-3.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:536232a5fe26dd989bd633e7a0c450705fdc86a207fec7254a55e9a22950fe43", size = 9741627, upload-time = "2026-02-17T22:18:53.905Z" }, + { url = "https://files.pythonhosted.org/packages/86/4f/caf9952948fb00d23795f09b893d11f1cacb384e666854d87249530f7cbe/pandas-3.0.1-cp312-cp312-win_arm64.whl", hash = "sha256:0f463ebfd8de7f326d38037c7363c6dacb857c5881ab8961fb387804d6daf2f7", size = 9052483, upload-time = "2026-02-17T22:18:57.31Z" }, ] [package.optional-dependencies] @@ -4924,37 +5041,36 @@ performance = [ [[package]] name = "pandas-stubs" -version = "2.2.3.250527" +version = "3.0.0.260204" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, - { name = "types-pytz" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5f/0d/5fe7f7f3596eb1c2526fea151e9470f86b379183d8b9debe44b2098651ca/pandas_stubs-2.2.3.250527.tar.gz", hash = "sha256:e2d694c4e72106055295ad143664e5c99e5815b07190d1ff85b73b13ff019e63", size = 106312, upload-time = "2025-05-27T15:24:29.716Z" } +sdist = { url = "https://files.pythonhosted.org/packages/27/1d/297ff2c7ea50a768a2247621d6451abb2a07c0e9be7ca6d36ebe371658e5/pandas_stubs-3.0.0.260204.tar.gz", hash = "sha256:bf9294b76352effcffa9cb85edf0bed1339a7ec0c30b8e1ac3d66b4228f1fbc3", size = 109383, upload-time = "2026-02-04T15:17:17.247Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/f8/46141ba8c9d7064dc5008bfb4a6ae5bd3c30e4c61c28b5c5ed485bf358ba/pandas_stubs-2.2.3.250527-py3-none-any.whl", hash = "sha256:cd0a49a95b8c5f944e605be711042a4dd8550e2c559b43d70ba2c4b524b66163", size = 159683, upload-time = "2025-05-27T15:24:28.4Z" }, + { url = "https://files.pythonhosted.org/packages/7c/2f/f91e4eee21585ff548e83358332d5632ee49f6b2dcd96cb5dca4e0468951/pandas_stubs-3.0.0.260204-py3-none-any.whl", hash = "sha256:5ab9e4d55a6e2752e9720828564af40d48c4f709e6a2c69b743014a6fcb6c241", size = 168540, upload-time = "2026-02-04T15:17:15.615Z" }, ] [[package]] name = "pathspec" -version = "0.12.1" +version = "1.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043, upload-time = "2023-12-10T22:30:45Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/36/e27608899f9b8d4dff0617b2d9ab17ca5608956ca44461ac14ac48b44015/pathspec-1.0.4.tar.gz", hash = "sha256:0210e2ae8a21a9137c0d470578cb0e595af87edaa6ebf12ff176f14a02e0e645", size = 131200, upload-time = "2026-01-27T03:59:46.938Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191, upload-time = "2023-12-10T22:30:43.14Z" }, + { url = "https://files.pythonhosted.org/packages/ef/3c/2c197d226f9ea224a9ab8d197933f9da0ae0aac5b6e0f884e2b8d9c8e9f7/pathspec-1.0.4-py3-none-any.whl", hash = "sha256:fb6ae2fd4e7c921a165808a552060e722767cfa526f99ca5156ed2ce45a5c723", size = 55206, upload-time = "2026-01-27T03:59:45.137Z" }, ] [[package]] name = "pdfminer-six" -version = "20251230" +version = "20260107" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "charset-normalizer" }, { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" }, + { url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" }, ] [[package]] @@ -4977,13 +5093,14 @@ sqlalchemy = [ [[package]] name = "pgvector" -version = "0.2.5" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/25/6c/6d8b4b03b958c02fa8687ec6063c49d952a189f8c91ebbe51e877dfab8f7/pgvector-0.4.2.tar.gz", hash = "sha256:322cac0c1dc5d41c9ecf782bd9991b7966685dee3a00bc873631391ed949513a", size = 31354, upload-time = "2025-12-05T01:07:17.87Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/bb/4686b1090a7c68fa367e981130a074dc6c1236571d914ffa6e05c882b59d/pgvector-0.2.5-py2.py3-none-any.whl", hash = "sha256:5e5e93ec4d3c45ab1fa388729d56c602f6966296e19deee8878928c6d567e41b", size = 9638, upload-time = "2024-02-07T19:35:03.8Z" }, + { url = "https://files.pythonhosted.org/packages/5a/26/6cee8a1ce8c43625ec561aff19df07f9776b7525d9002c86bceb3e0ac970/pgvector-0.4.2-py3-none-any.whl", hash = "sha256:549d45f7a18593783d5eec609ea1684a724ba8405c4cb182a0b2b08aeff04e08", size = 27441, upload-time = "2025-12-05T01:07:16.536Z" }, ] [[package]] @@ -5023,22 +5140,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f2/26/c56ce33ca856e358d27fda9676c055395abddb82c35ac0f593877ed4562e/pillow-12.1.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:cb9bb857b2d057c6dfc72ac5f3b44836924ba15721882ef103cecb40d002d80e", size = 7029880, upload-time = "2026-02-11T04:23:04.783Z" }, ] -[[package]] -name = "pkginfo" -version = "1.10.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/2f/72/347ec5be4adc85c182ed2823d8d1c7b51e13b9a6b0c1aae59582eca652df/pkginfo-1.10.0.tar.gz", hash = "sha256:5df73835398d10db79f8eecd5cd86b1f6d29317589ea70796994d49399af6297", size = 378457, upload-time = "2024-03-03T08:34:21.011Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/56/09/054aea9b7534a15ad38a363a2bd974c20646ab1582a387a95b8df1bfea1c/pkginfo-1.10.0-py3-none-any.whl", hash = "sha256:889a6da2ed7ffc58ab5b900d888ddce90bce912f2d2de1dc1c26f4cb9fe65097", size = 30392, upload-time = "2024-03-03T08:34:18.891Z" }, -] - [[package]] name = "platformdirs" -version = "4.5.0" +version = "4.9.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/61/33/9611380c2bdb1225fdef633e2a9610622310fed35ab11dac9620972ee088/platformdirs-4.5.0.tar.gz", hash = "sha256:70ddccdd7c99fc5942e9fc25636a8b34d04c24b335100223152c2803e4063312", size = 21632, upload-time = "2025-10-08T17:44:48.791Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/56/8d4c30c8a1d07013911a8fdbd8f89440ef9f08d07a1b50ab8ca8be5a20f9/platformdirs-4.9.4.tar.gz", hash = "sha256:1ec356301b7dc906d83f371c8f487070e99d3ccf9e501686456394622a01a934", size = 28737, upload-time = "2026-03-05T18:34:13.271Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/cb/ac7874b3e5d58441674fb70742e6c374b28b0c7cb988d37d991cde47166c/platformdirs-4.5.0-py3-none-any.whl", hash = "sha256:e578a81bb873cbb89a41fcc904c7ef523cc18284b7e3b3ccf06aca1403b7ebd3", size = 18651, upload-time = "2025-10-08T17:44:47.223Z" }, + { url = "https://files.pythonhosted.org/packages/63/d7/97f7e3a6abb67d8080dd406fd4df842c2be0efaf712d1c899c32a075027c/platformdirs-4.9.4-py3-none-any.whl", hash = "sha256:68a9a4619a666ea6439f2ff250c12a853cd1cbd5158d258bd824a7df6be2f868", size = 21216, upload-time = "2026-03-05T18:34:12.172Z" }, ] [[package]] @@ -5050,24 +5158,16 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] -[[package]] -name = "ply" -version = "3.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e5/69/882ee5c9d017149285cab114ebeab373308ef0f874fcdac9beb90e0ac4da/ply-3.11.tar.gz", hash = "sha256:00c7c1aaa88358b9c765b6d3000c6eec0ba42abca5351b095321aef446081da3", size = 159130, upload-time = "2018-02-15T19:01:31.097Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/58/35da89ee790598a0700ea49b2a66594140f44dec458c07e8e3d4979137fc/ply-3.11-py2.py3-none-any.whl", hash = "sha256:096f9b8350b65ebd2fd1346b12452efe5b9607f7482813ffca50c22722a807ce", size = 49567, upload-time = "2018-02-15T19:01:27.172Z" }, -] - [[package]] name = "polyfile-weave" -version = "0.5.8" +version = "0.5.9" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "abnf" }, { name = "chardet" }, { name = "cint" }, { name = "fickling" }, + { name = "filelock" }, { name = "graphviz" }, { name = "intervaltree" }, { name = "jinja2" }, @@ -5077,11 +5177,10 @@ dependencies = [ { name = "pillow" }, { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "pyyaml" }, - { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e7/d4/76e56e4429646d9353b4287794f8324ff94201bdb0a2c35ce88cf3de90d0/polyfile_weave-0.5.8.tar.gz", hash = "sha256:cf2ca6a1351165fbbf2971ace4b8bebbb03b2c00e4f2159ff29bed88854e7b32", size = 5989602, upload-time = "2026-01-08T04:21:26.689Z" } +sdist = { url = "https://files.pythonhosted.org/packages/70/55/e5400762e3884f743d59291e71eaaa9c52dd7e144b75a11911e74ec1bac9/polyfile_weave-0.5.9.tar.gz", hash = "sha256:12341fab03e06ede1bfebbd3627dd24015fde5353ea74ece2da186321b818bdb", size = 6024974, upload-time = "2026-01-22T22:08:48.081Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/54/32/c09fd626366c00325d1981e310be5cac8661c09206098d267a592e0c5000/polyfile_weave-0.5.8-py3-none-any.whl", hash = "sha256:f68c570ef189a4219798a7c797730fc3b7feace7ff5bd7e662490f89b772964a", size = 1656208, upload-time = "2026-01-08T04:21:15.213Z" }, + { url = "https://files.pythonhosted.org/packages/52/94/215005530a48c5f7d4ec4a31acdb5828f2bfb985cc6e577b0eaa5882c0e2/polyfile_weave-0.5.9-py3-none-any.whl", hash = "sha256:6ae4b1b5eeac9f5bfc862474484d6d3e33655fab31749d93af0b0a91fddabfc7", size = 1700174, upload-time = "2026-01-22T22:08:46.346Z" }, ] [[package]] @@ -5112,7 +5211,7 @@ wheels = [ [[package]] name = "posthog" -version = "7.0.1" +version = "5.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, @@ -5120,27 +5219,38 @@ dependencies = [ { name = "python-dateutil" }, { name = "requests" }, { name = "six" }, - { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a2/d4/b9afe855a8a7a1bf4459c28ae4c300b40338122dc850acabefcf2c3df24d/posthog-7.0.1.tar.gz", hash = "sha256:21150562c2630a599c1d7eac94bc5c64eb6f6acbf3ff52ccf1e57345706db05a", size = 126985, upload-time = "2025-11-15T12:44:22.465Z" } +sdist = { url = "https://files.pythonhosted.org/packages/48/20/60ae67bb9d82f00427946218d49e2e7e80fb41c15dc5019482289ec9ce8d/posthog-5.4.0.tar.gz", hash = "sha256:701669261b8d07cdde0276e5bc096b87f9e200e3b9589c5ebff14df658c5893c", size = 88076, upload-time = "2025-06-20T23:19:23.485Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/05/0c/8b6b20b0be71725e6e8a32dcd460cdbf62fe6df9bc656a650150dc98fedd/posthog-7.0.1-py3-none-any.whl", hash = "sha256:efe212d8d88a9ba80a20c588eab4baf4b1a5e90e40b551160a5603bb21e96904", size = 145234, upload-time = "2025-11-15T12:44:21.247Z" }, + { url = "https://files.pythonhosted.org/packages/4f/98/e480cab9a08d1c09b1c59a93dade92c1bb7544826684ff2acbfd10fcfbd4/posthog-5.4.0-py3-none-any.whl", hash = "sha256:284dfa302f64353484420b52d4ad81ff5c2c2d1d607c4e2db602ac72761831bd", size = 105364, upload-time = "2025-06-20T23:19:22.001Z" }, ] [[package]] -name = "pre-commit" -version = "4.5.1" +name = "preshed" +version = "3.0.12" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cfgv" }, - { name = "identify" }, - { name = "nodeenv" }, - { name = "pyyaml" }, - { name = "virtualenv" }, + { name = "cymem" }, + { name = "murmurhash" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/40/f1/6d86a29246dfd2e9b6237f0b5823717f60cad94d47ddc26afa916d21f525/pre_commit-4.5.1.tar.gz", hash = "sha256:eb545fcff725875197837263e977ea257a402056661f09dae08e4b149b030a61", size = 198232, upload-time = "2025-12-16T21:14:33.552Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/34/eb4f5f0f678e152a96e826da867d2f41c4b18a2d589e40e1dd3347219e91/preshed-3.0.12.tar.gz", hash = "sha256:b73f9a8b54ee1d44529cc6018356896cff93d48f755f29c134734d9371c0d685", size = 15027, upload-time = "2025-11-17T13:00:33.621Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/19/fd3ef348460c80af7bb4669ea7926651d1f95c23ff2df18b9d24bab4f3fa/pre_commit-4.5.1-py2.py3-none-any.whl", hash = "sha256:3b3afd891e97337708c1674210f8eba659b52a38ea5f822ff142d10786221f77", size = 226437, upload-time = "2025-12-16T21:14:32.409Z" }, + { url = "https://files.pythonhosted.org/packages/1e/54/d1e02d0a0ea348fb6a769506166e366abfe87ee917c2f11f7139c7acbf10/preshed-3.0.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:bc45fda3fd4ae1ae15c37f18f0777cf389ce9184ef8884b39b18894416fd1341", size = 128439, upload-time = "2025-11-17T12:59:21.317Z" }, + { url = "https://files.pythonhosted.org/packages/8c/cb/685ca57ca6e438345b3f6c20226705a0e056a3de399a5bf8a9ee89b3dd2b/preshed-3.0.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:75d6e628bc78c022dbb9267242715718f862c3105927732d166076ff009d65de", size = 124544, upload-time = "2025-11-17T12:59:22.944Z" }, + { url = "https://files.pythonhosted.org/packages/f8/07/018fcd3bf298304e1570065cf80601ac16acd29f799578fd47b715dd3ca2/preshed-3.0.12-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b901cff5c814facf7a864b0a4c14a16d45fa1379899a585b3fb48ee36a2dccdb", size = 824728, upload-time = "2025-11-17T12:59:24.614Z" }, + { url = "https://files.pythonhosted.org/packages/79/dc/d888b328fcedae530df53396d9fc0006026aa8793fec54d7d34f57f31ff5/preshed-3.0.12-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d1099253bf73dd3c39313280bd5331841f769637b27ddb576ff362c4e7bad298", size = 825969, upload-time = "2025-11-17T12:59:26.493Z" }, + { url = "https://files.pythonhosted.org/packages/21/51/f19933301f42ece1ffef1f7f4c370d09f0351c43c528e66fac24560e44d2/preshed-3.0.12-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1af4a049ffe9d0246e5dc10d6f54820ed064c40e5c3f7b6526127c664008297c", size = 842346, upload-time = "2025-11-17T12:59:28.092Z" }, + { url = "https://files.pythonhosted.org/packages/51/46/025f60fd3d51bf60606a0f8f0cd39c40068b9b5e4d249bca1682e4ff09c3/preshed-3.0.12-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:57159bcedca0cb4c99390f8a6e730f8659fdb663a5a3efcd9c4531e0f54b150e", size = 865504, upload-time = "2025-11-17T12:59:29.648Z" }, + { url = "https://files.pythonhosted.org/packages/88/b5/2e6ee5ab19b03e7983fc5e1850c812fb71dc178dd140d6aca3b45306bdf7/preshed-3.0.12-cp311-cp311-win_amd64.whl", hash = "sha256:8fe9cf1745e203e5aa58b8700436f78da1dcf0f0e2efb0054b467effd9d7d19d", size = 117736, upload-time = "2025-11-17T12:59:30.974Z" }, + { url = "https://files.pythonhosted.org/packages/1e/17/8a0a8f4b01e71b5fb7c5cd4c9fec04d7b852d42f1f9e096b01e7d2b16b17/preshed-3.0.12-cp311-cp311-win_arm64.whl", hash = "sha256:12d880f8786cb6deac34e99b8b07146fb92d22fbca0023208e03325f5944606b", size = 105127, upload-time = "2025-11-17T12:59:32.171Z" }, + { url = "https://files.pythonhosted.org/packages/4b/f7/ff3aca937eeaee19c52c45ddf92979546e52ed0686e58be4bc09c47e7d88/preshed-3.0.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2779861f5d69480493519ed123a622a13012d1182126779036b99d9d989bf7e9", size = 129958, upload-time = "2025-11-17T12:59:33.391Z" }, + { url = "https://files.pythonhosted.org/packages/80/24/fd654a9c0f5f3ed1a9b1d8a392f063ae9ca29ad0b462f0732ae0147f7cee/preshed-3.0.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ffe1fd7d92f51ed34383e20d8b734780c814ca869cfdb7e07f2d31651f90cdf4", size = 124550, upload-time = "2025-11-17T12:59:34.688Z" }, + { url = "https://files.pythonhosted.org/packages/71/49/8271c7f680696f4b0880f44357d2a903d649cb9f6e60a1efc97a203104df/preshed-3.0.12-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:91893404858502cc4e856d338fef3d2a4a552135f79a1041c24eb919817c19db", size = 874987, upload-time = "2025-11-17T12:59:36.062Z" }, + { url = "https://files.pythonhosted.org/packages/a3/a5/ca200187ca1632f1e2c458b72f1bd100fa8b55deecd5d72e1e4ebf09e98c/preshed-3.0.12-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9e06e8f2ba52f183eb9817a616cdebe84a211bb859a2ffbc23f3295d0b189638", size = 866499, upload-time = "2025-11-17T12:59:37.586Z" }, + { url = "https://files.pythonhosted.org/packages/87/a1/943b61f850c44899910c21996cb542d0ef5931744c6d492fdfdd8457e693/preshed-3.0.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bbe8b8a2d4f9af14e8a39ecca524b9de6defc91d8abcc95eb28f42da1c23272c", size = 878064, upload-time = "2025-11-17T12:59:39.651Z" }, + { url = "https://files.pythonhosted.org/packages/3e/75/d7fff7f1fa3763619aa85d6ba70493a5d9c6e6ea7958a6e8c9d3e6e88bbe/preshed-3.0.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5d0aaac9c5862f5471fddd0c931dc64d3af2efc5fe3eb48b50765adb571243b9", size = 900540, upload-time = "2025-11-17T12:59:41.384Z" }, + { url = "https://files.pythonhosted.org/packages/e4/12/a2285b78bd097a1e53fb90a1743bc8ce0d35e5b65b6853f3b3c47da398ca/preshed-3.0.12-cp312-cp312-win_amd64.whl", hash = "sha256:0eb8d411afcb1e3b12a0602fb6a0e33140342a732a795251a0ce452aba401dc0", size = 118298, upload-time = "2025-11-17T12:59:42.65Z" }, + { url = "https://files.pythonhosted.org/packages/0b/34/4e8443fe99206a2fcfc63659969a8f8c8ab184836533594a519f3899b1ad/preshed-3.0.12-cp312-cp312-win_arm64.whl", hash = "sha256:dcd3d12903c9f720a39a5c5f1339f7f46e3ab71279fb7a39776768fb840b6077", size = 104746, upload-time = "2025-11-17T12:59:43.934Z" }, ] [[package]] @@ -5196,42 +5306,44 @@ wheels = [ [[package]] name = "proto-plus" -version = "1.26.1" +version = "1.27.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142, upload-time = "2025-03-10T15:54:38.843Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/02/8832cde80e7380c600fbf55090b6ab7b62bd6825dbedde6d6657c15a1f8e/proto_plus-1.27.1.tar.gz", hash = "sha256:912a7460446625b792f6448bade9e55cd4e41e6ac10e27009ef71a7f317fa147", size = 56929, upload-time = "2026-02-02T17:34:49.035Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163, upload-time = "2025-03-10T15:54:37.335Z" }, + { url = "https://files.pythonhosted.org/packages/5d/79/ac273cbbf744691821a9cca88957257f41afe271637794975ca090b9588b/proto_plus-1.27.1-py3-none-any.whl", hash = "sha256:e4643061f3a4d0de092d62aa4ad09fa4756b2cbb89d4627f3985018216f9fefc", size = 50480, upload-time = "2026-02-02T17:34:47.339Z" }, ] [[package]] name = "protobuf" -version = "4.25.8" +version = "5.29.6" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/df/01/34c8d2b6354906d728703cb9d546a0e534de479e25f1b581e4094c4a85cc/protobuf-4.25.8.tar.gz", hash = "sha256:6135cf8affe1fc6f76cced2641e4ea8d3e59518d1f24ae41ba97bcad82d397cd", size = 380920, upload-time = "2025-05-28T14:22:25.153Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7e/57/394a763c103e0edf87f0938dafcd918d53b4c011dfc5c8ae80f3b0452dbb/protobuf-5.29.6.tar.gz", hash = "sha256:da9ee6a5424b6b30fd5e45c5ea663aef540ca95f9ad99d1e887e819cdf9b8723", size = 425623, upload-time = "2026-02-04T22:54:40.584Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/45/ff/05f34305fe6b85bbfbecbc559d423a5985605cad5eda4f47eae9e9c9c5c5/protobuf-4.25.8-cp310-abi3-win32.whl", hash = "sha256:504435d831565f7cfac9f0714440028907f1975e4bed228e58e72ecfff58a1e0", size = 392745, upload-time = "2025-05-28T14:22:10.524Z" }, - { url = "https://files.pythonhosted.org/packages/08/35/8b8a8405c564caf4ba835b1fdf554da869954712b26d8f2a98c0e434469b/protobuf-4.25.8-cp310-abi3-win_amd64.whl", hash = "sha256:bd551eb1fe1d7e92c1af1d75bdfa572eff1ab0e5bf1736716814cdccdb2360f9", size = 413736, upload-time = "2025-05-28T14:22:13.156Z" }, - { url = "https://files.pythonhosted.org/packages/28/d7/ab27049a035b258dab43445eb6ec84a26277b16105b277cbe0a7698bdc6c/protobuf-4.25.8-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:ca809b42f4444f144f2115c4c1a747b9a404d590f18f37e9402422033e464e0f", size = 394537, upload-time = "2025-05-28T14:22:14.768Z" }, - { url = "https://files.pythonhosted.org/packages/bd/6d/a4a198b61808dd3d1ee187082ccc21499bc949d639feb948961b48be9a7e/protobuf-4.25.8-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:9ad7ef62d92baf5a8654fbb88dac7fa5594cfa70fd3440488a5ca3bfc6d795a7", size = 294005, upload-time = "2025-05-28T14:22:16.052Z" }, - { url = "https://files.pythonhosted.org/packages/d6/c6/c9deaa6e789b6fc41b88ccbdfe7a42d2b82663248b715f55aa77fbc00724/protobuf-4.25.8-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:83e6e54e93d2b696a92cad6e6efc924f3850f82b52e1563778dfab8b355101b0", size = 294924, upload-time = "2025-05-28T14:22:17.105Z" }, - { url = "https://files.pythonhosted.org/packages/0c/c1/6aece0ab5209981a70cd186f164c133fdba2f51e124ff92b73de7fd24d78/protobuf-4.25.8-py3-none-any.whl", hash = "sha256:15a0af558aa3b13efef102ae6e4f3efac06f1eea11afb3a57db2901447d9fb59", size = 156757, upload-time = "2025-05-28T14:22:24.135Z" }, + { url = "https://files.pythonhosted.org/packages/d4/88/9ee58ff7863c479d6f8346686d4636dd4c415b0cbeed7a6a7d0617639c2a/protobuf-5.29.6-cp310-abi3-win32.whl", hash = "sha256:62e8a3114992c7c647bce37dcc93647575fc52d50e48de30c6fcb28a6a291eb1", size = 423357, upload-time = "2026-02-04T22:54:25.805Z" }, + { url = "https://files.pythonhosted.org/packages/1c/66/2dc736a4d576847134fb6d80bd995c569b13cdc7b815d669050bf0ce2d2c/protobuf-5.29.6-cp310-abi3-win_amd64.whl", hash = "sha256:7e6ad413275be172f67fdee0f43484b6de5a904cc1c3ea9804cb6fe2ff366eda", size = 435175, upload-time = "2026-02-04T22:54:28.592Z" }, + { url = "https://files.pythonhosted.org/packages/06/db/49b05966fd208ae3f44dcd33837b6243b4915c57561d730a43f881f24dea/protobuf-5.29.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:b5a169e664b4057183a34bdc424540e86eea47560f3c123a0d64de4e137f9269", size = 418619, upload-time = "2026-02-04T22:54:30.266Z" }, + { url = "https://files.pythonhosted.org/packages/b7/d7/48cbf6b0c3c39761e47a99cb483405f0fde2be22cf00d71ef316ce52b458/protobuf-5.29.6-cp38-abi3-manylinux2014_aarch64.whl", hash = "sha256:a8866b2cff111f0f863c1b3b9e7572dc7eaea23a7fae27f6fc613304046483e6", size = 320284, upload-time = "2026-02-04T22:54:31.782Z" }, + { url = "https://files.pythonhosted.org/packages/e3/dd/cadd6ec43069247d91f6345fa7a0d2858bef6af366dbd7ba8f05d2c77d3b/protobuf-5.29.6-cp38-abi3-manylinux2014_x86_64.whl", hash = "sha256:e3387f44798ac1106af0233c04fb8abf543772ff241169946f698b3a9a3d3ab9", size = 320478, upload-time = "2026-02-04T22:54:32.909Z" }, + { url = "https://files.pythonhosted.org/packages/5a/cb/e3065b447186cb70aa65acc70c86baf482d82bf75625bf5a2c4f6919c6a3/protobuf-5.29.6-py3-none-any.whl", hash = "sha256:6b9edb641441b2da9fa8f428760fc136a49cf97a52076010cf22a2ff73438a86", size = 173126, upload-time = "2026-02-04T22:54:39.462Z" }, ] [[package]] name = "psutil" -version = "7.1.3" +version = "7.2.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e1/88/bdd0a41e5857d5d703287598cbf08dad90aed56774ea52ae071bae9071b6/psutil-7.1.3.tar.gz", hash = "sha256:6c86281738d77335af7aec228328e944b30930899ea760ecf33a4dba66be5e74", size = 489059, upload-time = "2025-11-02T12:25:54.619Z" } +sdist = { url = "https://files.pythonhosted.org/packages/aa/c6/d1ddf4abb55e93cebc4f2ed8b5d6dbad109ecb8d63748dd2b20ab5e57ebe/psutil-7.2.2.tar.gz", hash = "sha256:0746f5f8d406af344fd547f1c8daa5f5c33dbc293bb8d6a16d80b4bb88f59372", size = 493740, upload-time = "2026-01-28T18:14:54.428Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/94/46b9154a800253e7ecff5aaacdf8ebf43db99de4a2dfa18575b02548654e/psutil-7.1.3-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2bdbcd0e58ca14996a42adf3621a6244f1bb2e2e528886959c72cf1e326677ab", size = 238359, upload-time = "2025-11-02T12:26:25.284Z" }, - { url = "https://files.pythonhosted.org/packages/68/3a/9f93cff5c025029a36d9a92fef47220ab4692ee7f2be0fba9f92813d0cb8/psutil-7.1.3-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:bc31fa00f1fbc3c3802141eede66f3a2d51d89716a194bf2cd6fc68310a19880", size = 239171, upload-time = "2025-11-02T12:26:27.23Z" }, - { url = "https://files.pythonhosted.org/packages/ce/b1/5f49af514f76431ba4eea935b8ad3725cdeb397e9245ab919dbc1d1dc20f/psutil-7.1.3-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3bb428f9f05c1225a558f53e30ccbad9930b11c3fc206836242de1091d3e7dd3", size = 263261, upload-time = "2025-11-02T12:26:29.48Z" }, - { url = "https://files.pythonhosted.org/packages/e0/95/992c8816a74016eb095e73585d747e0a8ea21a061ed3689474fabb29a395/psutil-7.1.3-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:56d974e02ca2c8eb4812c3f76c30e28836fffc311d55d979f1465c1feeb2b68b", size = 264635, upload-time = "2025-11-02T12:26:31.74Z" }, - { url = "https://files.pythonhosted.org/packages/55/4c/c3ed1a622b6ae2fd3c945a366e64eb35247a31e4db16cf5095e269e8eb3c/psutil-7.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:f39c2c19fe824b47484b96f9692932248a54c43799a84282cfe58d05a6449efd", size = 247633, upload-time = "2025-11-02T12:26:33.887Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ad/33b2ccec09bf96c2b2ef3f9a6f66baac8253d7565d8839e024a6b905d45d/psutil-7.1.3-cp37-abi3-win_arm64.whl", hash = "sha256:bd0d69cee829226a761e92f28140bec9a5ee9d5b4fb4b0cc589068dbfff559b1", size = 244608, upload-time = "2025-11-02T12:26:36.136Z" }, + { url = "https://files.pythonhosted.org/packages/e7/36/5ee6e05c9bd427237b11b3937ad82bb8ad2752d72c6969314590dd0c2f6e/psutil-7.2.2-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ed0cace939114f62738d808fdcecd4c869222507e266e574799e9c0faa17d486", size = 129090, upload-time = "2026-01-28T18:15:22.168Z" }, + { url = "https://files.pythonhosted.org/packages/80/c4/f5af4c1ca8c1eeb2e92ccca14ce8effdeec651d5ab6053c589b074eda6e1/psutil-7.2.2-cp36-abi3-macosx_11_0_arm64.whl", hash = "sha256:1a7b04c10f32cc88ab39cbf606e117fd74721c831c98a27dc04578deb0c16979", size = 129859, upload-time = "2026-01-28T18:15:23.795Z" }, + { url = "https://files.pythonhosted.org/packages/b5/70/5d8df3b09e25bce090399cf48e452d25c935ab72dad19406c77f4e828045/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:076a2d2f923fd4821644f5ba89f059523da90dc9014e85f8e45a5774ca5bc6f9", size = 155560, upload-time = "2026-01-28T18:15:25.976Z" }, + { url = "https://files.pythonhosted.org/packages/63/65/37648c0c158dc222aba51c089eb3bdfa238e621674dc42d48706e639204f/psutil-7.2.2-cp36-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b0726cecd84f9474419d67252add4ac0cd9811b04d61123054b9fb6f57df6e9e", size = 156997, upload-time = "2026-01-28T18:15:27.794Z" }, + { url = "https://files.pythonhosted.org/packages/8e/13/125093eadae863ce03c6ffdbae9929430d116a246ef69866dad94da3bfbc/psutil-7.2.2-cp36-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:fd04ef36b4a6d599bbdb225dd1d3f51e00105f6d48a28f006da7f9822f2606d8", size = 148972, upload-time = "2026-01-28T18:15:29.342Z" }, + { url = "https://files.pythonhosted.org/packages/04/78/0acd37ca84ce3ddffaa92ef0f571e073faa6d8ff1f0559ab1272188ea2be/psutil-7.2.2-cp36-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:b58fabe35e80b264a4e3bb23e6b96f9e45a3df7fb7eed419ac0e5947c61e47cc", size = 148266, upload-time = "2026-01-28T18:15:31.597Z" }, + { url = "https://files.pythonhosted.org/packages/b4/90/e2159492b5426be0c1fef7acba807a03511f97c5f86b3caeda6ad92351a7/psutil-7.2.2-cp37-abi3-win_amd64.whl", hash = "sha256:eb7e81434c8d223ec4a219b5fc1c47d0417b12be7ea866e24fb5ad6e84b3d988", size = 137737, upload-time = "2026-01-28T18:15:33.849Z" }, + { url = "https://files.pythonhosted.org/packages/8c/c7/7bb2e321574b10df20cbde462a94e2b71d05f9bbda251ef27d104668306a/psutil-7.2.2-cp37-abi3-win_arm64.whl", hash = "sha256:8c233660f575a5a89e6d4cb65d9f938126312bca76d8fe087b947b3a1aaac9ee", size = 134617, upload-time = "2026-01-28T18:15:36.514Z" }, ] [[package]] @@ -5240,6 +5352,53 @@ version = "1.0.2" source = { registry = "https://pypi.org/simple" } sdist = { url = "https://files.pythonhosted.org/packages/eb/72/4a7965cf54e341006ad74cdc72cd6572c789bc4f4e3fadc78672f1fbcfbd/psycogreen-1.0.2.tar.gz", hash = "sha256:c429845a8a49cf2f76b71265008760bcd7c7c77d80b806db4dc81116dbcd130d", size = 5411, upload-time = "2020-02-22T19:55:22.02Z" } +[[package]] +name = "psycopg" +version = "3.3.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d3/b6/379d0a960f8f435ec78720462fd94c4863e7a31237cf81bf76d0af5883bf/psycopg-3.3.3.tar.gz", hash = "sha256:5e9a47458b3c1583326513b2556a2a9473a1001a56c9efe9e587245b43148dd9", size = 165624, upload-time = "2026-02-18T16:52:16.546Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/5b/181e2e3becb7672b502f0ed7f16ed7352aca7c109cfb94cf3878a9186db9/psycopg-3.3.3-py3-none-any.whl", hash = "sha256:f96525a72bcfade6584ab17e89de415ff360748c766f0106959144dcbb38c698", size = 212768, upload-time = "2026-02-18T16:46:27.365Z" }, +] + +[package.optional-dependencies] +binary = [ + { name = "psycopg-binary", marker = "implementation_name != 'pypy'" }, +] + +[[package]] +name = "psycopg-binary" +version = "3.3.3" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/be/c0/b389119dd754483d316805260f3e73cdcad97925839107cc7a296f6132b1/psycopg_binary-3.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a89bb9ee11177b2995d87186b1d9fa892d8ea725e85eab28c6525e4cc14ee048", size = 4609740, upload-time = "2026-02-18T16:47:51.093Z" }, + { url = "https://files.pythonhosted.org/packages/cf/e3/9976eef20f61840285174d360da4c820a311ab39d6b82fa09fbb545be825/psycopg_binary-3.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9f7d0cf072c6fbac3795b08c98ef9ea013f11db609659dcfc6b1f6cc31f9e181", size = 4676837, upload-time = "2026-02-18T16:47:55.523Z" }, + { url = "https://files.pythonhosted.org/packages/9f/f2/d28ba2f7404fd7f68d41e8a11df86313bd646258244cb12a8dd83b868a97/psycopg_binary-3.3.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:90eecd93073922f085967f3ed3a98ba8c325cbbc8c1a204e300282abd2369e13", size = 5497070, upload-time = "2026-02-18T16:47:59.929Z" }, + { url = "https://files.pythonhosted.org/packages/de/2f/6c5c54b815edeb30a281cfcea96dc93b3bb6be939aea022f00cab7aa1420/psycopg_binary-3.3.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:dac7ee2f88b4d7bb12837989ca354c38d400eeb21bce3b73dac02622f0a3c8d6", size = 5172410, upload-time = "2026-02-18T16:48:05.665Z" }, + { url = "https://files.pythonhosted.org/packages/51/75/8206c7008b57de03c1ada46bd3110cc3743f3fd9ed52031c4601401d766d/psycopg_binary-3.3.3-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b62cf8784eb6d35beaee1056d54caf94ec6ecf2b7552395e305518ab61eb8fd2", size = 6763408, upload-time = "2026-02-18T16:48:13.541Z" }, + { url = "https://files.pythonhosted.org/packages/d4/5a/ea1641a1e6c8c8b3454b0fcb43c3045133a8b703e6e824fae134088e63bd/psycopg_binary-3.3.3-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a39f34c9b18e8f6794cca17bfbcd64572ca2482318db644268049f8c738f35a6", size = 5006255, upload-time = "2026-02-18T16:48:22.176Z" }, + { url = "https://files.pythonhosted.org/packages/aa/fb/538df099bf55ae1637d52d7ccb6b9620b535a40f4c733897ac2b7bb9e14c/psycopg_binary-3.3.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:883d68d48ca9ff3cb3d10c5fdebea02c79b48eecacdddbf7cce6e7cdbdc216b8", size = 4532694, upload-time = "2026-02-18T16:48:27.338Z" }, + { url = "https://files.pythonhosted.org/packages/a1/d1/00780c0e187ea3c13dfc53bd7060654b2232cd30df562aac91a5f1c545ac/psycopg_binary-3.3.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:cab7bc3d288d37a80aa8c0820033250c95e40b1c2b5c57cf59827b19c2a8b69d", size = 4222833, upload-time = "2026-02-18T16:48:31.221Z" }, + { url = "https://files.pythonhosted.org/packages/7a/34/a07f1ff713c51d64dc9f19f2c32be80299a2055d5d109d5853662b922cb4/psycopg_binary-3.3.3-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:56c767007ca959ca32f796b42379fc7e1ae2ed085d29f20b05b3fc394f3715cc", size = 3952818, upload-time = "2026-02-18T16:48:35.869Z" }, + { url = "https://files.pythonhosted.org/packages/d3/67/d33f268a7759b4445f3c9b5a181039b01af8c8263c865c1be7a6444d4749/psycopg_binary-3.3.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:da2f331a01af232259a21573a01338530c6016dcfad74626c01330535bcd8628", size = 4258061, upload-time = "2026-02-18T16:48:41.365Z" }, + { url = "https://files.pythonhosted.org/packages/b4/3b/0d8d2c5e8e29ccc07d28c8af38445d9d9abcd238d590186cac82ee71fc84/psycopg_binary-3.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:19f93235ece6dbfc4036b5e4f6d8b13f0b8f2b3eeb8b0bd2936d406991bcdd40", size = 3558915, upload-time = "2026-02-18T16:48:46.679Z" }, + { url = "https://files.pythonhosted.org/packages/90/15/021be5c0cbc5b7c1ab46e91cc3434eb42569f79a0592e67b8d25e66d844d/psycopg_binary-3.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6698dbab5bcef8fdb570fc9d35fd9ac52041771bfcfe6fd0fc5f5c4e36f1e99d", size = 4591170, upload-time = "2026-02-18T16:48:55.594Z" }, + { url = "https://files.pythonhosted.org/packages/f1/54/a60211c346c9a2f8c6b272b5f2bbe21f6e11800ce7f61e99ba75cf8b63e1/psycopg_binary-3.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:329ff393441e75f10b673ae99ab45276887993d49e65f141da20d915c05aafd8", size = 4670009, upload-time = "2026-02-18T16:49:03.608Z" }, + { url = "https://files.pythonhosted.org/packages/c1/53/ac7c18671347c553362aadbf65f92786eef9540676ca24114cc02f5be405/psycopg_binary-3.3.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:eb072949b8ebf4082ae24289a2b0fd724da9adc8f22743409d6fd718ddb379df", size = 5469735, upload-time = "2026-02-18T16:49:10.128Z" }, + { url = "https://files.pythonhosted.org/packages/7f/c3/4f4e040902b82a344eff1c736cde2f2720f127fe939c7e7565706f96dd44/psycopg_binary-3.3.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:263a24f39f26e19ed7fc982d7859a36f17841b05bebad3eb47bb9cd2dd785351", size = 5152919, upload-time = "2026-02-18T16:49:16.335Z" }, + { url = "https://files.pythonhosted.org/packages/0c/e7/d929679c6a5c212bcf738806c7c89f5b3d0919f2e1685a0e08d6ff877945/psycopg_binary-3.3.3-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5152d50798c2fa5bd9b68ec68eb68a1b71b95126c1d70adaa1a08cd5eefdc23d", size = 6738785, upload-time = "2026-02-18T16:49:22.687Z" }, + { url = "https://files.pythonhosted.org/packages/69/b0/09703aeb69a9443d232d7b5318d58742e8ca51ff79f90ffe6b88f1db45e7/psycopg_binary-3.3.3-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:9d6a1e56dd267848edb824dbeb08cf5bac649e02ee0b03ba883ba3f4f0bd54f2", size = 4979008, upload-time = "2026-02-18T16:49:27.313Z" }, + { url = "https://files.pythonhosted.org/packages/cc/a6/e662558b793c6e13a7473b970fee327d635270e41eded3090ef14045a6a5/psycopg_binary-3.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73eaaf4bb04709f545606c1db2f65f4000e8a04cdbf3e00d165a23004692093e", size = 4508255, upload-time = "2026-02-18T16:49:31.575Z" }, + { url = "https://files.pythonhosted.org/packages/5f/7f/0f8b2e1d5e0093921b6f324a948a5c740c1447fbb45e97acaf50241d0f39/psycopg_binary-3.3.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:162e5675efb4704192411eaf8e00d07f7960b679cd3306e7efb120bb8d9456cc", size = 4189166, upload-time = "2026-02-18T16:49:35.801Z" }, + { url = "https://files.pythonhosted.org/packages/92/ec/ce2e91c33bc8d10b00c87e2f6b0fb570641a6a60042d6a9ae35658a3a797/psycopg_binary-3.3.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:fab6b5e37715885c69f5d091f6ff229be71e235f272ebaa35158d5a46fd548a0", size = 3924544, upload-time = "2026-02-18T16:49:41.129Z" }, + { url = "https://files.pythonhosted.org/packages/c5/2f/7718141485f73a924205af60041c392938852aa447a94c8cbd222ff389a1/psycopg_binary-3.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a4aab31bd6d1057f287c96c0effca3a25584eb9cc702f282ecb96ded7814e830", size = 4235297, upload-time = "2026-02-18T16:49:46.726Z" }, + { url = "https://files.pythonhosted.org/packages/57/f9/1add717e2643a003bbde31b1b220172e64fbc0cb09f06429820c9173f7fc/psycopg_binary-3.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:59aa31fe11a0e1d1bcc2ce37ed35fe2ac84cd65bb9036d049b1a1c39064d0f14", size = 3547659, upload-time = "2026-02-18T16:49:52.999Z" }, +] + [[package]] name = "psycopg2-binary" version = "2.9.11" @@ -5290,33 +5449,45 @@ wheels = [ [[package]] name = "pyarrow" -version = "23.0.1" +version = "14.0.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/88/22/134986a4cc224d593c1afde5494d18ff629393d74cc2eddb176669f234a4/pyarrow-23.0.1.tar.gz", hash = "sha256:b8c5873e33440b2bc2f4a79d2b47017a89c5a24116c055625e6f2ee50523f019", size = 1167336, upload-time = "2026-02-16T10:14:12.39Z" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d7/8b/d18b7eb6fb22e5ed6ffcbc073c85dae635778dbd1270a6cf5d750b031e84/pyarrow-14.0.2.tar.gz", hash = "sha256:36cef6ba12b499d864d1def3e990f97949e0b79400d08b7cf74504ffbd3eb025", size = 1063645, upload-time = "2023-12-18T15:43:41.625Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b0/41/8e6b6ef7e225d4ceead8459427a52afdc23379768f54dd3566014d7618c1/pyarrow-23.0.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:6f0147ee9e0386f519c952cc670eb4a8b05caa594eeffe01af0e25f699e4e9bb", size = 34302230, upload-time = "2026-02-16T10:09:03.859Z" }, - { url = "https://files.pythonhosted.org/packages/bf/4a/1472c00392f521fea03ae93408bf445cc7bfa1ab81683faf9bc188e36629/pyarrow-23.0.1-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:0ae6e17c828455b6265d590100c295193f93cc5675eb0af59e49dbd00d2de350", size = 35850050, upload-time = "2026-02-16T10:09:11.877Z" }, - { url = "https://files.pythonhosted.org/packages/0c/b2/bd1f2f05ded56af7f54d702c8364c9c43cd6abb91b0e9933f3d77b4f4132/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:fed7020203e9ef273360b9e45be52a2a47d3103caf156a30ace5247ffb51bdbd", size = 44491918, upload-time = "2026-02-16T10:09:18.144Z" }, - { url = "https://files.pythonhosted.org/packages/0b/62/96459ef5b67957eac38a90f541d1c28833d1b367f014a482cb63f3b7cd2d/pyarrow-23.0.1-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:26d50dee49d741ac0e82185033488d28d35be4d763ae6f321f97d1140eb7a0e9", size = 47562811, upload-time = "2026-02-16T10:09:25.792Z" }, - { url = "https://files.pythonhosted.org/packages/7d/94/1170e235add1f5f45a954e26cd0e906e7e74e23392dcb560de471f7366ec/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:3c30143b17161310f151f4a2bcfe41b5ff744238c1039338779424e38579d701", size = 48183766, upload-time = "2026-02-16T10:09:34.645Z" }, - { url = "https://files.pythonhosted.org/packages/0e/2d/39a42af4570377b99774cdb47f63ee6c7da7616bd55b3d5001aa18edfe4f/pyarrow-23.0.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:db2190fa79c80a23fdd29fef4b8992893f024ae7c17d2f5f4db7171fa30c2c78", size = 50607669, upload-time = "2026-02-16T10:09:44.153Z" }, - { url = "https://files.pythonhosted.org/packages/00/ca/db94101c187f3df742133ac837e93b1f269ebdac49427f8310ee40b6a58f/pyarrow-23.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:f00f993a8179e0e1c9713bcc0baf6d6c01326a406a9c23495ec1ba9c9ebf2919", size = 27527698, upload-time = "2026-02-16T10:09:50.263Z" }, - { url = "https://files.pythonhosted.org/packages/9a/4b/4166bb5abbfe6f750fc60ad337c43ecf61340fa52ab386da6e8dbf9e63c4/pyarrow-23.0.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:f4b0dbfa124c0bb161f8b5ebb40f1a680b70279aa0c9901d44a2b5a20806039f", size = 34214575, upload-time = "2026-02-16T10:09:56.225Z" }, - { url = "https://files.pythonhosted.org/packages/e1/da/3f941e3734ac8088ea588b53e860baeddac8323ea40ce22e3d0baa865cc9/pyarrow-23.0.1-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:7707d2b6673f7de054e2e83d59f9e805939038eebe1763fe811ee8fa5c0cd1a7", size = 35832540, upload-time = "2026-02-16T10:10:03.428Z" }, - { url = "https://files.pythonhosted.org/packages/88/7c/3d841c366620e906d54430817531b877ba646310296df42ef697308c2705/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:86ff03fb9f1a320266e0de855dee4b17da6794c595d207f89bba40d16b5c78b9", size = 44470940, upload-time = "2026-02-16T10:10:10.704Z" }, - { url = "https://files.pythonhosted.org/packages/2c/a5/da83046273d990f256cb79796a190bbf7ec999269705ddc609403f8c6b06/pyarrow-23.0.1-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:813d99f31275919c383aab17f0f455a04f5a429c261cc411b1e9a8f5e4aaaa05", size = 47586063, upload-time = "2026-02-16T10:10:17.95Z" }, - { url = "https://files.pythonhosted.org/packages/5b/3c/b7d2ebcff47a514f47f9da1e74b7949138c58cfeb108cdd4ee62f43f0cf3/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bf5842f960cddd2ef757d486041d57c96483efc295a8c4a0e20e704cbbf39c67", size = 48173045, upload-time = "2026-02-16T10:10:25.363Z" }, - { url = "https://files.pythonhosted.org/packages/43/b2/b40961262213beaba6acfc88698eb773dfce32ecdf34d19291db94c2bd73/pyarrow-23.0.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:564baf97c858ecc03ec01a41062e8f4698abc3e6e2acd79c01c2e97880a19730", size = 50621741, upload-time = "2026-02-16T10:10:33.477Z" }, - { url = "https://files.pythonhosted.org/packages/f6/70/1fdda42d65b28b078e93d75d371b2185a61da89dda4def8ba6ba41ebdeb4/pyarrow-23.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:07deae7783782ac7250989a7b2ecde9b3c343a643f82e8a4df03d93b633006f0", size = 27620678, upload-time = "2026-02-16T10:10:39.31Z" }, + { url = "https://files.pythonhosted.org/packages/94/8a/411ef0b05483076b7f548c74ccaa0f90c1e60d3875db71a821f6ffa8cf42/pyarrow-14.0.2-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:87482af32e5a0c0cce2d12eb3c039dd1d853bd905b04f3f953f147c7a196915b", size = 26904455, upload-time = "2023-12-18T15:40:43.477Z" }, + { url = "https://files.pythonhosted.org/packages/6c/6c/882a57798877e3a49ba54d8e0540bea24aed78fb42e1d860f08c3449c75e/pyarrow-14.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:059bd8f12a70519e46cd64e1ba40e97eae55e0cbe1695edd95384653d7626b23", size = 23997116, upload-time = "2023-12-18T15:40:48.533Z" }, + { url = "https://files.pythonhosted.org/packages/ec/3f/ef47fe6192ce4d82803a073db449b5292135406c364a7fc49dfbcd34c987/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3f16111f9ab27e60b391c5f6d197510e3ad6654e73857b4e394861fc79c37200", size = 35944575, upload-time = "2023-12-18T15:40:55.128Z" }, + { url = "https://files.pythonhosted.org/packages/1a/90/2021e529d7f234a3909f419d4341d53382541ef77d957fa274a99c533b18/pyarrow-14.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:06ff1264fe4448e8d02073f5ce45a9f934c0f3db0a04460d0b01ff28befc3696", size = 38079719, upload-time = "2023-12-18T15:41:02.565Z" }, + { url = "https://files.pythonhosted.org/packages/30/a9/474caf5fd54a6d5315aaf9284c6e8f5d071ca825325ad64c53137b646e1f/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6dd4f4b472ccf4042f1eab77e6c8bce574543f54d2135c7e396f413046397d5a", size = 35429706, upload-time = "2023-12-18T15:41:09.955Z" }, + { url = "https://files.pythonhosted.org/packages/d9/f8/cfba56f5353e51c19b0c240380ce39483f4c76e5c4aee5a000f3d75b72da/pyarrow-14.0.2-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:32356bfb58b36059773f49e4e214996888eeea3a08893e7dbde44753799b2a02", size = 38001476, upload-time = "2023-12-18T15:41:16.372Z" }, + { url = "https://files.pythonhosted.org/packages/43/3f/7bdf7dc3b3b0cfdcc60760e7880954ba99ccd0bc1e0df806f3dd61bc01cd/pyarrow-14.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:52809ee69d4dbf2241c0e4366d949ba035cbcf48409bf404f071f624ed313a2b", size = 24576230, upload-time = "2023-12-18T15:41:22.561Z" }, + { url = "https://files.pythonhosted.org/packages/69/5b/d8ab6c20c43b598228710e4e4a6cba03a01f6faa3d08afff9ce76fd0fd47/pyarrow-14.0.2-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:c87824a5ac52be210d32906c715f4ed7053d0180c1060ae3ff9b7e560f53f944", size = 26819585, upload-time = "2023-12-18T15:41:27.59Z" }, + { url = "https://files.pythonhosted.org/packages/2d/29/bed2643d0dd5e9570405244a61f6db66c7f4704a6e9ce313f84fa5a3675a/pyarrow-14.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a25eb2421a58e861f6ca91f43339d215476f4fe159eca603c55950c14f378cc5", size = 23965222, upload-time = "2023-12-18T15:41:32.449Z" }, + { url = "https://files.pythonhosted.org/packages/2a/34/da464632e59a8cdd083370d69e6c14eae30221acb284f671c6bc9273fadd/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c1da70d668af5620b8ba0a23f229030a4cd6c5f24a616a146f30d2386fec422", size = 35942036, upload-time = "2023-12-18T15:41:38.767Z" }, + { url = "https://files.pythonhosted.org/packages/a8/ff/cbed4836d543b29f00d2355af67575c934999ff1d43e3f438ab0b1b394f1/pyarrow-14.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2cc61593c8e66194c7cdfae594503e91b926a228fba40b5cf25cc593563bcd07", size = 38089266, upload-time = "2023-12-18T15:41:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/38/41/345011cb831d3dbb2dab762fc244c745a5df94b199223a99af52a5f7dff6/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:78ea56f62fb7c0ae8ecb9afdd7893e3a7dbeb0b04106f5c08dbb23f9c0157591", size = 35404468, upload-time = "2023-12-18T15:41:54.49Z" }, + { url = "https://files.pythonhosted.org/packages/fd/af/2fc23ca2068ff02068d8dabf0fb85b6185df40ec825973470e613dbd8790/pyarrow-14.0.2-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:37c233ddbce0c67a76c0985612fef27c0c92aef9413cf5aa56952f359fcb7379", size = 38003134, upload-time = "2023-12-18T15:42:01.593Z" }, + { url = "https://files.pythonhosted.org/packages/95/1f/9d912f66a87e3864f694e000977a6a70a644ea560289eac1d733983f215d/pyarrow-14.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:e4b123ad0f6add92de898214d404e488167b87b5dd86e9a434126bc2b7a5578d", size = 25043754, upload-time = "2023-12-18T15:42:07.108Z" }, +] + +[[package]] +name = "pyarrow-hotfix" +version = "0.7" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/ed/c3e8677f7abf3981838c2af7b5ac03e3589b3ef94fcb31d575426abae904/pyarrow_hotfix-0.7.tar.gz", hash = "sha256:59399cd58bdd978b2e42816a4183a55c6472d4e33d183351b6069f11ed42661d", size = 9910, upload-time = "2025-04-25T10:17:06.247Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2e/c3/94ade4906a2f88bc935772f59c934013b4205e773bcb4239db114a6da136/pyarrow_hotfix-0.7-py3-none-any.whl", hash = "sha256:3236f3b5f1260f0e2ac070a55c1a7b339c4bb7267839bd2015e283234e758100", size = 7923, upload-time = "2025-04-25T10:17:05.224Z" }, ] [[package]] name = "pyasn1" -version = "0.6.2" +version = "0.6.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/b6/6e630dff89739fcd427e3f72b3d905ce0acb85a45d4ec3e2678718a3487f/pyasn1-0.6.2.tar.gz", hash = "sha256:9b59a2b25ba7e4f8197db7686c09fb33e658b98339fadb826e9512629017833b", size = 146586, upload-time = "2026-01-16T18:04:18.534Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/5f/6583902b6f79b399c9c40674ac384fd9cd77805f9e6205075f828ef11fb2/pyasn1-0.6.3.tar.gz", hash = "sha256:697a8ecd6d98891189184ca1fa05d1bb00e2f84b5977c481452050549c8a72cf", size = 148685, upload-time = "2026-03-17T01:06:53.382Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/44/b5/a96872e5184f354da9c84ae119971a0a4c221fe9b27a4d94bd43f2596727/pyasn1-0.6.2-py3-none-any.whl", hash = "sha256:1eb26d860996a18e9b6ed05e7aae0e9fc21619fcee6af91cca9bad4fbea224bf", size = 83371, upload-time = "2026-01-16T18:04:17.174Z" }, + { url = "https://files.pythonhosted.org/packages/5d/a0/7d793dce3fa811fe047d6ae2431c672364b462850c6235ae306c0efd025f/pyasn1-0.6.3-py3-none-any.whl", hash = "sha256:a80184d120f0864a52a073acc6fc642847d0be408e7c7252f31390c0f4eadcde", size = 83997, upload-time = "2026-03-17T01:06:52.036Z" }, ] [[package]] @@ -5333,11 +5504,11 @@ wheels = [ [[package]] name = "pycparser" -version = "2.23" +version = "3.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/fe/cf/d2d3b9f5699fb1e4615c8e32ff220203e43b248e1dfcc6736ad9057731ca/pycparser-2.23.tar.gz", hash = "sha256:78816d4f24add8f10a06d6f05b4d424ad9e96cfebf68a4ddc99c65c0720d00c2", size = 173734, upload-time = "2025-09-09T13:23:47.91Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1b/7d/92392ff7815c21062bea51aa7b87d45576f649f16458d78b7cf94b9ab2e6/pycparser-3.0.tar.gz", hash = "sha256:600f49d217304a5902ac3c37e1281c9fe94e4d0489de643a9504c5cdfdfc6b29", size = 103492, upload-time = "2026-01-21T14:26:51.89Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/e3/59cd50310fc9b59512193629e1984c1f95e5c8ae6e5d8c69532ccc65a7fe/pycparser-2.23-py3-none-any.whl", hash = "sha256:e5c6e8d3fbad53479cab09ac03729e0a9faf2bee3db8208a550daf5af81a5934", size = 118140, upload-time = "2025-09-09T13:23:46.651Z" }, + { url = "https://files.pythonhosted.org/packages/0c/c3/44f3fbbfa403ea2a7c779186dc20772604442dde72947e7d01069cbe98e3/pycparser-3.0-py3-none-any.whl", hash = "sha256:b727414169a36b7d524c1c3e31839a521725078d7b2ff038656844266160a992", size = 48172, upload-time = "2026-01-21T14:26:50.693Z" }, ] [[package]] @@ -5431,29 +5602,38 @@ wheels = [ [[package]] name = "pydantic-extra-types" -version = "2.10.6" +version = "2.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/3a/10/fb64987804cde41bcc39d9cd757cd5f2bb5d97b389d81aa70238b14b8a7e/pydantic_extra_types-2.10.6.tar.gz", hash = "sha256:c63d70bf684366e6bbe1f4ee3957952ebe6973d41e7802aea0b770d06b116aeb", size = 141858, upload-time = "2025-10-08T13:47:49.483Z" } +sdist = { url = "https://files.pythonhosted.org/packages/66/71/dba38ee2651f84f7842206adbd2233d8bbdb59fb85e9fa14232486a8c471/pydantic_extra_types-2.11.1.tar.gz", hash = "sha256:46792d2307383859e923d8fcefa82108b1a141f8a9c0198982b3832ab5ef1049", size = 172002, upload-time = "2026-03-16T08:08:03.92Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/93/04/5c918669096da8d1c9ec7bb716bd72e755526103a61bc5e76a3e4fb23b53/pydantic_extra_types-2.10.6-py3-none-any.whl", hash = "sha256:6106c448316d30abf721b5b9fecc65e983ef2614399a24142d689c7546cc246a", size = 40949, upload-time = "2025-10-08T13:47:48.268Z" }, + { url = "https://files.pythonhosted.org/packages/17/c1/3226e6d7f5a4f736f38ac11a6fbb262d701889802595cdb0f53a885ac2e0/pydantic_extra_types-2.11.1-py3-none-any.whl", hash = "sha256:1722ea2bddae5628ace25f2aa685b69978ef533123e5638cfbddb999e0100ec1", size = 79526, upload-time = "2026-03-16T08:08:02.533Z" }, ] [[package]] name = "pydantic-settings" -version = "2.12.0" +version = "2.13.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pydantic" }, { name = "python-dotenv" }, { name = "typing-inspection" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/43/4b/ac7e0aae12027748076d72a8764ff1c9d82ca75a7a52622e67ed3f765c54/pydantic_settings-2.12.0.tar.gz", hash = "sha256:005538ef951e3c2a68e1c08b292b5f2e71490def8589d4221b95dab00dafcfd0", size = 194184, upload-time = "2025-11-10T14:25:47.013Z" } +sdist = { url = "https://files.pythonhosted.org/packages/52/6d/fffca34caecc4a3f97bda81b2098da5e8ab7efc9a66e819074a11955d87e/pydantic_settings-2.13.1.tar.gz", hash = "sha256:b4c11847b15237fb0171e1462bf540e294affb9b86db4d9aa5c01730bdbe4025", size = 223826, upload-time = "2026-02-19T13:45:08.055Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/60/5d4751ba3f4a40a6891f24eec885f51afd78d208498268c734e256fb13c4/pydantic_settings-2.12.0-py3-none-any.whl", hash = "sha256:fddb9fd99a5b18da837b29710391e945b1e30c135477f484084ee513adb93809", size = 51880, upload-time = "2025-11-10T14:25:45.546Z" }, + { url = "https://files.pythonhosted.org/packages/00/4b/ccc026168948fec4f7555b9164c724cf4125eac006e176541483d2c959be/pydantic_settings-2.13.1-py3-none-any.whl", hash = "sha256:d56fd801823dbeae7f0975e1f8c8e25c258eb75d278ea7abb5d9cebb01b56237", size = 58929, upload-time = "2026-02-19T13:45:06.034Z" }, +] + +[[package]] +name = "pyfiglet" +version = "1.0.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c8/e3/0a86276ad2c383ce08d76110a8eec2fe22e7051c4b8ba3fa163a0b08c428/pyfiglet-1.0.4.tar.gz", hash = "sha256:db9c9940ed1bf3048deff534ed52ff2dafbbc2cd7610b17bb5eca1df6d4278ef", size = 1560615, upload-time = "2025-08-15T18:32:47.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9f/5c/fe9f95abd5eaedfa69f31e450f7e2768bef121dbdf25bcddee2cd3087a16/pyfiglet-1.0.4-py3-none-any.whl", hash = "sha256:65b57b7a8e1dff8a67dc8e940a117238661d5e14c3e49121032bd404d9b2b39f", size = 1806118, upload-time = "2025-08-15T18:32:45.556Z" }, ] [[package]] @@ -5467,11 +5647,11 @@ wheels = [ [[package]] name = "pyjwt" -version = "2.11.0" +version = "2.12.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c2/27/a3b6e5bf6ff856d2509292e95c8f57f0df7017cf5394921fc4e4ef40308a/pyjwt-2.12.1.tar.gz", hash = "sha256:c74a7a2adf861c04d002db713dd85f84beb242228e671280bf709d765b03672b", size = 102564, upload-time = "2026-03-13T19:27:37.25Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" }, + { url = "https://files.pythonhosted.org/packages/e5/7a/8dd906bd22e79e47397a61742927f6747fe93242ef86645ee9092e610244/pyjwt-2.12.1-py3-none-any.whl", hash = "sha256:28ca37c070cad8ba8cd9790cd940535d40274d22f80ab87f3ac6a713e6e8454c", size = 29726, upload-time = "2026-03-13T19:27:35.677Z" }, ] [package.optional-dependencies] @@ -5481,34 +5661,35 @@ crypto = [ [[package]] name = "pymilvus" -version = "2.5.17" +version = "2.6.10" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "cachetools" }, { name = "grpcio" }, - { name = "milvus-lite", marker = "sys_platform != 'win32'" }, + { name = "orjson" }, { name = "pandas" }, { name = "protobuf" }, { name = "python-dotenv" }, + { name = "requests" }, { name = "setuptools" }, - { name = "ujson" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/dc/85/91828a9282bb7f9b210c0a93831979c5829cba5533ac12e87014b6e2208b/pymilvus-2.5.17.tar.gz", hash = "sha256:48ff55db9598e1b4cc25f4fe645b00d64ebcfb03f79f9f741267fc2a35526d43", size = 1281485, upload-time = "2025-11-10T03:24:53.058Z" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/85/90362066ccda5ff6fec693a55693cde659fdcd36d08f1bd7012ae958248d/pymilvus-2.6.10.tar.gz", hash = "sha256:58a44ee0f1dddd7727ae830ef25325872d8946f029d801a37105164e6699f1b8", size = 1561042, upload-time = "2026-03-13T09:54:22.441Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/44/ee0c64617f58c123f570293f36b40f7b56fc123a2aa9573aa22e6ff0fb86/pymilvus-2.5.17-py3-none-any.whl", hash = "sha256:a43d36f2e5f793040917d35858d1ed2532307b7dfb03bc3eaf813aac085bc5a4", size = 244036, upload-time = "2025-11-10T03:24:51.496Z" }, + { url = "https://files.pythonhosted.org/packages/88/10/fe7fbb6795aa20038afd55e9c653991e7c69fb24c741ebb39ba3b0aa5c13/pymilvus-2.6.10-py3-none-any.whl", hash = "sha256:a048b6f3ebad93742bca559beabf44fe578f0983555a109c4436b5fb2c1dbd40", size = 312797, upload-time = "2026-03-13T09:54:21.081Z" }, ] [[package]] name = "pymochow" -version = "2.2.9" +version = "2.3.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "future" }, { name = "orjson" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b5/29/d9b112684ce490057b90bddede3fb6a69cf2787a3fd7736bdce203e77388/pymochow-2.2.9.tar.gz", hash = "sha256:5a28058edc8861deb67524410e786814571ed9fe0700c8c9fc0bc2ad5835b06c", size = 50079, upload-time = "2025-06-05T08:33:19.59Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/04/2edda5447aa7c87a0b2b7c75406cc0fbcceeddd09c76b04edfb84eb47499/pymochow-2.3.6.tar.gz", hash = "sha256:6249a2fa410ef22e9e702710d725e7e052f492af87233ffe911845f931557632", size = 51123, upload-time = "2025-12-12T06:23:24.162Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/9b/be18f9709dfd8187ff233be5acb253a9f4f1b07f1db0e7b09d84197c28e2/pymochow-2.2.9-py3-none-any.whl", hash = "sha256:639192b97f143d4a22fc163872be12aee19523c46f12e22416e8f289f1354d15", size = 77899, upload-time = "2025-06-05T08:33:17.424Z" }, + { url = "https://files.pythonhosted.org/packages/aa/86/588c75acbcc7dd9860252f1ef2233212f36b6751ac0cdec15867fc2fc4d6/pymochow-2.3.6-py3-none-any.whl", hash = "sha256:d46cb3af4d908f0c15d875190b1945c0353b907d7e32f068636ee04433cf06b1", size = 78963, upload-time = "2025-12-12T06:23:21.419Z" }, ] [[package]] @@ -5522,7 +5703,7 @@ wheels = [ [[package]] name = "pyobvector" -version = "0.2.20" +version = "0.2.25" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiomysql" }, @@ -5532,72 +5713,89 @@ dependencies = [ { name = "sqlalchemy" }, { name = "sqlglot" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ca/6f/24ae2d4ba811e5e112c89bb91ba7c50eb79658563650c8fc65caa80655f8/pyobvector-0.2.20.tar.gz", hash = "sha256:72a54044632ba3bb27d340fb660c50b22548d34c6a9214b6653bc18eee4287c4", size = 46648, upload-time = "2025-11-20T09:30:16.354Z" } +sdist = { url = "https://files.pythonhosted.org/packages/38/8a/c459f45844f1f90e9edf80c0f434ec3b1a65132efb240cfab8f26b1836c3/pyobvector-0.2.25.tar.gz", hash = "sha256:94d987583255ed8aba701d37a5d7c2727ec5fd7e0288cd9dd87a1f5ee36dd923", size = 78511, upload-time = "2026-03-10T07:18:32.283Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/21/630c4e9f0d30b7a6eebe0590cd97162e82a2d3ac4ed3a33259d0a67e0861/pyobvector-0.2.20-py3-none-any.whl", hash = "sha256:9a3c1d3eb5268eae64185f8807b10fd182f271acf33323ee731c2ad554d1c076", size = 60131, upload-time = "2025-11-20T09:30:14.88Z" }, + { url = "https://files.pythonhosted.org/packages/d1/7d/037401cecb34728d1c28ea05e196ea3c9d50a1ce0f2172e586e075ff55d8/pyobvector-0.2.25-py3-none-any.whl", hash = "sha256:ae0153f99bd0222783ed7e3951efc31a0d2b462d926b6f86ebd2033409aede8f", size = 64663, upload-time = "2026-03-10T07:18:29.789Z" }, ] [[package]] name = "pypandoc" -version = "1.16.2" +version = "1.17" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0b/18/9f5f70567b97758625335209b98d5cb857e19aa1a9306e9749567a240634/pypandoc-1.16.2.tar.gz", hash = "sha256:7a72a9fbf4a5dc700465e384c3bb333d22220efc4e972cb98cf6fc723cdca86b", size = 31477, upload-time = "2025-11-13T16:30:29.608Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ea/d6/410615fc433e5d1eacc00db2044ae2a9c82302df0d35366fe2bd15de024d/pypandoc-1.17.tar.gz", hash = "sha256:51179abfd6e582a25ed03477541b48836b5bba5a4c3b282a547630793934d799", size = 69071, upload-time = "2026-03-14T22:39:07.21Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bb/e9/b145683854189bba84437ea569bfa786f408c8dc5bc16d8eb0753f5583bf/pypandoc-1.16.2-py3-none-any.whl", hash = "sha256:c200c1139c8e3247baf38d1e9279e85d9f162499d1999c6aa8418596558fe79b", size = 19451, upload-time = "2025-11-13T16:30:07.66Z" }, + { url = "https://files.pythonhosted.org/packages/0c/86/e2ffa604eacfbec3f430b1d850e7e04c4101eca1a5828f9ae54bf51dfba4/pypandoc-1.17-py3-none-any.whl", hash = "sha256:01fdbffa61edb9f8e82e8faad6954efcb7b6f8f0634aead4d89e322a00225a67", size = 23554, upload-time = "2026-03-14T22:38:46.007Z" }, +] + +[[package]] +name = "pypandoc-binary" +version = "1.17" +source = { registry = "https://pypi.org/simple" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/80/85/681a54111f0948821a5cf87ce30a88bb0a3f6848af5112c912abac4a2b77/pypandoc_binary-1.17-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:734726dc618ef276343e272e1a6b4567e59c2ef9ef41d5533042deac3b0531f1", size = 25553945, upload-time = "2026-03-14T22:38:47.91Z" }, + { url = "https://files.pythonhosted.org/packages/15/58/8fd107c68522957868c1e785fbea7595608df118e440e424d189668294df/pypandoc_binary-1.17-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fcfd28f347ed998dda28823fc6bc24f9310e7fdf3ddceaf925bf0563a100ab5b", size = 25553944, upload-time = "2026-03-14T22:38:50.74Z" }, + { url = "https://files.pythonhosted.org/packages/f4/27/ac1078239aae14b94c51975b7f46ad8e099e47d7ae26c175a5486b1c0099/pypandoc_binary-1.17-py3-none-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6b620b21c9374e3e48aabd518492bf0776b148442ee28816f6aaf52da3d4387", size = 34460960, upload-time = "2026-03-14T22:38:53.391Z" }, + { url = "https://files.pythonhosted.org/packages/8d/7f/1e5612b52900ebe590862dabeadf546f739b27527dcd8bfd632f8adac1be/pypandoc_binary-1.17-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9ada156cb980cd54fd6534231788e668c00dbb591cbd24f0be0bd86812eb8788", size = 36867598, upload-time = "2026-03-14T22:38:56.351Z" }, + { url = "https://files.pythonhosted.org/packages/3b/31/a5a867159c4080e5d368f4a53540a727501a2f31affc297dc8e0fced96a7/pypandoc_binary-1.17-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:2f439dcd211183bb3460253ca4511101df6e1acf4a01f45f5617e1fa2ad24279", size = 36867584, upload-time = "2026-03-14T22:38:59.087Z" }, + { url = "https://files.pythonhosted.org/packages/0d/2d/6a51cd4e54bdf132c19416801077c34bd40ba182e85d843360d36ae03a2d/pypandoc_binary-1.17-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:f6e6d3e4cfafbe23189a08db3d41f8def260bacd6e7e382bceadab7ba1f17da6", size = 34460949, upload-time = "2026-03-14T22:39:01.71Z" }, + { url = "https://files.pythonhosted.org/packages/c6/b9/f47b77ba75ed5d47ec85fcc2ecfbf7f78e3a73347f3a09836634d930de98/pypandoc_binary-1.17-py3-none-win_amd64.whl", hash = "sha256:76fae066cd2d7e78fb97f0ec8e9e36f437b07187b689b0b415ca18216f8f898a", size = 40891661, upload-time = "2026-03-14T22:39:04.782Z" }, ] [[package]] name = "pyparsing" -version = "3.2.5" +version = "3.3.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f2/a5/181488fc2b9d093e3972d2a472855aae8a03f000592dbfce716a512b3359/pyparsing-3.2.5.tar.gz", hash = "sha256:2df8d5b7b2802ef88e8d016a2eb9c7aeaa923529cd251ed0fe4608275d4105b6", size = 1099274, upload-time = "2025-09-21T04:11:06.277Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f3/91/9c6ee907786a473bf81c5f53cf703ba0957b23ab84c264080fb5a450416f/pyparsing-3.3.2.tar.gz", hash = "sha256:c777f4d763f140633dcb6d8a3eda953bf7a214dc4eff598413c070bcdc117cbc", size = 6851574, upload-time = "2026-01-21T03:57:59.36Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/10/5e/1aa9a93198c6b64513c9d7752de7422c06402de6600a8767da1524f9570b/pyparsing-3.2.5-py3-none-any.whl", hash = "sha256:e38a4f02064cf41fe6593d328d0512495ad1f3d8a91c4f73fc401b3079a59a5e", size = 113890, upload-time = "2025-09-21T04:11:04.117Z" }, + { url = "https://files.pythonhosted.org/packages/10/bd/c038d7cc38edc1aa5bf91ab8068b63d4308c66c4c8bb3cbba7dfbc049f9c/pyparsing-3.3.2-py3-none-any.whl", hash = "sha256:850ba148bd908d7e2411587e247a1e4f0327839c40e2e5e6d05a007ecc69911d", size = 122781, upload-time = "2026-01-21T03:57:55.912Z" }, ] [[package]] name = "pypdf" -version = "6.7.4" +version = "6.9.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/09/dc/f52deef12797ad58b88e4663f097a343f53b9361338aef6573f135ac302f/pypdf-6.7.4.tar.gz", hash = "sha256:9edd1cd47938bb35ec87795f61225fd58a07cfaf0c5699018ae1a47d6f8ab0e3", size = 5304821, upload-time = "2026-02-27T10:44:39.395Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/be/cded021305f5c81b47265b8c5292b99388615a4391c21ff00fd538d34a56/pypdf-6.7.4-py3-none-any.whl", hash = "sha256:527d6da23274a6c70a9cb59d1986d93946ba8e36a6bc17f3f7cce86331492dda", size = 331496, upload-time = "2026-02-27T10:44:37.527Z" }, + { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, ] [[package]] name = "pypdfium2" -version = "5.2.0" +version = "5.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f6/ab/73c7d24e4eac9ba952569403b32b7cca9412fc5b9bef54fdbd669551389f/pypdfium2-5.2.0.tar.gz", hash = "sha256:43863625231ce999c1ebbed6721a88de818b2ab4d909c1de558d413b9a400256", size = 269999, upload-time = "2025-12-12T13:20:15.353Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3b/01/be763b9081c7eb823196e7d13d9c145bf75ac43f3c1466de81c21c24b381/pypdfium2-5.6.0.tar.gz", hash = "sha256:bcb9368acfe3547054698abbdae68ba0cbd2d3bda8e8ee437e061deef061976d", size = 270714, upload-time = "2026-03-08T01:05:06.5Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fb/0c/9108ae5266ee4cdf495f99205c44d4b5c83b4eb227c2b610d35c9e9fe961/pypdfium2-5.2.0-py3-none-android_23_arm64_v8a.whl", hash = "sha256:1ba4187a45ce4cf08f2a8c7e0f8970c36b9aa1770c8a3412a70781c1d80fb145", size = 2763268, upload-time = "2025-12-12T13:19:37.354Z" }, - { url = "https://files.pythonhosted.org/packages/35/8c/55f5c8a2c6b293f5c020be4aa123eaa891e797c514e5eccd8cb042740d37/pypdfium2-5.2.0-py3-none-android_23_armeabi_v7a.whl", hash = "sha256:80c55e10a8c9242f0901d35a9a306dd09accce8e497507bb23fcec017d45fe2e", size = 2301821, upload-time = "2025-12-12T13:19:39.484Z" }, - { url = "https://files.pythonhosted.org/packages/5e/7d/efa013e3795b41c59dd1e472f7201c241232c3a6553be4917e3a26b9f225/pypdfium2-5.2.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:73523ae69cd95c084c1342096893b2143ea73c36fdde35494780ba431e6a7d6e", size = 2816428, upload-time = "2025-12-12T13:19:41.735Z" }, - { url = "https://files.pythonhosted.org/packages/ec/ae/8c30af6ff2ab41a7cb84753ee79dd1e0a8932c9bda9fe19759d69cbbf115/pypdfium2-5.2.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:19c501d22ef5eb98e42416d22cc3ac66d4808b436e3d06686392f24d8d9f708d", size = 2939486, upload-time = "2025-12-12T13:19:43.176Z" }, - { url = "https://files.pythonhosted.org/packages/64/64/454a73c49a04c2c290917ad86184e4da959e9e5aba94b3b046328c89be93/pypdfium2-5.2.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ed15a3f58d6ee4905f0d0a731e30b381b457c30689512589c7f57950b0cdcec", size = 2979235, upload-time = "2025-12-12T13:19:44.635Z" }, - { url = "https://files.pythonhosted.org/packages/4e/29/f1cab8e31192dd367dc7b1afa71f45cfcb8ff0b176f1d2a0f528faf04052/pypdfium2-5.2.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:329cd1e9f068e8729e0d0b79a070d6126f52bc48ff1e40505cb207a5e20ce0ba", size = 2763001, upload-time = "2025-12-12T13:19:47.598Z" }, - { url = "https://files.pythonhosted.org/packages/bc/5d/e95fad8fdac960854173469c4b6931d5de5e09d05e6ee7d9756f8b95eef0/pypdfium2-5.2.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:325259759886e66619504df4721fef3b8deabf8a233e4f4a66e0c32ebae60c2f", size = 3057024, upload-time = "2025-12-12T13:19:49.179Z" }, - { url = "https://files.pythonhosted.org/packages/f4/32/468591d017ab67f8142d40f4db8163b6d8bb404fe0d22da75a5c661dc144/pypdfium2-5.2.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5683e8f08ab38ed05e0e59e611451ec74332803d4e78f8c45658ea1d372a17af", size = 3448598, upload-time = "2025-12-12T13:19:50.979Z" }, - { url = "https://files.pythonhosted.org/packages/f9/a5/57b4e389b77ab5f7e9361dc7fc03b5378e678ba81b21e791e85350fbb235/pypdfium2-5.2.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da4815426a5adcf03bf4d2c5f26c0ff8109dbfaf2c3415984689931bc6006ef9", size = 2993946, upload-time = "2025-12-12T13:19:53.154Z" }, - { url = "https://files.pythonhosted.org/packages/84/3a/e03e9978f817632aa56183bb7a4989284086fdd45de3245ead35f147179b/pypdfium2-5.2.0-py3-none-manylinux_2_27_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64bf5c039b2c314dab1fd158bfff99db96299a5b5c6d96fc056071166056f1de", size = 3673148, upload-time = "2025-12-12T13:19:54.528Z" }, - { url = "https://files.pythonhosted.org/packages/13/ee/e581506806553afa4b7939d47bf50dca35c1151b8cc960f4542a6eb135ce/pypdfium2-5.2.0-py3-none-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:76b42a17748ac7dc04d5ef04d0561c6a0a4b546d113ec1d101d59650c6a340f7", size = 2964757, upload-time = "2025-12-12T13:19:56.406Z" }, - { url = "https://files.pythonhosted.org/packages/00/be/3715c652aff30f12284523dd337843d0efe3e721020f0ec303a99ffffd8d/pypdfium2-5.2.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:9d4367d471439fae846f0aba91ff9e8d66e524edcf3c8d6e02fe96fa306e13b9", size = 4130319, upload-time = "2025-12-12T13:19:57.889Z" }, - { url = "https://files.pythonhosted.org/packages/b0/0b/28aa2ede9004dd4192266bbad394df0896787f7c7bcfa4d1a6e091ad9a2c/pypdfium2-5.2.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:613f6bb2b47d76b66c0bf2ca581c7c33e3dd9dcb29d65d8c34fef4135f933149", size = 3746488, upload-time = "2025-12-12T13:19:59.469Z" }, - { url = "https://files.pythonhosted.org/packages/bc/04/1b791e1219652bbfc51df6498267d8dcec73ad508b99388b2890902ccd9d/pypdfium2-5.2.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c03fad3f2fa68d358f5dd4deb07e438482fa26fae439c49d127576d969769ca1", size = 4336534, upload-time = "2025-12-12T13:20:01.28Z" }, - { url = "https://files.pythonhosted.org/packages/4f/e3/6f00f963bb702ffd2e3e2d9c7286bc3bb0bebcdfa96ca897d466f66976c6/pypdfium2-5.2.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:f10be1900ae21879d02d9f4d58c2d2db3a2e6da611736a8e9decc22d1fb02909", size = 4375079, upload-time = "2025-12-12T13:20:03.117Z" }, - { url = "https://files.pythonhosted.org/packages/3a/2a/7ec2b191b5e1b7716a0dfc14e6860e89bb355fb3b94ed0c1d46db526858c/pypdfium2-5.2.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:97c1a126d30378726872f94866e38c055740cae80313638dafd1cd448d05e7c0", size = 3928648, upload-time = "2025-12-12T13:20:05.041Z" }, - { url = "https://files.pythonhosted.org/packages/bf/c3/c6d972fa095ff3ace76f9d3a91ceaf8a9dbbe0d9a5a84ac1d6178a46630e/pypdfium2-5.2.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:c369f183a90781b788af9a357a877bc8caddc24801e8346d0bf23f3295f89f3a", size = 4997772, upload-time = "2025-12-12T13:20:06.453Z" }, - { url = "https://files.pythonhosted.org/packages/22/45/2c64584b7a3ca5c4652280a884f4b85b8ed24e27662adeebdc06d991c917/pypdfium2-5.2.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b391f1cceb454934b612a05b54e90f98aafeffe5e73830d71700b17f0812226b", size = 4180046, upload-time = "2025-12-12T13:20:08.715Z" }, - { url = "https://files.pythonhosted.org/packages/d6/99/8d1ff87b626649400e62a2840e6e10fe258443ba518798e071fee4cd86f9/pypdfium2-5.2.0-py3-none-win32.whl", hash = "sha256:c68067938f617c37e4d17b18de7cac231fc7ce0eb7b6653b7283ebe8764d4999", size = 2990175, upload-time = "2025-12-12T13:20:10.241Z" }, - { url = "https://files.pythonhosted.org/packages/93/fc/114fff8895b620aac4984808e93d01b6d7b93e342a1635fcfe2a5f39cf39/pypdfium2-5.2.0-py3-none-win_amd64.whl", hash = "sha256:eb0591b720e8aaeab9475c66d653655ec1be0464b946f3f48a53922e843f0f3b", size = 3098615, upload-time = "2025-12-12T13:20:11.795Z" }, - { url = "https://files.pythonhosted.org/packages/08/97/eb738bff5998760d6e0cbcb7dd04cbf1a95a97b997fac6d4e57562a58992/pypdfium2-5.2.0-py3-none-win_arm64.whl", hash = "sha256:5dd1ef579f19fa3719aee4959b28bda44b1072405756708b5e83df8806a19521", size = 2939479, upload-time = "2025-12-12T13:20:13.815Z" }, + { url = "https://files.pythonhosted.org/packages/9d/b1/129ed0177521a93a892f8a6a215dd3260093e30e77ef7035004bb8af7b6c/pypdfium2-5.6.0-py3-none-android_23_arm64_v8a.whl", hash = "sha256:fb7858c9707708555b4a719b5548a6e7f5d26bc82aef55ae4eb085d7a2190b11", size = 3346059, upload-time = "2026-03-08T01:04:21.37Z" }, + { url = "https://files.pythonhosted.org/packages/86/34/cbdece6886012180a7f2c7b2c360c415cf5e1f83f1973d2c9201dae3506a/pypdfium2-5.6.0-py3-none-android_23_armeabi_v7a.whl", hash = "sha256:6a7e1f4597317786f994bfb947eef480e53933f804a990193ab89eef8243f805", size = 2804418, upload-time = "2026-03-08T01:04:23.384Z" }, + { url = "https://files.pythonhosted.org/packages/6e/f6/9f9e190fe0e5a6b86b82f83bd8b5d3490348766062381140ca5cad8e00b1/pypdfium2-5.6.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e468c38997573f0e86f03273c2c1fbdea999de52ba43fee96acaa2f6b2ad35f7", size = 3412541, upload-time = "2026-03-08T01:04:25.45Z" }, + { url = "https://files.pythonhosted.org/packages/ee/8d/e57492cb2228ba56ed57de1ff044c8ac114b46905f8b1445c33299ba0488/pypdfium2-5.6.0-py3-none-macosx_11_0_x86_64.whl", hash = "sha256:ad3abddc5805424f962e383253ccad6a0d1d2ebd86afa9a9e1b9ca659773cd0d", size = 3592320, upload-time = "2026-03-08T01:04:27.509Z" }, + { url = "https://files.pythonhosted.org/packages/f9/8a/8ab82e33e9c551494cbe1526ea250ca8cc4e9e98d6a4fc6b6f8d959aa1d1/pypdfium2-5.6.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6b5eb9eae5c45076395454522ca26add72ba8bd1fe473e1e4721aa58521470c", size = 3596450, upload-time = "2026-03-08T01:04:29.183Z" }, + { url = "https://files.pythonhosted.org/packages/f5/b5/602a792282312ccb158cc63849528079d94b0a11efdc61f2a359edfb41e9/pypdfium2-5.6.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:258624da8ef45cdc426e11b33e9d83f9fb723c1c201c6e0f4ab5a85966c6b876", size = 3325442, upload-time = "2026-03-08T01:04:30.886Z" }, + { url = "https://files.pythonhosted.org/packages/81/1f/9e48ec05ed8d19d736c2d1f23c1bd0f20673f02ef846a2576c69e237f15d/pypdfium2-5.6.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9367451c8a00931d6612db0822525a18c06f649d562cd323a719e46ac19c9bb", size = 3727434, upload-time = "2026-03-08T01:04:33.619Z" }, + { url = "https://files.pythonhosted.org/packages/33/90/0efd020928b4edbd65f4f3c2af0c84e20b43a3ada8fa6d04f999a97afe7a/pypdfium2-5.6.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a757869f891eac1cc1372e38a4aa01adac8abc8fe2a8a4e2ebf50595e3bf5937", size = 4139029, upload-time = "2026-03-08T01:04:36.08Z" }, + { url = "https://files.pythonhosted.org/packages/ff/49/a640b288a48dab1752281dd9b72c0679fccea107874e80a65a606b00efa9/pypdfium2-5.6.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:515be355222cc57ae9e62cd5c7c350b8e0c863efc539f80c7d75e2811ba45cb6", size = 3646387, upload-time = "2026-03-08T01:04:38.151Z" }, + { url = "https://files.pythonhosted.org/packages/b0/3b/a344c19c01021eeb5d830c102e4fc9b1602f19c04aa7d11abbe2d188fd8e/pypdfium2-5.6.0-py3-none-manylinux_2_27_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d1c4753c7caf7d004211d7f57a21f10d127f5e0e5510a14d24bc073e7220a3ea", size = 3097212, upload-time = "2026-03-08T01:04:40.776Z" }, + { url = "https://files.pythonhosted.org/packages/50/96/e48e13789ace22aeb9b7510904a1b1493ec588196e11bbacc122da330b3d/pypdfium2-5.6.0-py3-none-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c49729090281fdd85775fb8912c10bd19e99178efaa98f145ab06e7ce68554d2", size = 2965026, upload-time = "2026-03-08T01:04:42.857Z" }, + { url = "https://files.pythonhosted.org/packages/cb/06/3100e44d4935f73af8f5d633d3bd40f0d36d606027085a0ef1f0566a6320/pypdfium2-5.6.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:a4a1749a8d4afd62924a8d95cfa4f2e26fc32957ce34ac3b674be6f127ed252e", size = 4131431, upload-time = "2026-03-08T01:04:44.982Z" }, + { url = "https://files.pythonhosted.org/packages/64/ef/d8df63569ce9a66c8496057782eb8af78e0d28667922d62ec958434e3d4b/pypdfium2-5.6.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:36469ebd0fdffb7130ce45ed9c44f8232d91571c89eb851bd1633c64b6f6114f", size = 3747469, upload-time = "2026-03-08T01:04:46.702Z" }, + { url = "https://files.pythonhosted.org/packages/a6/47/fd2c6a67a49fade1acd719fbd11f7c375e7219912923ef2de0ea0ac1544e/pypdfium2-5.6.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9da900df09be3cf546b637a127a7b6428fb22d705951d731269e25fd3adef457", size = 4337578, upload-time = "2026-03-08T01:04:49.007Z" }, + { url = "https://files.pythonhosted.org/packages/6b/f5/836c83e54b01e09478c4d6bf4912651d6053c932250fcee953f5c72d8e4a/pypdfium2-5.6.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:45fccd5622233c5ec91a885770ae7dd4004d4320ac05a4ad8fa03a66dea40244", size = 4376104, upload-time = "2026-03-08T01:04:51.04Z" }, + { url = "https://files.pythonhosted.org/packages/6e/7f/b940b6a1664daf8f9bad87c6c99b84effa3611615b8708d10392dc33036c/pypdfium2-5.6.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:282dc030e767cd61bd0299f9d581052b91188e2b87561489057a8e7963e7e0cb", size = 3929824, upload-time = "2026-03-08T01:04:53.544Z" }, + { url = "https://files.pythonhosted.org/packages/88/79/00267d92a6a58c229e364d474f5698efe446e0c7f4f152f58d0138715e99/pypdfium2-5.6.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:a1c1dfe950382c76a7bba1ba160ec5e40df8dd26b04a1124ae268fda55bc4cbe", size = 4270201, upload-time = "2026-03-08T01:04:55.81Z" }, + { url = "https://files.pythonhosted.org/packages/e1/ab/b127f38aba41746bdf9ace15ba08411d7ef6ecba1326d529ba414eb1ed50/pypdfium2-5.6.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:43b0341ca6feb6c92e4b7a9eb4813e5466f5f5e8b6baeb14df0a94d5f312c00b", size = 4180793, upload-time = "2026-03-08T01:04:57.961Z" }, + { url = "https://files.pythonhosted.org/packages/0e/8c/a01c8e4302448b614d25a85c08298b0d3e9dfbdac5bd1b2f32c9b02e83d9/pypdfium2-5.6.0-py3-none-win32.whl", hash = "sha256:9dfcd4ff49a2b9260d00e38539ab28190d59e785e83030b30ffaf7a29c42155d", size = 3596753, upload-time = "2026-03-08T01:05:00.566Z" }, + { url = "https://files.pythonhosted.org/packages/9b/5f/2d871adf46761bb002a62686545da6348afe838d19af03df65d1ece786a2/pypdfium2-5.6.0-py3-none-win_amd64.whl", hash = "sha256:c6bc8dd63d0568f4b592f3e03de756afafc0e44aa1fe8878cc4aba1b11ae7374", size = 3716526, upload-time = "2026-03-08T01:05:02.433Z" }, + { url = "https://files.pythonhosted.org/packages/3a/80/0d9b162098597fbe3ac2b269b1682c0c3e8db9ba87679603fdd9b19afaa6/pypdfium2-5.6.0-py3-none-win_arm64.whl", hash = "sha256:5538417b199bdcb3207370c88df61f2ba3dac7a3253f82e1aa2708e6376b6f90", size = 3515049, upload-time = "2026-03-08T01:05:04.587Z" }, ] [[package]] name = "pypika" -version = "0.48.9" +version = "0.51.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/2c/94ed7b91db81d61d7096ac8f2d325ec562fc75e35f3baea8749c85b28784/PyPika-0.48.9.tar.gz", hash = "sha256:838836a61747e7c8380cd1b7ff638694b7a7335345d0f559b04b2cd832ad5378", size = 67259, upload-time = "2022-03-15T11:22:57.066Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f8/78/cbaebba88e05e2dcda13ca203131b38d3640219f20ebb49676d26714861b/pypika-0.51.1.tar.gz", hash = "sha256:c30c7c1048fbf056fd3920c5a2b88b0c29dd190a9b2bee971fd17e4abe4d0ebe", size = 80919, upload-time = "2026-02-04T11:27:48.304Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/57/83/c77dfeed04022e8930b08eedca2b6e5efed256ab3321396fde90066efb65/pypika-0.51.1-py2.py3-none-any.whl", hash = "sha256:77985b4d7ce71b9905255bf12468cf598349e98837c037541cfc240e528aec46", size = 60585, upload-time = "2026-02-04T11:27:46.251Z" }, +] [[package]] name = "pyproject-hooks" @@ -5619,33 +5817,34 @@ wheels = [ [[package]] name = "pyrefly" -version = "0.54.0" +version = "0.57.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/81/44/c10b16a302fda90d0af1328f880b232761b510eab546616a7be2fdf35a57/pyrefly-0.54.0.tar.gz", hash = "sha256:c6663be64d492f0d2f2a411ada9f28a6792163d34133639378b7f3dd9a8dca94", size = 5098893, upload-time = "2026-02-23T15:44:35.111Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/c1/c17211e5bbd2b90a24447484713da7cc2cee4e9455e57b87016ffc69d426/pyrefly-0.57.1.tar.gz", hash = "sha256:b05f6f5ee3a6a5d502ca19d84cb9ab62d67f05083819964a48c1510f2993efc6", size = 5310800, upload-time = "2026-03-18T18:42:35.614Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5f/99/8fdcdb4e55f0227fdd9f6abce36b619bab1ecb0662b83b66adc8cba3c788/pyrefly-0.54.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:58a3f092b6dc25ef79b2dc6c69a40f36784ca157c312bfc0baea463926a9db6d", size = 12223973, upload-time = "2026-02-23T15:44:14.278Z" }, - { url = "https://files.pythonhosted.org/packages/90/35/c2aaf87a76003ad27b286594d2e5178f811eaa15bfe3d98dba2b47d56dd1/pyrefly-0.54.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:615081414106dd95873bc39c3a4bed68754c6cc24a8177ac51d22f88f88d3eb3", size = 11785585, upload-time = "2026-02-23T15:44:17.468Z" }, - { url = "https://files.pythonhosted.org/packages/c4/4a/ced02691ed67e5a897714979196f08ad279ec7ec7f63c45e00a75a7f3c0e/pyrefly-0.54.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbcaf20f5fe585079079a95205c1f3cd4542d17228cdf1df560288880623b70", size = 33381977, upload-time = "2026-02-23T15:44:19.736Z" }, - { url = "https://files.pythonhosted.org/packages/0b/ce/72a117ed437c8f6950862181014b41e36f3c3997580e29b772b71e78d587/pyrefly-0.54.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:66d5da116c0d34acfbd66663addd3ca8aa78a636f6692a66e078126d3620a883", size = 35962821, upload-time = "2026-02-23T15:44:22.357Z" }, - { url = "https://files.pythonhosted.org/packages/85/de/89013f5ae0a35d2b6b01274a92a35ee91431ea001050edf0a16748d39875/pyrefly-0.54.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ef3ac27f1a4baaf67aead64287d3163350844794aca6315ad1a9650b16ec26a", size = 38496689, upload-time = "2026-02-23T15:44:25.236Z" }, - { url = "https://files.pythonhosted.org/packages/9f/9a/33b097c7bf498b924742dca32dd5d9c6a3fa6c2b52b63a58eb9e1980ca89/pyrefly-0.54.0-py3-none-win32.whl", hash = "sha256:7d607d72200a8afbd2db10bfefb40160a7a5d709d207161c21649cedd5cfc09a", size = 11295268, upload-time = "2026-02-23T15:44:27.551Z" }, - { url = "https://files.pythonhosted.org/packages/d4/21/9263fd1144d2a3d7342b474f183f7785b3358a1565c864089b780110b933/pyrefly-0.54.0-py3-none-win_amd64.whl", hash = "sha256:fd416f04f89309385696f685bd5c9141011f18c8072f84d31ca20c748546e791", size = 12081810, upload-time = "2026-02-23T15:44:29.461Z" }, - { url = "https://files.pythonhosted.org/packages/ea/5b/fad062a196c064cbc8564de5b2f4d3cb6315f852e3b31e8a1ce74c69a1ea/pyrefly-0.54.0-py3-none-win_arm64.whl", hash = "sha256:f06ab371356c7b1925e0bffe193b738797e71e5dbbff7fb5a13f90ee7521211d", size = 11564930, upload-time = "2026-02-23T15:44:33.053Z" }, + { url = "https://files.pythonhosted.org/packages/b7/58/8af37856c8d45b365ece635a6728a14b0356b08d1ff1ac601d7120def1e0/pyrefly-0.57.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:91974bfbe951eebf5a7bc959c1f3921f0371c789cad84761511d695e9ab2265f", size = 12681847, upload-time = "2026-03-18T18:42:10.963Z" }, + { url = "https://files.pythonhosted.org/packages/5f/d7/fae6dd9d0355fc5b8df7793f1423b7433ca8e10b698ea934c35f0e4e6522/pyrefly-0.57.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:808087298537c70f5e7cdccb5bbaad482e7e056e947c0adf00fb612cbace9fdc", size = 12219634, upload-time = "2026-03-18T18:42:13.469Z" }, + { url = "https://files.pythonhosted.org/packages/29/8f/9511ae460f0690e837b9ba0f7e5e192079e16ff9a9ba8a272450e81f11f8/pyrefly-0.57.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b01f454fa5539e070c0cba17ddec46b3d2107d571d519bd8eca8f3142ba02a6", size = 34947757, upload-time = "2026-03-18T18:42:17.152Z" }, + { url = "https://files.pythonhosted.org/packages/07/43/f053bf9c65218f70e6a49561e9942c7233f8c3e4da8d42e5fe2aae50b3d2/pyrefly-0.57.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:02ad59ea722191f51635f23e37574662116b82ca9d814529f7cb5528f041f381", size = 37621018, upload-time = "2026-03-18T18:42:20.79Z" }, + { url = "https://files.pythonhosted.org/packages/0e/76/9cea46de01665bbc125e4f215340c9365c8d56cda6198ff238a563ea8e75/pyrefly-0.57.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:54bc0afe56776145e37733ff763e7e9679ee8a76c467b617dc3f227d4124a9e2", size = 40203649, upload-time = "2026-03-18T18:42:24.519Z" }, + { url = "https://files.pythonhosted.org/packages/fd/8b/2fb4a96d75e2a57df698a43e2970e441ba2704e3906cdc0386a055daa05a/pyrefly-0.57.1-py3-none-win32.whl", hash = "sha256:468e5839144b25bb0dce839bfc5fd879c9f38e68ebf5de561f30bed9ae19d8ca", size = 11732953, upload-time = "2026-03-18T18:42:27.379Z" }, + { url = "https://files.pythonhosted.org/packages/13/5a/4a197910fe2e9b102b15ae5e7687c45b7b5981275a11a564b41e185dd907/pyrefly-0.57.1-py3-none-win_amd64.whl", hash = "sha256:46db9c97093673c4fb7fab96d610e74d140661d54688a92d8e75ad885a56c141", size = 12537319, upload-time = "2026-03-18T18:42:30.196Z" }, + { url = "https://files.pythonhosted.org/packages/b5/c6/bc442874be1d9b63da1f9debb4f04b7d0c590a8dc4091921f3c288207242/pyrefly-0.57.1-py3-none-win_arm64.whl", hash = "sha256:feb1bbe3b0d8d5a70121dcdf1476e6a99cc056a26a49379a156f040729244dcb", size = 12013455, upload-time = "2026-03-18T18:42:32.928Z" }, ] [[package]] name = "pytest" -version = "8.3.5" +version = "9.0.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "iniconfig" }, { name = "packaging" }, { name = "pluggy" }, + { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ae/3c/c9d525a414d506893f0cd8a8d0de7706446213181570cdbd766691164e40/pytest-8.3.5.tar.gz", hash = "sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845", size = 1450891, upload-time = "2025-03-02T12:54:54.503Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d1/db/7ef3487e0fb0049ddb5ce41d3a49c235bf9ad299b6a25d5780a89f19230f/pytest-9.0.2.tar.gz", hash = "sha256:75186651a92bd89611d1d9fc20f0b4345fd827c41ccd5c299a868a05d70edf11", size = 1568901, upload-time = "2025-12-06T21:30:51.014Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/3d/64ad57c803f1fa1e963a7946b6e0fea4a70df53c1a7fed304586539c2bac/pytest-8.3.5-py3-none-any.whl", hash = "sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820", size = 343634, upload-time = "2025-03-02T12:54:52.069Z" }, + { url = "https://files.pythonhosted.org/packages/3b/ab/b3226f0bd7cdcf710fbede2b3548584366da3b19b5021e74f5bde2a8fa3f/pytest-9.0.2-py3-none-any.whl", hash = "sha256:711ffd45bf766d5264d487b917733b453d917afd2b0ad65223959f59089f875b", size = 374801, upload-time = "2025-12-06T21:30:49.154Z" }, ] [[package]] @@ -5663,52 +5862,54 @@ wheels = [ [[package]] name = "pytest-benchmark" -version = "4.0.0" +version = "5.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "py-cpuinfo" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/28/08/e6b0067efa9a1f2a1eb3043ecd8a0c48bfeb60d3255006dcc829d72d5da2/pytest-benchmark-4.0.0.tar.gz", hash = "sha256:fb0785b83efe599a6a956361c0691ae1dbb5318018561af10f3e915caa0048d1", size = 334641, upload-time = "2022-10-25T21:21:55.686Z" } +sdist = { url = "https://files.pythonhosted.org/packages/24/34/9f732b76456d64faffbef6232f1f9dbec7a7c4999ff46282fa418bd1af66/pytest_benchmark-5.2.3.tar.gz", hash = "sha256:deb7317998a23c650fd4ff76e1230066a76cb45dcece0aca5607143c619e7779", size = 341340, upload-time = "2025-11-09T18:48:43.215Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/a1/3b70862b5b3f830f0422844f25a823d0470739d994466be9dbbbb414d85a/pytest_benchmark-4.0.0-py3-none-any.whl", hash = "sha256:fdb7db64e31c8b277dff9850d2a2556d8b60bcb0ea6524e36e28ffd7c87f71d6", size = 43951, upload-time = "2022-10-25T21:21:53.208Z" }, + { url = "https://files.pythonhosted.org/packages/33/29/e756e715a48959f1c0045342088d7ca9762a2f509b945f362a316e9412b7/pytest_benchmark-5.2.3-py3-none-any.whl", hash = "sha256:bc839726ad20e99aaa0d11a127445457b4219bdb9e80a1afc4b51da7f96b0803", size = 45255, upload-time = "2025-11-09T18:48:39.765Z" }, ] [[package]] name = "pytest-cov" -version = "4.1.0" +version = "7.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "coverage", extra = ["toml"] }, + { name = "pluggy" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7a/15/da3df99fd551507694a9b01f512a2f6cf1254f33601605843c3775f39460/pytest-cov-4.1.0.tar.gz", hash = "sha256:3904b13dfbfec47f003b8e77fd5b589cd11904a21ddf1ab38a64f204d6a10ef6", size = 63245, upload-time = "2023-05-24T18:44:56.845Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/51/a849f96e117386044471c8ec2bd6cfebacda285da9525c9106aeb28da671/pytest_cov-7.1.0.tar.gz", hash = "sha256:30674f2b5f6351aa09702a9c8c364f6a01c27aae0c1366ae8016160d1efc56b2", size = 55592, upload-time = "2026-03-21T20:11:16.284Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a7/4b/8b78d126e275efa2379b1c2e09dc52cf70df16fc3b90613ef82531499d73/pytest_cov-4.1.0-py3-none-any.whl", hash = "sha256:6ba70b9e97e69fcc3fb45bfeab2d0a138fb65c4d0d6a41ef33983ad114be8c3a", size = 21949, upload-time = "2023-05-24T18:44:54.079Z" }, + { url = "https://files.pythonhosted.org/packages/9d/7a/d968e294073affff457b041c2be9868a40c1c71f4a35fcc1e45e5493067b/pytest_cov-7.1.0-py3-none-any.whl", hash = "sha256:a0461110b7865f9a271aa1b51e516c9a95de9d696734a2f71e3e78f46e1d4678", size = 22876, upload-time = "2026-03-21T20:11:14.438Z" }, ] [[package]] name = "pytest-env" -version = "1.1.5" +version = "1.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, + { name = "python-dotenv" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/31/27f28431a16b83cab7a636dce59cf397517807d247caa38ee67d65e71ef8/pytest_env-1.1.5.tar.gz", hash = "sha256:91209840aa0e43385073ac464a554ad2947cc2fd663a9debf88d03b01e0cc1cf", size = 8911, upload-time = "2024-09-17T22:39:18.566Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ff/69/4db1c30625af0621df8dbe73797b38b6d1b04e15d021dd5d26a6d297f78c/pytest_env-1.6.0.tar.gz", hash = "sha256:ac02d6fba16af54d61e311dd70a3c61024a4e966881ea844affc3c8f0bf207d3", size = 16163, upload-time = "2026-03-12T22:39:43.78Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/de/b8/87cfb16045c9d4092cfcf526135d73b88101aac83bc1adcf82dfb5fd3833/pytest_env-1.1.5-py3-none-any.whl", hash = "sha256:ce90cf8772878515c24b31cd97c7fa1f4481cd68d588419fd45f10ecaee6bc30", size = 6141, upload-time = "2024-09-17T22:39:16.942Z" }, + { url = "https://files.pythonhosted.org/packages/27/16/ad52f56b96d851a2bcfdc1e754c3531341885bd7177a128c13ff2ca72ab4/pytest_env-1.6.0-py3-none-any.whl", hash = "sha256:1e7f8a62215e5885835daaed694de8657c908505b964ec8097a7ce77b403d9a3", size = 10400, upload-time = "2026-03-12T22:39:41.887Z" }, ] [[package]] name = "pytest-mock" -version = "3.14.1" +version = "3.15.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/71/28/67172c96ba684058a4d24ffe144d64783d2a270d0af0d9e792737bddc75c/pytest_mock-3.14.1.tar.gz", hash = "sha256:159e9edac4c451ce77a5cdb9fc5d1100708d2dd4ba3c3df572f14097351af80e", size = 33241, upload-time = "2025-05-26T13:58:45.167Z" } +sdist = { url = "https://files.pythonhosted.org/packages/68/14/eb014d26be205d38ad5ad20d9a80f7d201472e08167f0bb4361e251084a9/pytest_mock-3.15.1.tar.gz", hash = "sha256:1849a238f6f396da19762269de72cb1814ab44416fa73a8686deac10b0d87a0f", size = 34036, upload-time = "2025-09-16T16:37:27.081Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b2/05/77b60e520511c53d1c1ca75f1930c7dd8e971d0c4379b7f4b3f9644685ba/pytest_mock-3.14.1-py3-none-any.whl", hash = "sha256:178aefcd11307d874b4cd3100344e7e2d888d9791a6a1d9bfe90fbc1b74fd1d0", size = 9923, upload-time = "2025-05-26T13:58:43.487Z" }, + { url = "https://files.pythonhosted.org/packages/5a/cc/06253936f4a7fa2e0f48dfe6d851d9c56df896a9ab09ac019d70b760619c/pytest_mock-3.15.1-py3-none-any.whl", hash = "sha256:0a25e2eb88fe5168d535041d09a4529a188176ae608a6d249ee65abc0949630d", size = 10095, upload-time = "2025-09-16T16:37:25.734Z" }, ] [[package]] @@ -5725,15 +5926,15 @@ wheels = [ [[package]] name = "pytest-rerunfailures" -version = "12.0" +version = "16.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, { name = "pytest" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/97/66/40f778791860c5234c5c677026d45c1a8708873b3dba8111de672bceac4f/pytest-rerunfailures-12.0.tar.gz", hash = "sha256:784f462fa87fe9bdf781d0027d856b47a4bfe6c12af108f6bd887057a917b48e", size = 21154, upload-time = "2023-07-05T05:53:46.014Z" } +sdist = { url = "https://files.pythonhosted.org/packages/de/04/71e9520551fc8fe2cf5c1a1842e4e600265b0815f2016b7c27ec85688682/pytest_rerunfailures-16.1.tar.gz", hash = "sha256:c38b266db8a808953ebd71ac25c381cb1981a78ff9340a14bcb9f1b9bff1899e", size = 30889, upload-time = "2025-10-10T07:06:01.238Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/25/14/e02206388902a828cc26894996dfc68eec50f7583bcddc4b5605d0c18b51/pytest_rerunfailures-12.0-py3-none-any.whl", hash = "sha256:9a1afd04e21b8177faf08a9bbbf44de7a0fe3fc29f8ddbe83b9684bd5f8f92a9", size = 12977, upload-time = "2023-07-05T05:53:43.909Z" }, + { url = "https://files.pythonhosted.org/packages/77/54/60eabb34445e3db3d3d874dc1dfa72751bfec3265bd611cb13c8b290adea/pytest_rerunfailures-16.1-py3-none-any.whl", hash = "sha256:5d11b12c0ca9a1665b5054052fcc1084f8deadd9328962745ef6b04e26382e86", size = 14093, upload-time = "2025-10-10T07:06:00.019Z" }, ] [[package]] @@ -5763,46 +5964,46 @@ wheels = [ [[package]] name = "python-calamine" -version = "0.5.4" +version = "0.6.2" source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/18/e1e53ade001b30a3c6642d876e5defe8431da8c31fb7798909e6c8ab8c34/python_calamine-0.6.2.tar.gz", hash = "sha256:2c90e5224c5e92db9fcd8f22b6085ce63b935cfe7a893ac9a1c3c56793bafd9d", size = 138000, upload-time = "2026-02-18T13:38:17.389Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/25/1a/ff59788a7e8bfeded91a501abdd068dc7e2f5865ee1a55432133b0f7f08c/python_calamine-0.5.4-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:944bcc072aca29d346456b4e42675c4831c52c25641db3e976c6013cdd07d4cd", size = 854308, upload-time = "2025-10-21T07:10:55.17Z" }, - { url = "https://files.pythonhosted.org/packages/24/7d/33fc441a70b771093d10fa5086831be289766535cbcb2b443ff1d5e549d8/python_calamine-0.5.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e637382e50cabc263a37eda7a3cd33f054271e4391a304f68cecb2e490827533", size = 830841, upload-time = "2025-10-21T07:10:57.353Z" }, - { url = "https://files.pythonhosted.org/packages/0f/38/b5b25e6ce0a983c9751fb026bd8c5d77eb81a775948cc3d9ce2b18b2fc91/python_calamine-0.5.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1b2a31d1e711c5661b4f04efd89975d311788bd9a43a111beff74d7c4c8f8d7a", size = 898287, upload-time = "2025-10-21T07:10:58.977Z" }, - { url = "https://files.pythonhosted.org/packages/0f/e9/ab288cd489999f962f791d6c8544803c29dcf24e9b6dde24634c41ec09dd/python_calamine-0.5.4-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2078ede35cbd26cf7186673405ff13321caacd9e45a5e57b54ce7b3ef0eec2ff", size = 886960, upload-time = "2025-10-21T07:11:00.462Z" }, - { url = "https://files.pythonhosted.org/packages/f0/4d/2a261f2ccde7128a683cdb20733f9bc030ab37a90803d8de836bf6113e5b/python_calamine-0.5.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:faab9f59bb9cedba2b35c6e1f5dc72461d8f2837e8f6ab24fafff0d054ddc4b5", size = 1044123, upload-time = "2025-10-21T07:11:02.153Z" }, - { url = "https://files.pythonhosted.org/packages/20/dc/a84c5a5a2c38816570bcc96ae4c9c89d35054e59c4199d3caef9c60b65cf/python_calamine-0.5.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:300d8d5e6c63bdecf79268d3b6d2a84078cda39cb3394ed09c5c00a61ce9ff32", size = 941997, upload-time = "2025-10-21T07:11:03.537Z" }, - { url = "https://files.pythonhosted.org/packages/dd/92/b970d8316c54f274d9060e7c804b79dbfa250edeb6390cd94f5fcfeb5f87/python_calamine-0.5.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0019a74f1c0b1cbf08fee9ece114d310522837cdf63660a46fe46d3688f215ea", size = 905881, upload-time = "2025-10-21T07:11:05.228Z" }, - { url = "https://files.pythonhosted.org/packages/ac/88/9186ac8d3241fc6f90995cc7539bdbd75b770d2dab20978a702c36fbce5f/python_calamine-0.5.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:30b40ffb374f7fb9ce20ca87f43a609288f568e41872f8a72e5af313a9e20af0", size = 947224, upload-time = "2025-10-21T07:11:06.618Z" }, - { url = "https://files.pythonhosted.org/packages/ee/ec/6ac1882dc6b6fa829e2d1d94ffa58bd0c67df3dba074b2e2f3134d7f573a/python_calamine-0.5.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:206242690a5a5dff73a193fb1a1ca3c7a8aed95e2f9f10c875dece5a22068801", size = 1078351, upload-time = "2025-10-21T07:11:08.368Z" }, - { url = "https://files.pythonhosted.org/packages/3e/f1/07aff6966b04b7452c41a802b37199d9e9ac656d66d6092b83ab0937e212/python_calamine-0.5.4-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:88628e1a17a6f352d6433b0abf6edc4cb2295b8fbb3451392390f3a6a7a8cada", size = 1150148, upload-time = "2025-10-21T07:11:10.18Z" }, - { url = "https://files.pythonhosted.org/packages/4e/be/90aedeb0b77ea592a698a20db09014a5217ce46a55b699121849e239c8e7/python_calamine-0.5.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:22524cfb7720d15894a02392bbd49f8e7a8c173493f0628a45814d78e4243fff", size = 1080101, upload-time = "2025-10-21T07:11:11.489Z" }, - { url = "https://files.pythonhosted.org/packages/30/89/1fadd511d132d5ea9326c003c8753b6d234d61d9a72775fb1632cc94beb9/python_calamine-0.5.4-cp311-cp311-win32.whl", hash = "sha256:d159e98ef3475965555b67354f687257648f5c3686ed08e7faa34d54cc9274e1", size = 679593, upload-time = "2025-10-21T07:11:12.758Z" }, - { url = "https://files.pythonhosted.org/packages/e9/ba/d7324400a02491549ef30e0e480561a3a841aa073ac7c096313bc2cea555/python_calamine-0.5.4-cp311-cp311-win_amd64.whl", hash = "sha256:0d019b082f9a114cf1e130dc52b77f9f881325ab13dc31485d7b4563ad9e0812", size = 721570, upload-time = "2025-10-21T07:11:14.336Z" }, - { url = "https://files.pythonhosted.org/packages/4f/15/8c7895e603b4ae63ff279aae4aa6120658a15f805750ccdb5d8b311df616/python_calamine-0.5.4-cp311-cp311-win_arm64.whl", hash = "sha256:bb20875776e5b4c85134c2bf49fea12288e64448ed49f1d89a3a83f5bb16bd59", size = 685789, upload-time = "2025-10-21T07:11:15.646Z" }, - { url = "https://files.pythonhosted.org/packages/ff/60/b1ace7a0fd636581b3bb27f1011cb7b2fe4d507b58401c4d328cfcb5c849/python_calamine-0.5.4-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:4d711f91283d28f19feb111ed666764de69e6d2a0201df8f84e81a238f68d193", size = 850087, upload-time = "2025-10-21T07:11:17.002Z" }, - { url = "https://files.pythonhosted.org/packages/7f/32/32ca71ce50f9b7c7d6e7ec5fcc579a97ddd8b8ce314fe143ba2a19441dc7/python_calamine-0.5.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ed67afd3adedb5bcfb428cf1f2d7dfd936dea9fe979ab631194495ab092973ba", size = 825659, upload-time = "2025-10-21T07:11:18.248Z" }, - { url = "https://files.pythonhosted.org/packages/63/c5/27ba71a9da2a09be9ff2f0dac522769956c8c89d6516565b21c9c78bfae6/python_calamine-0.5.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:13662895dac487315ccce25ea272a1ea7e7ac05d899cde4e33d59d6c43274c54", size = 897332, upload-time = "2025-10-21T07:11:19.89Z" }, - { url = "https://files.pythonhosted.org/packages/5a/e7/c4be6ff8e8899ace98cacc9604a2dd1abc4901839b733addfb6ef32c22ba/python_calamine-0.5.4-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23e354755583cfaa824ddcbe8b099c5c7ac19bf5179320426e7a88eea2f14bc5", size = 886885, upload-time = "2025-10-21T07:11:21.912Z" }, - { url = "https://files.pythonhosted.org/packages/38/24/80258fb041435021efa10d0b528df6842e442585e48cbf130e73fed2529b/python_calamine-0.5.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4e1bc3f22107dcbdeb32d4d3c5c1e8831d3c85d4b004a8606dd779721b29843d", size = 1043907, upload-time = "2025-10-21T07:11:23.3Z" }, - { url = "https://files.pythonhosted.org/packages/f2/20/157340787d03ef6113a967fd8f84218e867ba4c2f7fc58cc645d8665a61a/python_calamine-0.5.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:182b314117e47dbd952adaa2b19c515555083a48d6f9146f46faaabd9dab2f81", size = 942376, upload-time = "2025-10-21T07:11:24.866Z" }, - { url = "https://files.pythonhosted.org/packages/98/f5/aec030f567ee14c60b6fc9028a78767687f484071cb080f7cfa328d6496e/python_calamine-0.5.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8f882e092ab23f72ea07e2e48f5f2efb1885c1836fb949f22fd4540ae11742e", size = 906455, upload-time = "2025-10-21T07:11:26.203Z" }, - { url = "https://files.pythonhosted.org/packages/29/58/4affc0d1389f837439ad45f400f3792e48030b75868ec757e88cb35d7626/python_calamine-0.5.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:62a9b4b7b9bd99d03373e58884dfb60d5a1c292c8e04e11f8b7420b77a46813e", size = 948132, upload-time = "2025-10-21T07:11:27.507Z" }, - { url = "https://files.pythonhosted.org/packages/b4/2e/70ed04f39e682a9116730f56b7fbb54453244ccc1c3dae0662d4819f1c1d/python_calamine-0.5.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:98bb011d33c0e2d183ff30ab3d96792c3493f56f67a7aa2fcadad9a03539e79b", size = 1077436, upload-time = "2025-10-21T07:11:28.801Z" }, - { url = "https://files.pythonhosted.org/packages/cb/ce/806f8ce06b5bb9db33007f85045c304cda410970e7aa07d08f6eaee67913/python_calamine-0.5.4-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:6b218a95489ff2f1cc1de0bba2a16fcc82981254bbb23f31d41d29191282b9ad", size = 1150570, upload-time = "2025-10-21T07:11:30.237Z" }, - { url = "https://files.pythonhosted.org/packages/18/da/61f13c8d107783128c1063cf52ca9cacdc064c58d58d3cf49c1728ce8296/python_calamine-0.5.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e8296a4872dbe834205d25d26dd6cfcb33ee9da721668d81b21adc25a07c07e4", size = 1080286, upload-time = "2025-10-21T07:11:31.564Z" }, - { url = "https://files.pythonhosted.org/packages/99/85/c5612a63292eb7d0648b17c5ff32ad5d6c6f3e1d78825f01af5c765f4d3f/python_calamine-0.5.4-cp312-cp312-win32.whl", hash = "sha256:cebb9c88983ae676c60c8c02aa29a9fe13563f240579e66de5c71b969ace5fd9", size = 676617, upload-time = "2025-10-21T07:11:32.833Z" }, - { url = "https://files.pythonhosted.org/packages/bb/18/5a037942de8a8df0c805224b2fba06df6d25c1be3c9484ba9db1ca4f3ee6/python_calamine-0.5.4-cp312-cp312-win_amd64.whl", hash = "sha256:15abd7aff98fde36d7df91ac051e86e66e5d5326a7fa98d54697afe95a613501", size = 721464, upload-time = "2025-10-21T07:11:34.383Z" }, - { url = "https://files.pythonhosted.org/packages/d1/8b/89ca17b44bcd8be5d0e8378d87b880ae17a837573553bd2147cceca7e759/python_calamine-0.5.4-cp312-cp312-win_arm64.whl", hash = "sha256:1cef0d0fc936974020a24acf1509ed2a285b30a4e1adf346c057112072e84251", size = 687268, upload-time = "2025-10-21T07:11:36.324Z" }, - { url = "https://files.pythonhosted.org/packages/ab/a8/0e05992489f8ca99eadfb52e858a7653b01b27a7c66d040abddeb4bdf799/python_calamine-0.5.4-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:8d4be45952555f129584e0ca6ddb442bed5cb97b8d7cd0fd5ae463237b98eb15", size = 856420, upload-time = "2025-10-21T07:13:20.962Z" }, - { url = "https://files.pythonhosted.org/packages/f0/b0/5bbe52c97161acb94066e7020c2fed7eafbca4bf6852a4b02ed80bf0b24b/python_calamine-0.5.4-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:5b387d12cb8cae98c8e0c061c5400f80bad1f43f26fafcf95ff5934df995f50b", size = 833240, upload-time = "2025-10-21T07:13:22.801Z" }, - { url = "https://files.pythonhosted.org/packages/c7/b9/44fa30f6bf479072d9042856d3fab8bdd1532d2d901e479e199bc1de0e6c/python_calamine-0.5.4-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2103714954b7dbed72a0b0eff178b08e854bba130be283e3ae3d7c95521e8f69", size = 899470, upload-time = "2025-10-21T07:13:25.176Z" }, - { url = "https://files.pythonhosted.org/packages/0e/f2/acbb2c1d6acba1eaf6b1efb6485c98995050bddedfb6b93ce05be2753a85/python_calamine-0.5.4-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c09fdebe23a5045d09e12b3366ff8fd45165b6fb56f55e9a12342a5daddbd11a", size = 906108, upload-time = "2025-10-21T07:13:26.709Z" }, - { url = "https://files.pythonhosted.org/packages/77/28/ff007e689539d6924223565995db876ac044466b8859bade371696294659/python_calamine-0.5.4-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:fa992d72fbd38f09107430100b7688c03046d8c1994e4cff9bbbd2a825811796", size = 948580, upload-time = "2025-10-21T07:13:30.816Z" }, - { url = "https://files.pythonhosted.org/packages/a4/06/b423655446fb27e22bfc1ca5e5b11f3449e0350fe8fefa0ebd68675f7e85/python_calamine-0.5.4-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:88e608c7589412d3159be40d270a90994e38c9eafc125bf8ad5a9c92deffd6dd", size = 1079516, upload-time = "2025-10-21T07:13:32.288Z" }, - { url = "https://files.pythonhosted.org/packages/76/f5/c7132088978b712a5eddf1ca6bf64ae81335fbca9443ed486330519954c3/python_calamine-0.5.4-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:51a007801aef12f6bc93a545040a36df48e9af920a7da9ded915584ad9a002b1", size = 1152379, upload-time = "2025-10-21T07:13:33.739Z" }, - { url = "https://files.pythonhosted.org/packages/bd/c8/37a8d80b7e55e7cfbe649f7a92a7e838defc746aac12dca751aad5dd06a6/python_calamine-0.5.4-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b056db205e45ab9381990a5c15d869f1021c1262d065740c9cd296fc5d3fb248", size = 1080420, upload-time = "2025-10-21T07:13:35.33Z" }, - { url = "https://files.pythonhosted.org/packages/10/52/9a96d06e75862d356dc80a4a465ad88fba544a19823568b4ff484e7a12f2/python_calamine-0.5.4-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:dd8f4123b2403fc22c92ec4f5e51c495427cf3739c5cb614b9829745a80922db", size = 722350, upload-time = "2025-10-21T07:13:37.074Z" }, + { url = "https://files.pythonhosted.org/packages/b8/89/739bf1033d7ec4d01f62b5a60a7ea2017f985f724baec95c39628a840ea3/python_calamine-0.6.2-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:626a9b342d3df5596e8a4c7410ad65d05447e77f915ef14cb0a9dd90f2d42608", size = 879847, upload-time = "2026-02-18T13:35:58.633Z" }, + { url = "https://files.pythonhosted.org/packages/29/e3/5f62d3e6cafb9a99290791f2206a2d752f9bffe65d19673b0f8e3e9c7320/python_calamine-0.6.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:65f427b50e659b3752b6a088bd703f1c4161e491b6745d849fe8cc997adbe296", size = 858617, upload-time = "2026-02-18T13:35:59.869Z" }, + { url = "https://files.pythonhosted.org/packages/49/01/683d0233c86d44ffe568ab627a6e99259f2fc868a8b3d343a8fd2f3ce6ec/python_calamine-0.6.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:316e41a9bc1fc0ee6c3e01f1e252112fae7cc3eba698870f11e0ef5d63d37b7c", size = 929696, upload-time = "2026-02-18T13:36:01.184Z" }, + { url = "https://files.pythonhosted.org/packages/e6/78/98d69f46db4adeab662463e119f5ac68ebfdc4f6122356c8db3e1442e8a6/python_calamine-0.6.2-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a9c0fa9138f2c3e9aff604744c78e8558c88f5e57ad5da23a3b6472d01b13696", size = 905627, upload-time = "2026-02-18T13:36:02.964Z" }, + { url = "https://files.pythonhosted.org/packages/e0/5f/9c4641e93c600e8bc52d19ad5b680baff242844cabee55f4625ec56ee6c1/python_calamine-0.6.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:fc42307129e59e48993ca974cf0d5402464690a15c87e5e5e879eb4a4d8ea362", size = 1064524, upload-time = "2026-02-18T13:36:07.778Z" }, + { url = "https://files.pythonhosted.org/packages/b6/48/907e98853008f21a2c77c12c9515903d2891572d453269f052bbd9a14c6a/python_calamine-0.6.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:aff0652a0d87d68de25c628338fb5c06e07f9d28d7120c5987e094282acf6d7a", size = 969734, upload-time = "2026-02-18T13:36:09.076Z" }, + { url = "https://files.pythonhosted.org/packages/01/f8/576d245d278ea3e13058a99cb2e3ab27eacf51a4ed6e387556e11e532114/python_calamine-0.6.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a170adfa3abb4724884617d12664c8518cd9eccb03522838553a06e353ed698", size = 937397, upload-time = "2026-02-18T13:36:10.695Z" }, + { url = "https://files.pythonhosted.org/packages/66/0e/bfe248e6d9e53b143ad371200613322e39c907598aed2e4915031c1d1635/python_calamine-0.6.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ee3f210e84eca9d8337cf0ea2f02757b2e0d123a07ecb43186dc5fd63c5c9d0b", size = 979083, upload-time = "2026-02-18T13:36:12.842Z" }, + { url = "https://files.pythonhosted.org/packages/9d/ad/4fa49ac4ee0a1fd97749149e4742f20146c706e31bc1ffac19d074a70817/python_calamine-0.6.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:20e0c2f28688786ea713f48dcc54a6877432da4ed277ae886a10a5ed802fcad6", size = 1106237, upload-time = "2026-02-18T13:36:14.084Z" }, + { url = "https://files.pythonhosted.org/packages/42/d4/ecd096f77bdf5927806591fe60b2937f301f098268ea97e045d9a552c7f4/python_calamine-0.6.2-cp311-cp311-musllinux_1_1_armv7l.whl", hash = "sha256:c743582bef238ea98176e47dffd9027d096df2bff5c4ad8bfad03920921aa021", size = 1180994, upload-time = "2026-02-18T13:36:15.253Z" }, + { url = "https://files.pythonhosted.org/packages/29/14/3128c1beddb5c12dc1eb0a7f0704a7f3fdd46ac1c79ae60f67cbc895707e/python_calamine-0.6.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8ddd518c62cb14e4ec5c81f8f389fd021d9a2ed6a5ab9348fbafb810c57f7717", size = 1148215, upload-time = "2026-02-18T13:36:16.45Z" }, + { url = "https://files.pythonhosted.org/packages/6b/7b/244514d60c41b65b2c1b16f9ffba01b48e0d323c61085bd1a98f6a49ab97/python_calamine-0.6.2-cp311-cp311-win32.whl", hash = "sha256:14f4fdb50ac7bcfe59b415bae865cdc958edff6a9adf85735dab4f7fa9f153f3", size = 699276, upload-time = "2026-02-18T13:36:17.755Z" }, + { url = "https://files.pythonhosted.org/packages/c4/67/749b482001d3c9ec601b6785ff715a8613e8bb6311ccd3d953485e08c6b3/python_calamine-0.6.2-cp311-cp311-win_amd64.whl", hash = "sha256:ec90a5d8358c0139bbcbf7f3a61d44baf41846a54367d4b34d5329102c4d4f59", size = 751919, upload-time = "2026-02-18T13:36:19.257Z" }, + { url = "https://files.pythonhosted.org/packages/65/a4/f2cc3243b836256d0bbb254cc9f766270813c4df77f45e8971d24c936bab/python_calamine-0.6.2-cp311-cp311-win_arm64.whl", hash = "sha256:a6112601f889e951333ff79ab0b6aa814a18df247d861db58d4272ad7be263d5", size = 719281, upload-time = "2026-02-18T13:36:20.469Z" }, + { url = "https://files.pythonhosted.org/packages/7a/ec/e111c1a3a4c138ebc41e416e33730ee6d7c54e714af21c2a4e59b41715a5/python_calamine-0.6.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:857e4cddadba9b55c76dc583c58c5dc101a6cd5320190c10f8b2ab98d66c9040", size = 879539, upload-time = "2026-02-18T13:36:21.674Z" }, + { url = "https://files.pythonhosted.org/packages/53/26/fe4c2138ff21542e2f1130a4d83c330d7f9486b62775196e998b88a03de6/python_calamine-0.6.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cd89d6a53e4b22328cd685fc054c31d359cb3ae67bd24bc57e1c1db62a4cfc97", size = 858642, upload-time = "2026-02-18T13:36:23.847Z" }, + { url = "https://files.pythonhosted.org/packages/6a/b0/bfeaf45ac5e2f6553723dd2fbe127d1d17c6f26496db5781de42a933776a/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d6c9af39db39e0c70710ae79cd1b5d980f9c0aea55fc16d194460c1561a0c6a", size = 925242, upload-time = "2026-02-18T13:36:25.236Z" }, + { url = "https://files.pythonhosted.org/packages/fe/6e/81106aa80609075015d400584030605b05f5e12931717160dcc58fdc4980/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a2382dbc410dd48c99d89ee460662cc70892fe1b2901ab982604b923e8eb8f6", size = 905295, upload-time = "2026-02-18T13:36:27.152Z" }, + { url = "https://files.pythonhosted.org/packages/4d/ba/6311b24f9889246be63b664630c5601039ef771f7ed04c8f51aace39b7a9/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ebb93255709874ede5b5e62828cb5758e60097e5390b6c9a3eb7751b617b12e", size = 1063473, upload-time = "2026-02-18T13:36:29.226Z" }, + { url = "https://files.pythonhosted.org/packages/23/e4/027a1b046d30768872307ebe808dc4cdc5357295cdcda98b30b3ea924904/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:837bca19bd945cb83aded433f4cf76e80d70a5400404d876400ca7e88e5ea311", size = 965355, upload-time = "2026-02-18T13:36:31.376Z" }, + { url = "https://files.pythonhosted.org/packages/5c/4d/da8716a1b3a66938aaabe36873f6fa210fa063bab1b20c2ec236013de6b3/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:723990a47668cb819f307ccc634741370d3cd3804a0ee8cda392a522ae6d5016", size = 935091, upload-time = "2026-02-18T13:36:32.777Z" }, + { url = "https://files.pythonhosted.org/packages/36/40/9521e8da5496cbc4b18027626a40018301f546b3e9802ca2f3a6cb5b4739/python_calamine-0.6.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b067630d693e1d7de41e3d44a99c7dd3feebb52db8dda8636ac3f70d8b6a4ad6", size = 974070, upload-time = "2026-02-18T13:36:34.055Z" }, + { url = "https://files.pythonhosted.org/packages/cd/b0/7a63963512c5ba7e9539b7452e2b1561625e63e4e29c044e487e2e93dcbe/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6ab09c9da53a2b33633e9f940aed11c08e083810a0fd6885826cdc52ba4f86a5", size = 1100321, upload-time = "2026-02-18T13:36:35.475Z" }, + { url = "https://files.pythonhosted.org/packages/22/81/e2bc38a5cf9629f656adcdabe8e134028f60c236e4bb96375dda90db3fdd/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:ae08e1308a0d0c6b8b4cc0a039ed8a85fc9ee2f8a3ca9ea57b1af9f97ed68fe4", size = 1181039, upload-time = "2026-02-18T13:36:37.195Z" }, + { url = "https://files.pythonhosted.org/packages/b1/ea/513117015fd5903ca6dde9c8fb8502af60af6965642f4e3311623943e673/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c441a20c7aff0e904ca01b5cdc1e5be2c6d4a41a24a0ea4d5ea6d211343bb95f", size = 1144843, upload-time = "2026-02-18T13:36:38.393Z" }, + { url = "https://files.pythonhosted.org/packages/a2/14/8846478dacf31535f5f15448ade3bc688b51f3183f1b52844451aa27b0e6/python_calamine-0.6.2-cp312-cp312-win32.whl", hash = "sha256:39cae8e66f8bce499f5f965f4575ddf61e30184cc97f02e1c7031a57abe0903b", size = 692411, upload-time = "2026-02-18T13:36:39.741Z" }, + { url = "https://files.pythonhosted.org/packages/bb/e2/2d2dcf4ec7e5ec08e33bf966ab010a7be178a4b623bd5f7601d47f2c734c/python_calamine-0.6.2-cp312-cp312-win_amd64.whl", hash = "sha256:1617efa24532f2420934a8cf77e6d33ff1740cae1d39355cab4f4cf141fdab49", size = 748960, upload-time = "2026-02-18T13:36:40.922Z" }, + { url = "https://files.pythonhosted.org/packages/f1/eb/2f50f3395c0435e6186cab56c36d04c06581ba827264bca1f1acae523aa3/python_calamine-0.6.2-cp312-cp312-win_arm64.whl", hash = "sha256:c2b378db494740e540e8157a7e5fe61dadae69ad2d988a7c80f9583f434acf07", size = 718992, upload-time = "2026-02-18T13:36:42.671Z" }, + { url = "https://files.pythonhosted.org/packages/42/8c/6b72c1eebbcd922feb1597751bbce4bdea569128b7e06f7983f0b6ad0cf0/python_calamine-0.6.2-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:0c6dc23fe9a00df4909db603966e020116742fbc2a998a121f2a9bf092ab9529", size = 883251, upload-time = "2026-02-18T13:38:05.881Z" }, + { url = "https://files.pythonhosted.org/packages/07/ef/5655a616e3d74ef74de0f32d0736357e3dc2982875823ac405e861c1d799/python_calamine-0.6.2-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:4f434d5db2a015ef1368fd633375844a6897a9b19df2cbdd6337ef3af2326a8e", size = 862349, upload-time = "2026-02-18T13:38:07.204Z" }, + { url = "https://files.pythonhosted.org/packages/77/54/cc754259fb1e56eff23792b5724d612689736176b98d09b755e6d9837676/python_calamine-0.6.2-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fc2f65c8edacf36f76d108d760413abccf2ace7478d9c2eafa795a0e01a89c44", size = 931049, upload-time = "2026-02-18T13:38:08.504Z" }, + { url = "https://files.pythonhosted.org/packages/af/05/c415832f625c94bc1d91e579e7f092a7109c83dd4841e033429db0621619/python_calamine-0.6.2-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c56d68cf4cc13e2865c66508b1ffde75de610c539603354240d366d3764968fc", size = 938030, upload-time = "2026-02-18T13:38:10.02Z" }, + { url = "https://files.pythonhosted.org/packages/f0/cb/ef5c257a52dd9d1c927f32048c4c254025012dd74ca96044c7dac9e8d9eb/python_calamine-0.6.2-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:8e5f3e959793f697101a96857e3e3e2e84b27847eb3b80370675483eeea5b07f", size = 980050, upload-time = "2026-02-18T13:38:11.515Z" }, + { url = "https://files.pythonhosted.org/packages/cb/5e/7437b7a9106bf331c7f802ec29d4c0208519203b56a3517fa1ba743d8398/python_calamine-0.6.2-pp311-pypy311_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:bb705b9acf993002bce5d2cab182cd04e539ae712be51f31a33a64528792869d", size = 1108010, upload-time = "2026-02-18T13:38:12.746Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d8/58cb7ac7ba2b937abc1b95e0c3cf6124b2d001d2bd03a62b1049afdb5c20/python_calamine-0.6.2-pp311-pypy311_pp73-musllinux_1_1_armv7l.whl", hash = "sha256:03faea1948e6cf82382819cb3f330616e0ec4b622e98102fd9a0522d260b1161", size = 1182431, upload-time = "2026-02-18T13:38:14.077Z" }, + { url = "https://files.pythonhosted.org/packages/12/2a/38ed80df1a21f27c4682d4855f42f78b19e96b77efbb80519c242780c161/python_calamine-0.6.2-pp311-pypy311_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:b149def7d77c8bc9e2d7b2a6c75215367426a0be430906fd6ea9803ae1c63e14", size = 1149093, upload-time = "2026-02-18T13:38:15.315Z" }, ] [[package]] @@ -5832,11 +6033,11 @@ wheels = [ [[package]] name = "python-dotenv" -version = "1.0.1" +version = "1.2.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bc/57/e84d88dfe0aec03b7a2d4327012c1627ab5f03652216c63d49846d7a6c58/python-dotenv-1.0.1.tar.gz", hash = "sha256:e324ee90a023d808f1959c46bcbc04446a10ced277783dc6ee09987c37ec10ca", size = 39115, upload-time = "2024-01-23T06:33:00.505Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/ed/0301aeeac3e5353ef3d94b6ec08bbcabd04a72018415dcb29e588514bba8/python_dotenv-1.2.2.tar.gz", hash = "sha256:2c371a91fbd7ba082c2c1dc1f8bf89ca22564a087c2c287cd9b662adde799cf3", size = 50135, upload-time = "2026-03-01T16:00:26.196Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/3e/b68c118422ec867fa7ab88444e1274aa40681c606d59ac27de5a5588f082/python_dotenv-1.0.1-py3-none-any.whl", hash = "sha256:f7b63ef50f1b690dddf550d03497b66d609393b40b564ed0d674909a68ebf16a", size = 19863, upload-time = "2024-01-23T06:32:58.246Z" }, + { url = "https://files.pythonhosted.org/packages/0b/d7/1959b9648791274998a9c3526f6d0ec8fd2233e4d4acce81bbae76b44b2a/python_dotenv-1.2.2-py3-none-any.whl", hash = "sha256:1d8214789a24de455a8b8bd8ae6fe3c6b69a5e3d64aa8a8e5d68e694bbcb285a", size = 22101, upload-time = "2026-03-01T16:00:25.09Z" }, ] [[package]] @@ -5850,11 +6051,11 @@ wheels = [ [[package]] name = "python-iso639" -version = "2025.11.16" +version = "2026.1.31" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/3b/3e07aadeeb7bbb2574d6aa6ccacbc58b17bd2b1fb6c7196bf96ab0e45129/python_iso639-2025.11.16.tar.gz", hash = "sha256:aabe941267898384415a509f5236d7cfc191198c84c5c6f73dac73d9783f5169", size = 174186, upload-time = "2025-11-16T21:53:37.031Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a3/da/701fc47ea3b0579a8ae489d50d5b54f2ef3aeb7768afd31db1d1cfe9f24e/python_iso639-2026.1.31.tar.gz", hash = "sha256:55a1612c15e5fbd3a1fa269a309cbf1e7c13019356e3d6f75bb435ed44c45ddb", size = 174144, upload-time = "2026-01-31T15:04:48.105Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b5/2d/563849c31e58eb2e273fa0c391a7d9987db32f4d9152fe6ecdac0a8ffe93/python_iso639-2025.11.16-py3-none-any.whl", hash = "sha256:65f6ac6c6d8e8207f6175f8bf7fff7db486c6dc5c1d8866c2b77d2a923370896", size = 167818, upload-time = "2025-11-16T21:53:35.36Z" }, + { url = "https://files.pythonhosted.org/packages/5b/3a/03ee682b04099e6b02b591955851b0347deb2e3691ae850112000c54ba12/python_iso639-2026.1.31-py3-none-any.whl", hash = "sha256:b2c48fa1300af1299dff4f1e1995ad1059996ed9f22270ea2d6d6bdc5fb03d4c", size = 167757, upload-time = "2026-01-31T15:04:46.458Z" }, ] [[package]] @@ -5895,32 +6096,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d9/4f/00be2196329ebbff56ce564aa94efb0fbc828d00de250b1980de1a34ab49/python_pptx-1.0.2-py3-none-any.whl", hash = "sha256:160838e0b8565a8b1f67947675886e9fea18aa5e795db7ae531606d68e785cba", size = 472788, upload-time = "2024-08-07T17:33:28.192Z" }, ] -[[package]] -name = "pytokens" -version = "0.4.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b6/34/b4e015b99031667a7b960f888889c5bd34ef585c85e1cb56a594b92836ac/pytokens-0.4.1.tar.gz", hash = "sha256:292052fe80923aae2260c073f822ceba21f3872ced9a68bb7953b348e561179a", size = 23015, upload-time = "2026-01-30T01:03:45.924Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/3d/92/790ebe03f07b57e53b10884c329b9a1a308648fc083a6d4a39a10a28c8fc/pytokens-0.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d70e77c55ae8380c91c0c18dea05951482e263982911fc7410b1ffd1dadd3440", size = 160864, upload-time = "2026-01-30T01:02:57.882Z" }, - { url = "https://files.pythonhosted.org/packages/13/25/a4f555281d975bfdd1eba731450e2fe3a95870274da73fb12c40aeae7625/pytokens-0.4.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a58d057208cb9075c144950d789511220b07636dd2e4708d5645d24de666bdc", size = 248565, upload-time = "2026-01-30T01:02:59.912Z" }, - { url = "https://files.pythonhosted.org/packages/17/50/bc0394b4ad5b1601be22fa43652173d47e4c9efbf0044c62e9a59b747c56/pytokens-0.4.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b49750419d300e2b5a3813cf229d4e5a4c728dae470bcc89867a9ad6f25a722d", size = 260824, upload-time = "2026-01-30T01:03:01.471Z" }, - { url = "https://files.pythonhosted.org/packages/4e/54/3e04f9d92a4be4fc6c80016bc396b923d2a6933ae94b5f557c939c460ee0/pytokens-0.4.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:d9907d61f15bf7261d7e775bd5d7ee4d2930e04424bab1972591918497623a16", size = 264075, upload-time = "2026-01-30T01:03:04.143Z" }, - { url = "https://files.pythonhosted.org/packages/d1/1b/44b0326cb5470a4375f37988aea5d61b5cc52407143303015ebee94abfd6/pytokens-0.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:ee44d0f85b803321710f9239f335aafe16553b39106384cef8e6de40cb4ef2f6", size = 103323, upload-time = "2026-01-30T01:03:05.412Z" }, - { url = "https://files.pythonhosted.org/packages/41/5d/e44573011401fb82e9d51e97f1290ceb377800fb4eed650b96f4753b499c/pytokens-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:140709331e846b728475786df8aeb27d24f48cbcf7bcd449f8de75cae7a45083", size = 160663, upload-time = "2026-01-30T01:03:06.473Z" }, - { url = "https://files.pythonhosted.org/packages/f0/e6/5bbc3019f8e6f21d09c41f8b8654536117e5e211a85d89212d59cbdab381/pytokens-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6d6c4268598f762bc8e91f5dbf2ab2f61f7b95bdc07953b602db879b3c8c18e1", size = 255626, upload-time = "2026-01-30T01:03:08.177Z" }, - { url = "https://files.pythonhosted.org/packages/bf/3c/2d5297d82286f6f3d92770289fd439956b201c0a4fc7e72efb9b2293758e/pytokens-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24afde1f53d95348b5a0eb19488661147285ca4dd7ed752bbc3e1c6242a304d1", size = 269779, upload-time = "2026-01-30T01:03:09.756Z" }, - { url = "https://files.pythonhosted.org/packages/20/01/7436e9ad693cebda0551203e0bf28f7669976c60ad07d6402098208476de/pytokens-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:5ad948d085ed6c16413eb5fec6b3e02fa00dc29a2534f088d3302c47eb59adf9", size = 268076, upload-time = "2026-01-30T01:03:10.957Z" }, - { url = "https://files.pythonhosted.org/packages/2e/df/533c82a3c752ba13ae7ef238b7f8cdd272cf1475f03c63ac6cf3fcfb00b6/pytokens-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:3f901fe783e06e48e8cbdc82d631fca8f118333798193e026a50ce1b3757ea68", size = 103552, upload-time = "2026-01-30T01:03:12.066Z" }, - { url = "https://files.pythonhosted.org/packages/c6/78/397db326746f0a342855b81216ae1f0a32965deccfd7c830a2dbc66d2483/pytokens-0.4.1-py3-none-any.whl", hash = "sha256:26cef14744a8385f35d0e095dc8b3a7583f6c953c2e3d269c7f82484bf5ad2de", size = 13729, upload-time = "2026-01-30T01:03:45.029Z" }, -] - [[package]] name = "pytz" -version = "2025.2" +version = "2026.1.post1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f8/bf/abbd3cdfb8fbc7fb3d4d38d320f2441b1e7cbe29be4f23797b4a2b5d8aac/pytz-2025.2.tar.gz", hash = "sha256:360b9e3dbb49a209c21ad61809c7fb453643e048b38924c765813546746e81c3", size = 320884, upload-time = "2025-03-25T02:25:00.538Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/db/b8721d71d945e6a8ac63c0fc900b2067181dbb50805958d4d4661cf7d277/pytz-2026.1.post1.tar.gz", hash = "sha256:3378dde6a0c3d26719182142c56e60c7f9af7e968076f31aae569d72a0358ee1", size = 321088, upload-time = "2026-03-03T07:47:50.683Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/81/c4/34e93fe5f5429d7570ec1fa436f1986fb1f00c3e0f43a589fe2bbcd22c3f/pytz-2025.2-py2.py3-none-any.whl", hash = "sha256:5ddf76296dd8c44c26eb8f4b6f35488f3ccbf6fbbd7adee0b7262d43f0ec2f00", size = 509225, upload-time = "2025-03-25T02:24:58.468Z" }, + { url = "https://files.pythonhosted.org/packages/10/99/781fe0c827be2742bcc775efefccb3b048a3a9c6ce9aec0cbf4a101677e5/pytz-2026.1.post1-py2.py3-none-any.whl", hash = "sha256:f2fd16142fda348286a75e1a524be810bb05d444e5a081f37f7affc635035f7a", size = 510489, upload-time = "2026-03-03T07:47:49.167Z" }, ] [[package]] @@ -5936,15 +6118,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852", size = 8710102, upload-time = "2025-07-14T20:13:24.682Z" }, ] -[[package]] -name = "pywin32-ctypes" -version = "0.2.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/85/9f/01a1a99704853cb63f253eea009390c88e7131c67e66a0a02099a8c917cb/pywin32-ctypes-0.2.3.tar.gz", hash = "sha256:d162dc04946d704503b2edc4d55f3dba5c1d539ead017afa00142c38b9885755", size = 29471, upload-time = "2024-08-14T10:15:34.626Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/de/3d/8161f7711c017e01ac9f008dfddd9410dff3674334c233bde66e7ba65bbf/pywin32_ctypes-0.2.3-py3-none-any.whl", hash = "sha256:8a1513379d709975552d202d942d9837758905c8d01eb82b8bcc30918929e7b8", size = 30756, upload-time = "2024-08-14T10:15:33.187Z" }, -] - [[package]] name = "pyxlsb" version = "1.0.10" @@ -6001,32 +6174,31 @@ wheels = [ [[package]] name = "ragas" -version = "0.4.3" +version = "0.3.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "appdirs" }, { name = "datasets" }, { name = "diskcache" }, + { name = "gitpython" }, { name = "instructor" }, { name = "langchain" }, { name = "langchain-community" }, { name = "langchain-core" }, { name = "langchain-openai" }, { name = "nest-asyncio" }, - { name = "networkx" }, { name = "numpy" }, { name = "openai" }, { name = "pillow" }, { name = "pydantic" }, { name = "rich" }, - { name = "scikit-network" }, { name = "tiktoken" }, { name = "tqdm" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d2/bc/3234517692ac0ffae1ec2ec940992e4057844c49ee6c51c07ce385bb98f1/ragas-0.4.3.tar.gz", hash = "sha256:1eb1f61dbc8613ad014fdb8d630cbe9a1caec1ea01664a106993cb756128c001", size = 44029626, upload-time = "2026-01-13T17:48:01.043Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5c/3f/44ffcfcd10b885c330f5e07f327edcf8965674103f81a8766a01672c3f27/ragas-0.3.2.tar.gz", hash = "sha256:a800a5326f0d5bfa086cf8832d4b1031e4b21ff063f83d7d9388ee36e1a93fe6", size = 257091, upload-time = "2025-08-19T12:04:18.655Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/e0/1fecd22c93d3ed66453cbbdefd05528331af4d33b2b76a370d751231912c/ragas-0.4.3-py3-none-any.whl", hash = "sha256:ef1d75f674c294e9a6e7d8e9ad261b6bf4697dad1c9cbd1a756ba7a6b4849a38", size = 466452, upload-time = "2026-01-13T17:47:59.2Z" }, + { url = "https://files.pythonhosted.org/packages/1f/1a/d08bb5fd256c375dae885e508c00636cba5692e772616e49c6d4cfdec52c/ragas-0.3.2-py3-none-any.whl", hash = "sha256:ccf4447fdc6daf69b5fafb26d8baa5095e4194a4dc87091b2a118f68322e7f1b", size = 277283, upload-time = "2025-08-19T12:04:17.329Z" }, ] [[package]] @@ -6079,20 +6251,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/dd/46/8a640c6de1a6c6af971f858b2fb178ca5e1db91f223d8ba5f40efe1491e5/readabilipy-0.3.0-py3-none-any.whl", hash = "sha256:d106da0fad11d5fdfcde21f5c5385556bfa8ff0258483037d39ea6b1d6db3943", size = 22158, upload-time = "2024-12-02T23:03:00.438Z" }, ] -[[package]] -name = "readme-renderer" -version = "44.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "docutils" }, - { name = "nh3" }, - { name = "pygments" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/5a/a9/104ec9234c8448c4379768221ea6df01260cd6c2ce13182d4eac531c8342/readme_renderer-44.0.tar.gz", hash = "sha256:8712034eabbfa6805cacf1402b4eeb2a73028f72d1166d6f5cb7f9c047c5d1e1", size = 32056, upload-time = "2024-07-08T15:00:57.805Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e1/67/921ec3024056483db83953ae8e48079ad62b92db7880013ca77632921dd0/readme_renderer-44.0-py3-none-any.whl", hash = "sha256:2fbca89b81a08526aadf1357a8c2ae889ec05fb03f5da67f9769c9a592166151", size = 13310, upload-time = "2024-07-08T15:00:56.577Z" }, -] - [[package]] name = "realtime" version = "2.7.0" @@ -6109,14 +6267,14 @@ wheels = [ [[package]] name = "redis" -version = "7.2.0" +version = "7.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9f/32/6fac13a11e73e1bc67a2ae821a72bfe4c2d8c4c48f0267e4a952be0f1bae/redis-7.2.0.tar.gz", hash = "sha256:4dd5bf4bd4ae80510267f14185a15cba2a38666b941aff68cccf0256b51c1f26", size = 4901247, upload-time = "2026-02-16T17:16:22.797Z" } +sdist = { url = "https://files.pythonhosted.org/packages/da/82/4d1a5279f6c1251d3d2a603a798a1137c657de9b12cfc1fba4858232c4d2/redis-7.3.0.tar.gz", hash = "sha256:4d1b768aafcf41b01022410b3cc4f15a07d9b3d6fe0c66fc967da2c88e551034", size = 4928081, upload-time = "2026-03-06T18:18:16.287Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/86/cf/f6180b67f99688d83e15c84c5beda831d1d341e95872d224f87ccafafe61/redis-7.2.0-py3-none-any.whl", hash = "sha256:01f591f8598e483f1842d429e8ae3a820804566f1c73dca1b80e23af9fba0497", size = 394898, upload-time = "2026-02-16T17:16:20.693Z" }, + { url = "https://files.pythonhosted.org/packages/f0/28/84e57fce7819e81ec5aa1bd31c42b89607241f4fb1a3ea5b0d2dbeaea26c/redis-7.3.0-py3-none-any.whl", hash = "sha256:9d4fcb002a12a5e3c3fbe005d59c48a2cc231f87fbb2f6b70c2d89bb64fec364", size = 404379, upload-time = "2026-03-06T18:18:14.583Z" }, ] [package.optional-dependencies] @@ -6140,38 +6298,42 @@ wheels = [ [[package]] name = "regex" -version = "2025.11.3" +version = "2026.2.28" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cc/a9/546676f25e573a4cf00fe8e119b78a37b6a8fe2dc95cda877b30889c9c45/regex-2025.11.3.tar.gz", hash = "sha256:1fedc720f9bb2494ce31a58a1631f9c82df6a09b49c19517ea5cc280b4541e01", size = 414669, upload-time = "2025-11-03T21:34:22.089Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/71/41455aa99a5a5ac1eaf311f5d8efd9ce6433c03ac1e0962de163350d0d97/regex-2026.2.28.tar.gz", hash = "sha256:a729e47d418ea11d03469f321aaf67cdee8954cde3ff2cf8403ab87951ad10f2", size = 415184, upload-time = "2026-02-28T02:19:42.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f7/90/4fb5056e5f03a7048abd2b11f598d464f0c167de4f2a51aa868c376b8c70/regex-2025.11.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:eadade04221641516fa25139273505a1c19f9bf97589a05bc4cfcd8b4a618031", size = 488081, upload-time = "2025-11-03T21:31:11.946Z" }, - { url = "https://files.pythonhosted.org/packages/85/23/63e481293fac8b069d84fba0299b6666df720d875110efd0338406b5d360/regex-2025.11.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:feff9e54ec0dd3833d659257f5c3f5322a12eee58ffa360984b716f8b92983f4", size = 290554, upload-time = "2025-11-03T21:31:13.387Z" }, - { url = "https://files.pythonhosted.org/packages/2b/9d/b101d0262ea293a0066b4522dfb722eb6a8785a8c3e084396a5f2c431a46/regex-2025.11.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3b30bc921d50365775c09a7ed446359e5c0179e9e2512beec4a60cbcef6ddd50", size = 288407, upload-time = "2025-11-03T21:31:14.809Z" }, - { url = "https://files.pythonhosted.org/packages/0c/64/79241c8209d5b7e00577ec9dca35cd493cc6be35b7d147eda367d6179f6d/regex-2025.11.3-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f99be08cfead2020c7ca6e396c13543baea32343b7a9a5780c462e323bd8872f", size = 793418, upload-time = "2025-11-03T21:31:16.556Z" }, - { url = "https://files.pythonhosted.org/packages/3d/e2/23cd5d3573901ce8f9757c92ca4db4d09600b865919b6d3e7f69f03b1afd/regex-2025.11.3-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:6dd329a1b61c0ee95ba95385fb0c07ea0d3fe1a21e1349fa2bec272636217118", size = 860448, upload-time = "2025-11-03T21:31:18.12Z" }, - { url = "https://files.pythonhosted.org/packages/2a/4c/aecf31beeaa416d0ae4ecb852148d38db35391aac19c687b5d56aedf3a8b/regex-2025.11.3-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4c5238d32f3c5269d9e87be0cf096437b7622b6920f5eac4fd202468aaeb34d2", size = 907139, upload-time = "2025-11-03T21:31:20.753Z" }, - { url = "https://files.pythonhosted.org/packages/61/22/b8cb00df7d2b5e0875f60628594d44dba283e951b1ae17c12f99e332cc0a/regex-2025.11.3-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:10483eefbfb0adb18ee9474498c9a32fcf4e594fbca0543bb94c48bac6183e2e", size = 800439, upload-time = "2025-11-03T21:31:22.069Z" }, - { url = "https://files.pythonhosted.org/packages/02/a8/c4b20330a5cdc7a8eb265f9ce593f389a6a88a0c5f280cf4d978f33966bc/regex-2025.11.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:78c2d02bb6e1da0720eedc0bad578049cad3f71050ef8cd065ecc87691bed2b0", size = 782965, upload-time = "2025-11-03T21:31:23.598Z" }, - { url = "https://files.pythonhosted.org/packages/b4/4c/ae3e52988ae74af4b04d2af32fee4e8077f26e51b62ec2d12d246876bea2/regex-2025.11.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:e6b49cd2aad93a1790ce9cffb18964f6d3a4b0b3dbdbd5de094b65296fce6e58", size = 854398, upload-time = "2025-11-03T21:31:25.008Z" }, - { url = "https://files.pythonhosted.org/packages/06/d1/a8b9cf45874eda14b2e275157ce3b304c87e10fb38d9fc26a6e14eb18227/regex-2025.11.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:885b26aa3ee56433b630502dc3d36ba78d186a00cc535d3806e6bfd9ed3c70ab", size = 845897, upload-time = "2025-11-03T21:31:26.427Z" }, - { url = "https://files.pythonhosted.org/packages/ea/fe/1830eb0236be93d9b145e0bd8ab499f31602fe0999b1f19e99955aa8fe20/regex-2025.11.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ddd76a9f58e6a00f8772e72cff8ebcff78e022be95edf018766707c730593e1e", size = 788906, upload-time = "2025-11-03T21:31:28.078Z" }, - { url = "https://files.pythonhosted.org/packages/66/47/dc2577c1f95f188c1e13e2e69d8825a5ac582ac709942f8a03af42ed6e93/regex-2025.11.3-cp311-cp311-win32.whl", hash = "sha256:3e816cc9aac1cd3cc9a4ec4d860f06d40f994b5c7b4d03b93345f44e08cc68bf", size = 265812, upload-time = "2025-11-03T21:31:29.72Z" }, - { url = "https://files.pythonhosted.org/packages/50/1e/15f08b2f82a9bbb510621ec9042547b54d11e83cb620643ebb54e4eb7d71/regex-2025.11.3-cp311-cp311-win_amd64.whl", hash = "sha256:087511f5c8b7dfbe3a03f5d5ad0c2a33861b1fc387f21f6f60825a44865a385a", size = 277737, upload-time = "2025-11-03T21:31:31.422Z" }, - { url = "https://files.pythonhosted.org/packages/f4/fc/6500eb39f5f76c5e47a398df82e6b535a5e345f839581012a418b16f9cc3/regex-2025.11.3-cp311-cp311-win_arm64.whl", hash = "sha256:1ff0d190c7f68ae7769cd0313fe45820ba07ffebfddfaa89cc1eb70827ba0ddc", size = 270290, upload-time = "2025-11-03T21:31:33.041Z" }, - { url = "https://files.pythonhosted.org/packages/e8/74/18f04cb53e58e3fb107439699bd8375cf5a835eec81084e0bddbd122e4c2/regex-2025.11.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:bc8ab71e2e31b16e40868a40a69007bc305e1109bd4658eb6cad007e0bf67c41", size = 489312, upload-time = "2025-11-03T21:31:34.343Z" }, - { url = "https://files.pythonhosted.org/packages/78/3f/37fcdd0d2b1e78909108a876580485ea37c91e1acf66d3bb8e736348f441/regex-2025.11.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:22b29dda7e1f7062a52359fca6e58e548e28c6686f205e780b02ad8ef710de36", size = 291256, upload-time = "2025-11-03T21:31:35.675Z" }, - { url = "https://files.pythonhosted.org/packages/bf/26/0a575f58eb23b7ebd67a45fccbc02ac030b737b896b7e7a909ffe43ffd6a/regex-2025.11.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3a91e4a29938bc1a082cc28fdea44be420bf2bebe2665343029723892eb073e1", size = 288921, upload-time = "2025-11-03T21:31:37.07Z" }, - { url = "https://files.pythonhosted.org/packages/ea/98/6a8dff667d1af907150432cf5abc05a17ccd32c72a3615410d5365ac167a/regex-2025.11.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:08b884f4226602ad40c5d55f52bf91a9df30f513864e0054bad40c0e9cf1afb7", size = 798568, upload-time = "2025-11-03T21:31:38.784Z" }, - { url = "https://files.pythonhosted.org/packages/64/15/92c1db4fa4e12733dd5a526c2dd2b6edcbfe13257e135fc0f6c57f34c173/regex-2025.11.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:3e0b11b2b2433d1c39c7c7a30e3f3d0aeeea44c2a8d0bae28f6b95f639927a69", size = 864165, upload-time = "2025-11-03T21:31:40.559Z" }, - { url = "https://files.pythonhosted.org/packages/f9/e7/3ad7da8cdee1ce66c7cd37ab5ab05c463a86ffeb52b1a25fe7bd9293b36c/regex-2025.11.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:87eb52a81ef58c7ba4d45c3ca74e12aa4b4e77816f72ca25258a85b3ea96cb48", size = 912182, upload-time = "2025-11-03T21:31:42.002Z" }, - { url = "https://files.pythonhosted.org/packages/84/bd/9ce9f629fcb714ffc2c3faf62b6766ecb7a585e1e885eb699bcf130a5209/regex-2025.11.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a12ab1f5c29b4e93db518f5e3872116b7e9b1646c9f9f426f777b50d44a09e8c", size = 803501, upload-time = "2025-11-03T21:31:43.815Z" }, - { url = "https://files.pythonhosted.org/packages/7c/0f/8dc2e4349d8e877283e6edd6c12bdcebc20f03744e86f197ab6e4492bf08/regex-2025.11.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7521684c8c7c4f6e88e35ec89680ee1aa8358d3f09d27dfbdf62c446f5d4c695", size = 787842, upload-time = "2025-11-03T21:31:45.353Z" }, - { url = "https://files.pythonhosted.org/packages/f9/73/cff02702960bc185164d5619c0c62a2f598a6abff6695d391b096237d4ab/regex-2025.11.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7fe6e5440584e94cc4b3f5f4d98a25e29ca12dccf8873679a635638349831b98", size = 858519, upload-time = "2025-11-03T21:31:46.814Z" }, - { url = "https://files.pythonhosted.org/packages/61/83/0e8d1ae71e15bc1dc36231c90b46ee35f9d52fab2e226b0e039e7ea9c10a/regex-2025.11.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:8e026094aa12b43f4fd74576714e987803a315c76edb6b098b9809db5de58f74", size = 850611, upload-time = "2025-11-03T21:31:48.289Z" }, - { url = "https://files.pythonhosted.org/packages/c8/f5/70a5cdd781dcfaa12556f2955bf170cd603cb1c96a1827479f8faea2df97/regex-2025.11.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:435bbad13e57eb5606a68443af62bed3556de2f46deb9f7d4237bc2f1c9fb3a0", size = 789759, upload-time = "2025-11-03T21:31:49.759Z" }, - { url = "https://files.pythonhosted.org/packages/59/9b/7c29be7903c318488983e7d97abcf8ebd3830e4c956c4c540005fcfb0462/regex-2025.11.3-cp312-cp312-win32.whl", hash = "sha256:3839967cf4dc4b985e1570fd8d91078f0c519f30491c60f9ac42a8db039be204", size = 266194, upload-time = "2025-11-03T21:31:51.53Z" }, - { url = "https://files.pythonhosted.org/packages/1a/67/3b92df89f179d7c367be654ab5626ae311cb28f7d5c237b6bb976cd5fbbb/regex-2025.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:e721d1b46e25c481dc5ded6f4b3f66c897c58d2e8cfdf77bbced84339108b0b9", size = 277069, upload-time = "2025-11-03T21:31:53.151Z" }, - { url = "https://files.pythonhosted.org/packages/d7/55/85ba4c066fe5094d35b249c3ce8df0ba623cfd35afb22d6764f23a52a1c5/regex-2025.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:64350685ff08b1d3a6fff33f45a9ca183dc1d58bbfe4981604e70ec9801bbc26", size = 270330, upload-time = "2025-11-03T21:31:54.514Z" }, + { url = "https://files.pythonhosted.org/packages/04/db/8cbfd0ba3f302f2d09dd0019a9fcab74b63fee77a76c937d0e33161fb8c1/regex-2026.2.28-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:e621fb7c8dc147419b28e1702f58a0177ff8308a76fa295c71f3e7827849f5d9", size = 488462, upload-time = "2026-02-28T02:16:22.616Z" }, + { url = "https://files.pythonhosted.org/packages/5d/10/ccc22c52802223f2368731964ddd117799e1390ffc39dbb31634a83022ee/regex-2026.2.28-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0d5bef2031cbf38757a0b0bc4298bb4824b6332d28edc16b39247228fbdbad97", size = 290774, upload-time = "2026-02-28T02:16:23.993Z" }, + { url = "https://files.pythonhosted.org/packages/62/b9/6796b3bf3101e64117201aaa3a5a030ec677ecf34b3cd6141b5d5c6c67d5/regex-2026.2.28-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bcb399ed84eabf4282587ba151f2732ad8168e66f1d3f85b1d038868fe547703", size = 288724, upload-time = "2026-02-28T02:16:25.403Z" }, + { url = "https://files.pythonhosted.org/packages/9c/02/291c0ae3f3a10cea941d0f5366da1843d8d1fa8a25b0671e20a0e454bb38/regex-2026.2.28-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7c1b34dfa72f826f535b20712afa9bb3ba580020e834f3c69866c5bddbf10098", size = 791924, upload-time = "2026-02-28T02:16:26.863Z" }, + { url = "https://files.pythonhosted.org/packages/0f/57/f0235cc520d9672742196c5c15098f8f703f2758d48d5a7465a56333e496/regex-2026.2.28-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:851fa70df44325e1e4cdb79c5e676e91a78147b1b543db2aec8734d2add30ec2", size = 860095, upload-time = "2026-02-28T02:16:28.772Z" }, + { url = "https://files.pythonhosted.org/packages/b3/7c/393c94cbedda79a0f5f2435ebd01644aba0b338d327eb24b4aa5b8d6c07f/regex-2026.2.28-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:516604edd17b1c2c3e579cf4e9b25a53bf8fa6e7cedddf1127804d3e0140ca64", size = 906583, upload-time = "2026-02-28T02:16:30.977Z" }, + { url = "https://files.pythonhosted.org/packages/2c/73/a72820f47ca5abf2b5d911d0407ba5178fc52cf9780191ed3a54f5f419a2/regex-2026.2.28-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e7ce83654d1ab701cb619285a18a8e5a889c1216d746ddc710c914ca5fd71022", size = 800234, upload-time = "2026-02-28T02:16:32.55Z" }, + { url = "https://files.pythonhosted.org/packages/34/b3/6e6a4b7b31fa998c4cf159a12cbeaf356386fbd1a8be743b1e80a3da51e4/regex-2026.2.28-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f2791948f7c70bb9335a9102df45e93d428f4b8128020d85920223925d73b9e1", size = 772803, upload-time = "2026-02-28T02:16:34.029Z" }, + { url = "https://files.pythonhosted.org/packages/10/e7/5da0280c765d5a92af5e1cd324b3fe8464303189cbaa449de9a71910e273/regex-2026.2.28-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:03a83cc26aa2acda6b8b9dfe748cf9e84cbd390c424a1de34fdcef58961a297a", size = 781117, upload-time = "2026-02-28T02:16:36.253Z" }, + { url = "https://files.pythonhosted.org/packages/76/39/0b8d7efb256ae34e1b8157acc1afd8758048a1cf0196e1aec2e71fd99f4b/regex-2026.2.28-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ec6f5674c5dc836994f50f1186dd1fafde4be0666aae201ae2fcc3d29d8adf27", size = 854224, upload-time = "2026-02-28T02:16:38.119Z" }, + { url = "https://files.pythonhosted.org/packages/21/ff/a96d483ebe8fe6d1c67907729202313895d8de8495569ec319c6f29d0438/regex-2026.2.28-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:50c2fc924749543e0eacc93ada6aeeb3ea5f6715825624baa0dccaec771668ae", size = 761898, upload-time = "2026-02-28T02:16:40.333Z" }, + { url = "https://files.pythonhosted.org/packages/89/bd/d4f2e75cb4a54b484e796017e37c0d09d8a0a837de43d17e238adf163f4e/regex-2026.2.28-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:ba55c50f408fb5c346a3a02d2ce0ebc839784e24f7c9684fde328ff063c3cdea", size = 844832, upload-time = "2026-02-28T02:16:41.875Z" }, + { url = "https://files.pythonhosted.org/packages/8a/a7/428a135cf5e15e4e11d1e696eb2bf968362f8ea8a5f237122e96bc2ae950/regex-2026.2.28-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:edb1b1b3a5576c56f08ac46f108c40333f222ebfd5cf63afdfa3aab0791ebe5b", size = 788347, upload-time = "2026-02-28T02:16:43.472Z" }, + { url = "https://files.pythonhosted.org/packages/a9/59/68691428851cf9c9c3707217ab1d9b47cfeec9d153a49919e6c368b9e926/regex-2026.2.28-cp311-cp311-win32.whl", hash = "sha256:948c12ef30ecedb128903c2c2678b339746eb7c689c5c21957c4a23950c96d15", size = 266033, upload-time = "2026-02-28T02:16:45.094Z" }, + { url = "https://files.pythonhosted.org/packages/42/8b/1483de1c57024e89296cbcceb9cccb3f625d416ddb46e570be185c9b05a9/regex-2026.2.28-cp311-cp311-win_amd64.whl", hash = "sha256:fd63453f10d29097cc3dc62d070746523973fb5aa1c66d25f8558bebd47fed61", size = 277978, upload-time = "2026-02-28T02:16:46.75Z" }, + { url = "https://files.pythonhosted.org/packages/a4/36/abec45dc6e7252e3dbc797120496e43bb5730a7abf0d9cb69340696a2f2d/regex-2026.2.28-cp311-cp311-win_arm64.whl", hash = "sha256:00f2b8d9615aa165fdff0a13f1a92049bfad555ee91e20d246a51aa0b556c60a", size = 270340, upload-time = "2026-02-28T02:16:48.626Z" }, + { url = "https://files.pythonhosted.org/packages/07/42/9061b03cf0fc4b5fa2c3984cbbaed54324377e440a5c5a29d29a72518d62/regex-2026.2.28-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fcf26c3c6d0da98fada8ae4ef0aa1c3405a431c0a77eb17306d38a89b02adcd7", size = 489574, upload-time = "2026-02-28T02:16:50.455Z" }, + { url = "https://files.pythonhosted.org/packages/77/83/0c8a5623a233015595e3da499c5a1c13720ac63c107897a6037bb97af248/regex-2026.2.28-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02473c954af35dd2defeb07e44182f5705b30ea3f351a7cbffa9177beb14da5d", size = 291426, upload-time = "2026-02-28T02:16:52.52Z" }, + { url = "https://files.pythonhosted.org/packages/9e/06/3ef1ac6910dc3295ebd71b1f9bfa737e82cfead211a18b319d45f85ddd09/regex-2026.2.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b65d33a17101569f86d9c5966a8b1d7fbf8afdda5a8aa219301b0a80f58cf7d", size = 289200, upload-time = "2026-02-28T02:16:54.08Z" }, + { url = "https://files.pythonhosted.org/packages/dd/c9/8cc8d850b35ab5650ff6756a1cb85286e2000b66c97520b29c1587455344/regex-2026.2.28-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e71dcecaa113eebcc96622c17692672c2d104b1d71ddf7adeda90da7ddeb26fc", size = 796765, upload-time = "2026-02-28T02:16:55.905Z" }, + { url = "https://files.pythonhosted.org/packages/e9/5d/57702597627fc23278ebf36fbb497ac91c0ce7fec89ac6c81e420ca3e38c/regex-2026.2.28-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:481df4623fa4969c8b11f3433ed7d5e3dc9cec0f008356c3212b3933fb77e3d8", size = 863093, upload-time = "2026-02-28T02:16:58.094Z" }, + { url = "https://files.pythonhosted.org/packages/02/6d/f3ecad537ca2811b4d26b54ca848cf70e04fcfc138667c146a9f3157779c/regex-2026.2.28-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64e7c6ad614573e0640f271e811a408d79a9e1fe62a46adb602f598df42a818d", size = 909455, upload-time = "2026-02-28T02:17:00.918Z" }, + { url = "https://files.pythonhosted.org/packages/9e/40/bb226f203caa22c1043c1ca79b36340156eca0f6a6742b46c3bb222a3a57/regex-2026.2.28-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6b08a06976ff4fb0d83077022fde3eca06c55432bb997d8c0495b9a4e9872f4", size = 802037, upload-time = "2026-02-28T02:17:02.842Z" }, + { url = "https://files.pythonhosted.org/packages/44/7c/c6d91d8911ac6803b45ca968e8e500c46934e58c0903cbc6d760ee817a0a/regex-2026.2.28-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:864cdd1a2ef5716b0ab468af40139e62ede1b3a53386b375ec0786bb6783fc05", size = 775113, upload-time = "2026-02-28T02:17:04.506Z" }, + { url = "https://files.pythonhosted.org/packages/dc/8d/4a9368d168d47abd4158580b8c848709667b1cd293ff0c0c277279543bd0/regex-2026.2.28-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:511f7419f7afab475fd4d639d4aedfc54205bcb0800066753ef68a59f0f330b5", size = 784194, upload-time = "2026-02-28T02:17:06.888Z" }, + { url = "https://files.pythonhosted.org/packages/cc/bf/2c72ab5d8b7be462cb1651b5cc333da1d0068740342f350fcca3bca31947/regex-2026.2.28-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b42f7466e32bf15a961cf09f35fa6323cc72e64d3d2c990b10de1274a5da0a59", size = 856846, upload-time = "2026-02-28T02:17:09.11Z" }, + { url = "https://files.pythonhosted.org/packages/7c/f4/6b65c979bb6d09f51bb2d2a7bc85de73c01ec73335d7ddd202dcb8cd1c8f/regex-2026.2.28-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8710d61737b0c0ce6836b1da7109f20d495e49b3809f30e27e9560be67a257bf", size = 763516, upload-time = "2026-02-28T02:17:11.004Z" }, + { url = "https://files.pythonhosted.org/packages/8e/32/29ea5e27400ee86d2cc2b4e80aa059df04eaf78b4f0c18576ae077aeff68/regex-2026.2.28-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4390c365fd2d45278f45afd4673cb90f7285f5701607e3ad4274df08e36140ae", size = 849278, upload-time = "2026-02-28T02:17:12.693Z" }, + { url = "https://files.pythonhosted.org/packages/1d/91/3233d03b5f865111cd517e1c95ee8b43e8b428d61fa73764a80c9bb6f537/regex-2026.2.28-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb3b1db8ff6c7b8bf838ab05583ea15230cb2f678e569ab0e3a24d1e8320940b", size = 790068, upload-time = "2026-02-28T02:17:14.9Z" }, + { url = "https://files.pythonhosted.org/packages/76/92/abc706c1fb03b4580a09645b206a3fc032f5a9f457bc1a8038ac555658ab/regex-2026.2.28-cp312-cp312-win32.whl", hash = "sha256:f8ed9a5d4612df9d4de15878f0bc6aa7a268afbe5af21a3fdd97fa19516e978c", size = 266416, upload-time = "2026-02-28T02:17:17.15Z" }, + { url = "https://files.pythonhosted.org/packages/fa/06/2a6f7dff190e5fa9df9fb4acf2fdf17a1aa0f7f54596cba8de608db56b3a/regex-2026.2.28-cp312-cp312-win_amd64.whl", hash = "sha256:01d65fd24206c8e1e97e2e31b286c59009636c022eb5d003f52760b0f42155d4", size = 277297, upload-time = "2026-02-28T02:17:18.723Z" }, + { url = "https://files.pythonhosted.org/packages/b7/f0/58a2484851fadf284458fdbd728f580d55c1abac059ae9f048c63b92f427/regex-2026.2.28-cp312-cp312-win_arm64.whl", hash = "sha256:c0b5ccbb8ffb433939d248707d4a8b31993cb76ab1a0187ca886bf50e96df952", size = 270408, upload-time = "2026-02-28T02:17:20.328Z" }, ] [[package]] @@ -6216,15 +6378,15 @@ wheels = [ [[package]] name = "resend" -version = "2.9.0" +version = "2.26.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/2a/535a794e5b64f6ef4abc1342ef1a43465af2111c5185e98b4cca2a6b6b7a/resend-2.9.0.tar.gz", hash = "sha256:e8d4c909a7fe7701119789f848a6befb0a4a668e2182d7bbfe764742f1952bd3", size = 13600, upload-time = "2025-05-06T00:35:20.363Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/ff/6a4e5e758fc2145c6a7d8563934d8ee24bf96a0212d7ec7d1af1f155bb74/resend-2.26.0.tar.gz", hash = "sha256:957a6a59dc597ce27fbd6d5383220dd9cc497fab99d4f3d775c8a42a449a569e", size = 36238, upload-time = "2026-03-20T22:49:09.728Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/96/81/ba1feb9959bafbcde6466b78d4628405d69cd14613f6eba12b928a77b86a/resend-2.9.0-py2.py3-none-any.whl", hash = "sha256:6607f75e3a9257a219c0640f935b8d1211338190d553eb043c25732affb92949", size = 20173, upload-time = "2025-05-06T00:35:18.963Z" }, + { url = "https://files.pythonhosted.org/packages/16/c2/f88d3299d97aa1d36a923d0846fe185fcf5355ca898c954b2e5a79f090b5/resend-2.26.0-py2.py3-none-any.whl", hash = "sha256:5e25a804a84a68df504f2ade5369ac37e0139e37788a1f20b66c88696595b4bc", size = 57699, upload-time = "2026-03-20T22:49:08.354Z" }, ] [[package]] @@ -6240,126 +6402,104 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4b/0d/53aea75710af4528a25ed6837d71d117602b01946b307a3912cb3cfcbcba/retry-0.9.2-py2.py3-none-any.whl", hash = "sha256:ccddf89761fa2c726ab29391837d4327f819ea14d244c232a1d24c67a2f98606", size = 7986, upload-time = "2016-05-11T13:58:39.925Z" }, ] -[[package]] -name = "rfc3986" -version = "2.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/85/40/1520d68bfa07ab5a6f065a186815fb6610c86fe957bc065754e47f7b0840/rfc3986-2.0.0.tar.gz", hash = "sha256:97aacf9dbd4bfd829baad6e6309fa6573aaf1be3f6fa735c8ab05e46cecb261c", size = 49026, upload-time = "2022-01-10T00:52:30.832Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/9a/9afaade874b2fa6c752c36f1548f718b5b83af81ed9b76628329dab81c1b/rfc3986-2.0.0-py2.py3-none-any.whl", hash = "sha256:50b1502b60e289cb37883f3dfd34532b8873c7de9f49bb546641ce9cbd256ebd", size = 31326, upload-time = "2022-01-10T00:52:29.594Z" }, -] - [[package]] name = "rich" -version = "13.9.4" +version = "14.3.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ab/3a/0316b28d0761c6734d6bc14e770d85506c986c85ffb239e688eeaab2c2bc/rich-13.9.4.tar.gz", hash = "sha256:439594978a49a09530cff7ebc4b5c7103ef57baf48d5ea3184f21d9a2befa098", size = 223149, upload-time = "2024-11-01T16:43:57.873Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/c6/f3b320c27991c46f43ee9d856302c70dc2d0fb2dba4842ff739d5f46b393/rich-14.3.3.tar.gz", hash = "sha256:b8daa0b9e4eef54dd8cf7c86c03713f53241884e814f4e2f5fb342fe520f639b", size = 230582, upload-time = "2026-02-19T17:23:12.474Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/19/71/39c7c0d87f8d4e6c020a393182060eaefeeae6c01dab6a84ec346f2567df/rich-13.9.4-py3-none-any.whl", hash = "sha256:6049d5e6ec054bf2779ab3358186963bac2ea89175919d699e378b99738c2a90", size = 242424, upload-time = "2024-11-01T16:43:55.817Z" }, + { url = "https://files.pythonhosted.org/packages/14/25/b208c5683343959b670dc001595f2f3737e051da617f66c31f7c4fa93abc/rich-14.3.3-py3-none-any.whl", hash = "sha256:793431c1f8619afa7d3b52b2cdec859562b950ea0d4b6b505397612db8d5362d", size = 310458, upload-time = "2026-02-19T17:23:13.732Z" }, ] [[package]] name = "rpds-py" -version = "0.29.0" +version = "0.30.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/98/33/23b3b3419b6a3e0f559c7c0d2ca8fc1b9448382b25245033788785921332/rpds_py-0.29.0.tar.gz", hash = "sha256:fe55fe686908f50154d1dc599232016e50c243b438c3b7432f24e2895b0e5359", size = 69359, upload-time = "2025-11-16T14:50:39.532Z" } +sdist = { url = "https://files.pythonhosted.org/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84", size = 69469, upload-time = "2025-11-30T20:24:38.837Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/ab/7fb95163a53ab122c74a7c42d2d2f012819af2cf3deb43fb0d5acf45cc1a/rpds_py-0.29.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:9b9c764a11fd637e0322a488560533112837f5334ffeb48b1be20f6d98a7b437", size = 372344, upload-time = "2025-11-16T14:47:57.279Z" }, - { url = "https://files.pythonhosted.org/packages/b3/45/f3c30084c03b0d0f918cb4c5ae2c20b0a148b51ba2b3f6456765b629bedd/rpds_py-0.29.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3fd2164d73812026ce970d44c3ebd51e019d2a26a4425a5dcbdfa93a34abc383", size = 363041, upload-time = "2025-11-16T14:47:58.908Z" }, - { url = "https://files.pythonhosted.org/packages/e3/e9/4d044a1662608c47a87cbb37b999d4d5af54c6d6ebdda93a4d8bbf8b2a10/rpds_py-0.29.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a097b7f7f7274164566ae90a221fd725363c0e9d243e2e9ed43d195ccc5495c", size = 391775, upload-time = "2025-11-16T14:48:00.197Z" }, - { url = "https://files.pythonhosted.org/packages/50/c9/7616d3ace4e6731aeb6e3cd85123e03aec58e439044e214b9c5c60fd8eb1/rpds_py-0.29.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:7cdc0490374e31cedefefaa1520d5fe38e82fde8748cbc926e7284574c714d6b", size = 405624, upload-time = "2025-11-16T14:48:01.496Z" }, - { url = "https://files.pythonhosted.org/packages/c2/e2/6d7d6941ca0843609fd2d72c966a438d6f22617baf22d46c3d2156c31350/rpds_py-0.29.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:89ca2e673ddd5bde9b386da9a0aac0cab0e76f40c8f0aaf0d6311b6bbf2aa311", size = 527894, upload-time = "2025-11-16T14:48:03.167Z" }, - { url = "https://files.pythonhosted.org/packages/8d/f7/aee14dc2db61bb2ae1e3068f134ca9da5f28c586120889a70ff504bb026f/rpds_py-0.29.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a5d9da3ff5af1ca1249b1adb8ef0573b94c76e6ae880ba1852f033bf429d4588", size = 412720, upload-time = "2025-11-16T14:48:04.413Z" }, - { url = "https://files.pythonhosted.org/packages/2f/e2/2293f236e887c0360c2723d90c00d48dee296406994d6271faf1712e94ec/rpds_py-0.29.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8238d1d310283e87376c12f658b61e1ee23a14c0e54c7c0ce953efdbdc72deed", size = 392945, upload-time = "2025-11-16T14:48:06.252Z" }, - { url = "https://files.pythonhosted.org/packages/14/cd/ceea6147acd3bd1fd028d1975228f08ff19d62098078d5ec3eed49703797/rpds_py-0.29.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:2d6fb2ad1c36f91c4646989811e84b1ea5e0c3cf9690b826b6e32b7965853a63", size = 406385, upload-time = "2025-11-16T14:48:07.575Z" }, - { url = "https://files.pythonhosted.org/packages/52/36/fe4dead19e45eb77a0524acfdbf51e6cda597b26fc5b6dddbff55fbbb1a5/rpds_py-0.29.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:534dc9df211387547267ccdb42253aa30527482acb38dd9b21c5c115d66a96d2", size = 423943, upload-time = "2025-11-16T14:48:10.175Z" }, - { url = "https://files.pythonhosted.org/packages/a1/7b/4551510803b582fa4abbc8645441a2d15aa0c962c3b21ebb380b7e74f6a1/rpds_py-0.29.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d456e64724a075441e4ed648d7f154dc62e9aabff29bcdf723d0c00e9e1d352f", size = 574204, upload-time = "2025-11-16T14:48:11.499Z" }, - { url = "https://files.pythonhosted.org/packages/64/ba/071ccdd7b171e727a6ae079f02c26f75790b41555f12ca8f1151336d2124/rpds_py-0.29.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:a738f2da2f565989401bd6fd0b15990a4d1523c6d7fe83f300b7e7d17212feca", size = 600587, upload-time = "2025-11-16T14:48:12.822Z" }, - { url = "https://files.pythonhosted.org/packages/03/09/96983d48c8cf5a1e03c7d9cc1f4b48266adfb858ae48c7c2ce978dbba349/rpds_py-0.29.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a110e14508fd26fd2e472bb541f37c209409876ba601cf57e739e87d8a53cf95", size = 562287, upload-time = "2025-11-16T14:48:14.108Z" }, - { url = "https://files.pythonhosted.org/packages/40/f0/8c01aaedc0fa92156f0391f39ea93b5952bc0ec56b897763858f95da8168/rpds_py-0.29.0-cp311-cp311-win32.whl", hash = "sha256:923248a56dd8d158389a28934f6f69ebf89f218ef96a6b216a9be6861804d3f4", size = 221394, upload-time = "2025-11-16T14:48:15.374Z" }, - { url = "https://files.pythonhosted.org/packages/7e/a5/a8b21c54c7d234efdc83dc034a4d7cd9668e3613b6316876a29b49dece71/rpds_py-0.29.0-cp311-cp311-win_amd64.whl", hash = "sha256:539eb77eb043afcc45314d1be09ea6d6cafb3addc73e0547c171c6d636957f60", size = 235713, upload-time = "2025-11-16T14:48:16.636Z" }, - { url = "https://files.pythonhosted.org/packages/a7/1f/df3c56219523947b1be402fa12e6323fe6d61d883cf35d6cb5d5bb6db9d9/rpds_py-0.29.0-cp311-cp311-win_arm64.whl", hash = "sha256:bdb67151ea81fcf02d8f494703fb728d4d34d24556cbff5f417d74f6f5792e7c", size = 229157, upload-time = "2025-11-16T14:48:17.891Z" }, - { url = "https://files.pythonhosted.org/packages/3c/50/bc0e6e736d94e420df79be4deb5c9476b63165c87bb8f19ef75d100d21b3/rpds_py-0.29.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a0891cfd8db43e085c0ab93ab7e9b0c8fee84780d436d3b266b113e51e79f954", size = 376000, upload-time = "2025-11-16T14:48:19.141Z" }, - { url = "https://files.pythonhosted.org/packages/3e/3a/46676277160f014ae95f24de53bed0e3b7ea66c235e7de0b9df7bd5d68ba/rpds_py-0.29.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3897924d3f9a0361472d884051f9a2460358f9a45b1d85a39a158d2f8f1ad71c", size = 360575, upload-time = "2025-11-16T14:48:20.443Z" }, - { url = "https://files.pythonhosted.org/packages/75/ba/411d414ed99ea1afdd185bbabeeaac00624bd1e4b22840b5e9967ade6337/rpds_py-0.29.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2a21deb8e0d1571508c6491ce5ea5e25669b1dd4adf1c9d64b6314842f708b5d", size = 392159, upload-time = "2025-11-16T14:48:22.12Z" }, - { url = "https://files.pythonhosted.org/packages/8f/b1/e18aa3a331f705467a48d0296778dc1fea9d7f6cf675bd261f9a846c7e90/rpds_py-0.29.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9efe71687d6427737a0a2de9ca1c0a216510e6cd08925c44162be23ed7bed2d5", size = 410602, upload-time = "2025-11-16T14:48:23.563Z" }, - { url = "https://files.pythonhosted.org/packages/2f/6c/04f27f0c9f2299274c76612ac9d2c36c5048bb2c6c2e52c38c60bf3868d9/rpds_py-0.29.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:40f65470919dc189c833e86b2c4bd21bd355f98436a2cef9e0a9a92aebc8e57e", size = 515808, upload-time = "2025-11-16T14:48:24.949Z" }, - { url = "https://files.pythonhosted.org/packages/83/56/a8412aa464fb151f8bc0d91fb0bb888adc9039bd41c1c6ba8d94990d8cf8/rpds_py-0.29.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:def48ff59f181130f1a2cb7c517d16328efac3ec03951cca40c1dc2049747e83", size = 416015, upload-time = "2025-11-16T14:48:26.782Z" }, - { url = "https://files.pythonhosted.org/packages/04/4c/f9b8a05faca3d9e0a6397c90d13acb9307c9792b2bff621430c58b1d6e76/rpds_py-0.29.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ad7bd570be92695d89285a4b373006930715b78d96449f686af422debb4d3949", size = 395325, upload-time = "2025-11-16T14:48:28.055Z" }, - { url = "https://files.pythonhosted.org/packages/34/60/869f3bfbf8ed7b54f1ad9a5543e0fdffdd40b5a8f587fe300ee7b4f19340/rpds_py-0.29.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:5a572911cd053137bbff8e3a52d31c5d2dba51d3a67ad902629c70185f3f2181", size = 410160, upload-time = "2025-11-16T14:48:29.338Z" }, - { url = "https://files.pythonhosted.org/packages/91/aa/e5b496334e3aba4fe4c8a80187b89f3c1294c5c36f2a926da74338fa5a73/rpds_py-0.29.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d583d4403bcbf10cffc3ab5cee23d7643fcc960dff85973fd3c2d6c86e8dbb0c", size = 425309, upload-time = "2025-11-16T14:48:30.691Z" }, - { url = "https://files.pythonhosted.org/packages/85/68/4e24a34189751ceb6d66b28f18159922828dd84155876551f7ca5b25f14f/rpds_py-0.29.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:070befbb868f257d24c3bb350dbd6e2f645e83731f31264b19d7231dd5c396c7", size = 574644, upload-time = "2025-11-16T14:48:31.964Z" }, - { url = "https://files.pythonhosted.org/packages/8c/cf/474a005ea4ea9c3b4f17b6108b6b13cebfc98ebaff11d6e1b193204b3a93/rpds_py-0.29.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:fc935f6b20b0c9f919a8ff024739174522abd331978f750a74bb68abd117bd19", size = 601605, upload-time = "2025-11-16T14:48:33.252Z" }, - { url = "https://files.pythonhosted.org/packages/f4/b1/c56f6a9ab8c5f6bb5c65c4b5f8229167a3a525245b0773f2c0896686b64e/rpds_py-0.29.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8c5a8ecaa44ce2d8d9d20a68a2483a74c07f05d72e94a4dff88906c8807e77b0", size = 564593, upload-time = "2025-11-16T14:48:34.643Z" }, - { url = "https://files.pythonhosted.org/packages/b3/13/0494cecce4848f68501e0a229432620b4b57022388b071eeff95f3e1e75b/rpds_py-0.29.0-cp312-cp312-win32.whl", hash = "sha256:ba5e1aeaf8dd6d8f6caba1f5539cddda87d511331714b7b5fc908b6cfc3636b7", size = 223853, upload-time = "2025-11-16T14:48:36.419Z" }, - { url = "https://files.pythonhosted.org/packages/1f/6a/51e9aeb444a00cdc520b032a28b07e5f8dc7bc328b57760c53e7f96997b4/rpds_py-0.29.0-cp312-cp312-win_amd64.whl", hash = "sha256:b5f6134faf54b3cb83375db0f113506f8b7770785be1f95a631e7e2892101977", size = 239895, upload-time = "2025-11-16T14:48:37.956Z" }, - { url = "https://files.pythonhosted.org/packages/d1/d4/8bce56cdad1ab873e3f27cb31c6a51d8f384d66b022b820525b879f8bed1/rpds_py-0.29.0-cp312-cp312-win_arm64.whl", hash = "sha256:b016eddf00dca7944721bf0cd85b6af7f6c4efaf83ee0b37c4133bd39757a8c7", size = 230321, upload-time = "2025-11-16T14:48:39.71Z" }, - { url = "https://files.pythonhosted.org/packages/f2/ac/b97e80bf107159e5b9ba9c91df1ab95f69e5e41b435f27bdd737f0d583ac/rpds_py-0.29.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:acd82a9e39082dc5f4492d15a6b6c8599aa21db5c35aaf7d6889aea16502c07d", size = 373963, upload-time = "2025-11-16T14:50:16.205Z" }, - { url = "https://files.pythonhosted.org/packages/40/5a/55e72962d5d29bd912f40c594e68880d3c7a52774b0f75542775f9250712/rpds_py-0.29.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:715b67eac317bf1c7657508170a3e011a1ea6ccb1c9d5f296e20ba14196be6b3", size = 364644, upload-time = "2025-11-16T14:50:18.22Z" }, - { url = "https://files.pythonhosted.org/packages/99/2a/6b6524d0191b7fc1351c3c0840baac42250515afb48ae40c7ed15499a6a2/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3b1b87a237cb2dba4db18bcfaaa44ba4cd5936b91121b62292ff21df577fc43", size = 393847, upload-time = "2025-11-16T14:50:20.012Z" }, - { url = "https://files.pythonhosted.org/packages/1c/b8/c5692a7df577b3c0c7faed7ac01ee3c608b81750fc5d89f84529229b6873/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1c3c3e8101bb06e337c88eb0c0ede3187131f19d97d43ea0e1c5407ea74c0cbf", size = 407281, upload-time = "2025-11-16T14:50:21.64Z" }, - { url = "https://files.pythonhosted.org/packages/f0/57/0546c6f84031b7ea08b76646a8e33e45607cc6bd879ff1917dc077bb881e/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2b8e54d6e61f3ecd3abe032065ce83ea63417a24f437e4a3d73d2f85ce7b7cfe", size = 529213, upload-time = "2025-11-16T14:50:23.219Z" }, - { url = "https://files.pythonhosted.org/packages/fa/c1/01dd5f444233605555bc11fe5fed6a5c18f379f02013870c176c8e630a23/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3fbd4e9aebf110473a420dea85a238b254cf8a15acb04b22a5a6b5ce8925b760", size = 413808, upload-time = "2025-11-16T14:50:25.262Z" }, - { url = "https://files.pythonhosted.org/packages/aa/0a/60f98b06156ea2a7af849fb148e00fbcfdb540909a5174a5ed10c93745c7/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80fdf53d36e6c72819993e35d1ebeeb8e8fc688d0c6c2b391b55e335b3afba5a", size = 394600, upload-time = "2025-11-16T14:50:26.956Z" }, - { url = "https://files.pythonhosted.org/packages/37/f1/dc9312fc9bec040ece08396429f2bd9e0977924ba7a11c5ad7056428465e/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_31_riscv64.whl", hash = "sha256:ea7173df5d86f625f8dde6d5929629ad811ed8decda3b60ae603903839ac9ac0", size = 408634, upload-time = "2025-11-16T14:50:28.989Z" }, - { url = "https://files.pythonhosted.org/packages/ed/41/65024c9fd40c89bb7d604cf73beda4cbdbcebe92d8765345dd65855b6449/rpds_py-0.29.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:76054d540061eda273274f3d13a21a4abdde90e13eaefdc205db37c05230efce", size = 426064, upload-time = "2025-11-16T14:50:30.674Z" }, - { url = "https://files.pythonhosted.org/packages/a2/e0/cf95478881fc88ca2fdbf56381d7df36567cccc39a05394beac72182cd62/rpds_py-0.29.0-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:9f84c549746a5be3bc7415830747a3a0312573afc9f95785eb35228bb17742ec", size = 575871, upload-time = "2025-11-16T14:50:33.428Z" }, - { url = "https://files.pythonhosted.org/packages/ea/c0/df88097e64339a0218b57bd5f9ca49898e4c394db756c67fccc64add850a/rpds_py-0.29.0-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:0ea962671af5cb9a260489e311fa22b2e97103e3f9f0caaea6f81390af96a9ed", size = 601702, upload-time = "2025-11-16T14:50:36.051Z" }, - { url = "https://files.pythonhosted.org/packages/87/f4/09ffb3ebd0cbb9e2c7c9b84d252557ecf434cd71584ee1e32f66013824df/rpds_py-0.29.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:f7728653900035fb7b8d06e1e5900545d8088efc9d5d4545782da7df03ec803f", size = 564054, upload-time = "2025-11-16T14:50:37.733Z" }, -] - -[[package]] -name = "rsa" -version = "4.9.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyasn1" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/da/8a/22b7beea3ee0d44b1916c0c1cb0ee3af23b700b6da9f04991899d0c555d4/rsa-4.9.1.tar.gz", hash = "sha256:e7bdbfdb5497da4c07dfd35530e1a902659db6ff241e39d9953cad06ebd0ae75", size = 29034, upload-time = "2025-04-16T09:51:18.218Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/64/8d/0133e4eb4beed9e425d9a98ed6e081a55d195481b7632472be1af08d2f6b/rsa-4.9.1-py3-none-any.whl", hash = "sha256:68635866661c6836b8d39430f97a996acbd61bfa49406748ea243539fe239762", size = 34696, upload-time = "2025-04-16T09:51:17.142Z" }, + { url = "https://files.pythonhosted.org/packages/4d/6e/f964e88b3d2abee2a82c1ac8366da848fce1c6d834dc2132c3fda3970290/rpds_py-0.30.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:a2bffea6a4ca9f01b3f8e548302470306689684e61602aa3d141e34da06cf425", size = 370157, upload-time = "2025-11-30T20:21:53.789Z" }, + { url = "https://files.pythonhosted.org/packages/94/ba/24e5ebb7c1c82e74c4e4f33b2112a5573ddc703915b13a073737b59b86e0/rpds_py-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dc4f992dfe1e2bc3ebc7444f6c7051b4bc13cd8e33e43511e8ffd13bf407010d", size = 359676, upload-time = "2025-11-30T20:21:55.475Z" }, + { url = "https://files.pythonhosted.org/packages/84/86/04dbba1b087227747d64d80c3b74df946b986c57af0a9f0c98726d4d7a3b/rpds_py-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:422c3cb9856d80b09d30d2eb255d0754b23e090034e1deb4083f8004bd0761e4", size = 389938, upload-time = "2025-11-30T20:21:57.079Z" }, + { url = "https://files.pythonhosted.org/packages/42/bb/1463f0b1722b7f45431bdd468301991d1328b16cffe0b1c2918eba2c4eee/rpds_py-0.30.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:07ae8a593e1c3c6b82ca3292efbe73c30b61332fd612e05abee07c79359f292f", size = 402932, upload-time = "2025-11-30T20:21:58.47Z" }, + { url = "https://files.pythonhosted.org/packages/99/ee/2520700a5c1f2d76631f948b0736cdf9b0acb25abd0ca8e889b5c62ac2e3/rpds_py-0.30.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12f90dd7557b6bd57f40abe7747e81e0c0b119bef015ea7726e69fe550e394a4", size = 525830, upload-time = "2025-11-30T20:21:59.699Z" }, + { url = "https://files.pythonhosted.org/packages/e0/ad/bd0331f740f5705cc555a5e17fdf334671262160270962e69a2bdef3bf76/rpds_py-0.30.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:99b47d6ad9a6da00bec6aabe5a6279ecd3c06a329d4aa4771034a21e335c3a97", size = 412033, upload-time = "2025-11-30T20:22:00.991Z" }, + { url = "https://files.pythonhosted.org/packages/f8/1e/372195d326549bb51f0ba0f2ecb9874579906b97e08880e7a65c3bef1a99/rpds_py-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:33f559f3104504506a44bb666b93a33f5d33133765b0c216a5bf2f1e1503af89", size = 390828, upload-time = "2025-11-30T20:22:02.723Z" }, + { url = "https://files.pythonhosted.org/packages/ab/2b/d88bb33294e3e0c76bc8f351a3721212713629ffca1700fa94979cb3eae8/rpds_py-0.30.0-cp311-cp311-manylinux_2_31_riscv64.whl", hash = "sha256:946fe926af6e44f3697abbc305ea168c2c31d3e3ef1058cf68f379bf0335a78d", size = 404683, upload-time = "2025-11-30T20:22:04.367Z" }, + { url = "https://files.pythonhosted.org/packages/50/32/c759a8d42bcb5289c1fac697cd92f6fe01a018dd937e62ae77e0e7f15702/rpds_py-0.30.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:495aeca4b93d465efde585977365187149e75383ad2684f81519f504f5c13038", size = 421583, upload-time = "2025-11-30T20:22:05.814Z" }, + { url = "https://files.pythonhosted.org/packages/2b/81/e729761dbd55ddf5d84ec4ff1f47857f4374b0f19bdabfcf929164da3e24/rpds_py-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:d9a0ca5da0386dee0655b4ccdf46119df60e0f10da268d04fe7cc87886872ba7", size = 572496, upload-time = "2025-11-30T20:22:07.713Z" }, + { url = "https://files.pythonhosted.org/packages/14/f6/69066a924c3557c9c30baa6ec3a0aa07526305684c6f86c696b08860726c/rpds_py-0.30.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:8d6d1cc13664ec13c1b84241204ff3b12f9bb82464b8ad6e7a5d3486975c2eed", size = 598669, upload-time = "2025-11-30T20:22:09.312Z" }, + { url = "https://files.pythonhosted.org/packages/5f/48/905896b1eb8a05630d20333d1d8ffd162394127b74ce0b0784ae04498d32/rpds_py-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3896fa1be39912cf0757753826bc8bdc8ca331a28a7c4ae46b7a21280b06bb85", size = 561011, upload-time = "2025-11-30T20:22:11.309Z" }, + { url = "https://files.pythonhosted.org/packages/22/16/cd3027c7e279d22e5eb431dd3c0fbc677bed58797fe7581e148f3f68818b/rpds_py-0.30.0-cp311-cp311-win32.whl", hash = "sha256:55f66022632205940f1827effeff17c4fa7ae1953d2b74a8581baaefb7d16f8c", size = 221406, upload-time = "2025-11-30T20:22:13.101Z" }, + { url = "https://files.pythonhosted.org/packages/fa/5b/e7b7aa136f28462b344e652ee010d4de26ee9fd16f1bfd5811f5153ccf89/rpds_py-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:a51033ff701fca756439d641c0ad09a41d9242fa69121c7d8769604a0a629825", size = 236024, upload-time = "2025-11-30T20:22:14.853Z" }, + { url = "https://files.pythonhosted.org/packages/14/a6/364bba985e4c13658edb156640608f2c9e1d3ea3c81b27aa9d889fff0e31/rpds_py-0.30.0-cp311-cp311-win_arm64.whl", hash = "sha256:47b0ef6231c58f506ef0b74d44e330405caa8428e770fec25329ed2cb971a229", size = 229069, upload-time = "2025-11-30T20:22:16.577Z" }, + { url = "https://files.pythonhosted.org/packages/03/e7/98a2f4ac921d82f33e03f3835f5bf3a4a40aa1bfdc57975e74a97b2b4bdd/rpds_py-0.30.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a161f20d9a43006833cd7068375a94d035714d73a172b681d8881820600abfad", size = 375086, upload-time = "2025-11-30T20:22:17.93Z" }, + { url = "https://files.pythonhosted.org/packages/4d/a1/bca7fd3d452b272e13335db8d6b0b3ecde0f90ad6f16f3328c6fb150c889/rpds_py-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6abc8880d9d036ecaafe709079969f56e876fcf107f7a8e9920ba6d5a3878d05", size = 359053, upload-time = "2025-11-30T20:22:19.297Z" }, + { url = "https://files.pythonhosted.org/packages/65/1c/ae157e83a6357eceff62ba7e52113e3ec4834a84cfe07fa4b0757a7d105f/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca28829ae5f5d569bb62a79512c842a03a12576375d5ece7d2cadf8abe96ec28", size = 390763, upload-time = "2025-11-30T20:22:21.661Z" }, + { url = "https://files.pythonhosted.org/packages/d4/36/eb2eb8515e2ad24c0bd43c3ee9cd74c33f7ca6430755ccdb240fd3144c44/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1010ed9524c73b94d15919ca4d41d8780980e1765babf85f9a2f90d247153dd", size = 408951, upload-time = "2025-11-30T20:22:23.408Z" }, + { url = "https://files.pythonhosted.org/packages/d6/65/ad8dc1784a331fabbd740ef6f71ce2198c7ed0890dab595adb9ea2d775a1/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d1736cfb49381ba528cd5baa46f82fdc65c06e843dab24dd70b63d09121b3f", size = 514622, upload-time = "2025-11-30T20:22:25.16Z" }, + { url = "https://files.pythonhosted.org/packages/63/8e/0cfa7ae158e15e143fe03993b5bcd743a59f541f5952e1546b1ac1b5fd45/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d948b135c4693daff7bc2dcfc4ec57237a29bd37e60c2fabf5aff2bbacf3e2f1", size = 414492, upload-time = "2025-11-30T20:22:26.505Z" }, + { url = "https://files.pythonhosted.org/packages/60/1b/6f8f29f3f995c7ffdde46a626ddccd7c63aefc0efae881dc13b6e5d5bb16/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f236970bccb2233267d89173d3ad2703cd36a0e2a6e92d0560d333871a3d23", size = 394080, upload-time = "2025-11-30T20:22:27.934Z" }, + { url = "https://files.pythonhosted.org/packages/6d/d5/a266341051a7a3ca2f4b750a3aa4abc986378431fc2da508c5034d081b70/rpds_py-0.30.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2e6ecb5a5bcacf59c3f912155044479af1d0b6681280048b338b28e364aca1f6", size = 408680, upload-time = "2025-11-30T20:22:29.341Z" }, + { url = "https://files.pythonhosted.org/packages/10/3b/71b725851df9ab7a7a4e33cf36d241933da66040d195a84781f49c50490c/rpds_py-0.30.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a8fa71a2e078c527c3e9dc9fc5a98c9db40bcc8a92b4e8858e36d329f8684b51", size = 423589, upload-time = "2025-11-30T20:22:31.469Z" }, + { url = "https://files.pythonhosted.org/packages/00/2b/e59e58c544dc9bd8bd8384ecdb8ea91f6727f0e37a7131baeff8d6f51661/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73c67f2db7bc334e518d097c6d1e6fed021bbc9b7d678d6cc433478365d1d5f5", size = 573289, upload-time = "2025-11-30T20:22:32.997Z" }, + { url = "https://files.pythonhosted.org/packages/da/3e/a18e6f5b460893172a7d6a680e86d3b6bc87a54c1f0b03446a3c8c7b588f/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5ba103fb455be00f3b1c2076c9d4264bfcb037c976167a6047ed82f23153f02e", size = 599737, upload-time = "2025-11-30T20:22:34.419Z" }, + { url = "https://files.pythonhosted.org/packages/5c/e2/714694e4b87b85a18e2c243614974413c60aa107fd815b8cbc42b873d1d7/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee9c752c0364588353e627da8a7e808a66873672bcb5f52890c33fd965b394", size = 563120, upload-time = "2025-11-30T20:22:35.903Z" }, + { url = "https://files.pythonhosted.org/packages/6f/ab/d5d5e3bcedb0a77f4f613706b750e50a5a3ba1c15ccd3665ecc636c968fd/rpds_py-0.30.0-cp312-cp312-win32.whl", hash = "sha256:1ab5b83dbcf55acc8b08fc62b796ef672c457b17dbd7820a11d6c52c06839bdf", size = 223782, upload-time = "2025-11-30T20:22:37.271Z" }, + { url = "https://files.pythonhosted.org/packages/39/3b/f786af9957306fdc38a74cef405b7b93180f481fb48453a114bb6465744a/rpds_py-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:a090322ca841abd453d43456ac34db46e8b05fd9b3b4ac0c78bcde8b089f959b", size = 240463, upload-time = "2025-11-30T20:22:39.021Z" }, + { url = "https://files.pythonhosted.org/packages/f3/d2/b91dc748126c1559042cfe41990deb92c4ee3e2b415f6b5234969ffaf0cc/rpds_py-0.30.0-cp312-cp312-win_arm64.whl", hash = "sha256:669b1805bd639dd2989b281be2cfd951c6121b65e729d9b843e9639ef1fd555e", size = 230868, upload-time = "2025-11-30T20:22:40.493Z" }, + { url = "https://files.pythonhosted.org/packages/69/71/3f34339ee70521864411f8b6992e7ab13ac30d8e4e3309e07c7361767d91/rpds_py-0.30.0-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:c2262bdba0ad4fc6fb5545660673925c2d2a5d9e2e0fb603aad545427be0fc58", size = 372292, upload-time = "2025-11-30T20:24:16.537Z" }, + { url = "https://files.pythonhosted.org/packages/57/09/f183df9b8f2d66720d2ef71075c59f7e1b336bec7ee4c48f0a2b06857653/rpds_py-0.30.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:ee6af14263f25eedc3bb918a3c04245106a42dfd4f5c2285ea6f997b1fc3f89a", size = 362128, upload-time = "2025-11-30T20:24:18.086Z" }, + { url = "https://files.pythonhosted.org/packages/7a/68/5c2594e937253457342e078f0cc1ded3dd7b2ad59afdbf2d354869110a02/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3adbb8179ce342d235c31ab8ec511e66c73faa27a47e076ccc92421add53e2bb", size = 391542, upload-time = "2025-11-30T20:24:20.092Z" }, + { url = "https://files.pythonhosted.org/packages/49/5c/31ef1afd70b4b4fbdb2800249f34c57c64beb687495b10aec0365f53dfc4/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:250fa00e9543ac9b97ac258bd37367ff5256666122c2d0f2bc97577c60a1818c", size = 404004, upload-time = "2025-11-30T20:24:22.231Z" }, + { url = "https://files.pythonhosted.org/packages/e3/63/0cfbea38d05756f3440ce6534d51a491d26176ac045e2707adc99bb6e60a/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9854cf4f488b3d57b9aaeb105f06d78e5529d3145b1e4a41750167e8c213c6d3", size = 527063, upload-time = "2025-11-30T20:24:24.302Z" }, + { url = "https://files.pythonhosted.org/packages/42/e6/01e1f72a2456678b0f618fc9a1a13f882061690893c192fcad9f2926553a/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:993914b8e560023bc0a8bf742c5f303551992dcb85e247b1e5c7f4a7d145bda5", size = 413099, upload-time = "2025-11-30T20:24:25.916Z" }, + { url = "https://files.pythonhosted.org/packages/b8/25/8df56677f209003dcbb180765520c544525e3ef21ea72279c98b9aa7c7fb/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58edca431fb9b29950807e301826586e5bbf24163677732429770a697ffe6738", size = 392177, upload-time = "2025-11-30T20:24:27.834Z" }, + { url = "https://files.pythonhosted.org/packages/4a/b4/0a771378c5f16f8115f796d1f437950158679bcd2a7c68cf251cfb00ed5b/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_31_riscv64.whl", hash = "sha256:dea5b552272a944763b34394d04577cf0f9bd013207bc32323b5a89a53cf9c2f", size = 406015, upload-time = "2025-11-30T20:24:29.457Z" }, + { url = "https://files.pythonhosted.org/packages/36/d8/456dbba0af75049dc6f63ff295a2f92766b9d521fa00de67a2bd6427d57a/rpds_py-0.30.0-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:ba3af48635eb83d03f6c9735dfb21785303e73d22ad03d489e88adae6eab8877", size = 423736, upload-time = "2025-11-30T20:24:31.22Z" }, + { url = "https://files.pythonhosted.org/packages/13/64/b4d76f227d5c45a7e0b796c674fd81b0a6c4fbd48dc29271857d8219571c/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_aarch64.whl", hash = "sha256:dff13836529b921e22f15cb099751209a60009731a68519630a24d61f0b1b30a", size = 573981, upload-time = "2025-11-30T20:24:32.934Z" }, + { url = "https://files.pythonhosted.org/packages/20/91/092bacadeda3edf92bf743cc96a7be133e13a39cdbfd7b5082e7ab638406/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_i686.whl", hash = "sha256:1b151685b23929ab7beec71080a8889d4d6d9fa9a983d213f07121205d48e2c4", size = 599782, upload-time = "2025-11-30T20:24:35.169Z" }, + { url = "https://files.pythonhosted.org/packages/d1/b7/b95708304cd49b7b6f82fdd039f1748b66ec2b21d6a45180910802f1abf1/rpds_py-0.30.0-pp311-pypy311_pp73-musllinux_1_2_x86_64.whl", hash = "sha256:ac37f9f516c51e5753f27dfdef11a88330f04de2d564be3991384b2f3535d02e", size = 562191, upload-time = "2025-11-30T20:24:36.853Z" }, ] [[package]] name = "ruff" -version = "0.14.6" +version = "0.15.7" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/f0/62b5a1a723fe183650109407fa56abb433b00aa1c0b9ba555f9c4efec2c6/ruff-0.14.6.tar.gz", hash = "sha256:6f0c742ca6a7783a736b867a263b9a7a80a45ce9bee391eeda296895f1b4e1cc", size = 5669501, upload-time = "2025-11-21T14:26:17.903Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/22/9e4f66ee588588dc6c9af6a994e12d26e19efbe874d1a909d09a6dac7a59/ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac", size = 4601277, upload-time = "2026-03-19T16:26:22.605Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/d2/7dd544116d107fffb24a0064d41a5d2ed1c9d6372d142f9ba108c8e39207/ruff-0.14.6-py3-none-linux_armv6l.whl", hash = "sha256:d724ac2f1c240dbd01a2ae98db5d1d9a5e1d9e96eba999d1c48e30062df578a3", size = 13326119, upload-time = "2025-11-21T14:25:24.2Z" }, - { url = "https://files.pythonhosted.org/packages/36/6a/ad66d0a3315d6327ed6b01f759d83df3c4d5f86c30462121024361137b6a/ruff-0.14.6-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:9f7539ea257aa4d07b7ce87aed580e485c40143f2473ff2f2b75aee003186004", size = 13526007, upload-time = "2025-11-21T14:25:26.906Z" }, - { url = "https://files.pythonhosted.org/packages/a3/9d/dae6db96df28e0a15dea8e986ee393af70fc97fd57669808728080529c37/ruff-0.14.6-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7f6007e55b90a2a7e93083ba48a9f23c3158c433591c33ee2e99a49b889c6332", size = 12676572, upload-time = "2025-11-21T14:25:29.826Z" }, - { url = "https://files.pythonhosted.org/packages/76/a4/f319e87759949062cfee1b26245048e92e2acce900ad3a909285f9db1859/ruff-0.14.6-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a8e7b9d73d8728b68f632aa8e824ef041d068d231d8dbc7808532d3629a6bef", size = 13140745, upload-time = "2025-11-21T14:25:32.788Z" }, - { url = "https://files.pythonhosted.org/packages/95/d3/248c1efc71a0a8ed4e8e10b4b2266845d7dfc7a0ab64354afe049eaa1310/ruff-0.14.6-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:d50d45d4553a3ebcbd33e7c5e0fe6ca4aafd9a9122492de357205c2c48f00775", size = 13076486, upload-time = "2025-11-21T14:25:35.601Z" }, - { url = "https://files.pythonhosted.org/packages/a5/19/b68d4563fe50eba4b8c92aa842149bb56dd24d198389c0ed12e7faff4f7d/ruff-0.14.6-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:118548dd121f8a21bfa8ab2c5b80e5b4aed67ead4b7567790962554f38e598ce", size = 13727563, upload-time = "2025-11-21T14:25:38.514Z" }, - { url = "https://files.pythonhosted.org/packages/47/ac/943169436832d4b0e867235abbdb57ce3a82367b47e0280fa7b4eabb7593/ruff-0.14.6-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:57256efafbfefcb8748df9d1d766062f62b20150691021f8ab79e2d919f7c11f", size = 15199755, upload-time = "2025-11-21T14:25:41.516Z" }, - { url = "https://files.pythonhosted.org/packages/c9/b9/288bb2399860a36d4bb0541cb66cce3c0f4156aaff009dc8499be0c24bf2/ruff-0.14.6-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ff18134841e5c68f8e5df1999a64429a02d5549036b394fafbe410f886e1989d", size = 14850608, upload-time = "2025-11-21T14:25:44.428Z" }, - { url = "https://files.pythonhosted.org/packages/ee/b1/a0d549dd4364e240f37e7d2907e97ee80587480d98c7799d2d8dc7a2f605/ruff-0.14.6-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:29c4b7ec1e66a105d5c27bd57fa93203637d66a26d10ca9809dc7fc18ec58440", size = 14118754, upload-time = "2025-11-21T14:25:47.214Z" }, - { url = "https://files.pythonhosted.org/packages/13/ac/9b9fe63716af8bdfddfacd0882bc1586f29985d3b988b3c62ddce2e202c3/ruff-0.14.6-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:167843a6f78680746d7e226f255d920aeed5e4ad9c03258094a2d49d3028b105", size = 13949214, upload-time = "2025-11-21T14:25:50.002Z" }, - { url = "https://files.pythonhosted.org/packages/12/27/4dad6c6a77fede9560b7df6802b1b697e97e49ceabe1f12baf3ea20862e9/ruff-0.14.6-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:16a33af621c9c523b1ae006b1b99b159bf5ac7e4b1f20b85b2572455018e0821", size = 14106112, upload-time = "2025-11-21T14:25:52.841Z" }, - { url = "https://files.pythonhosted.org/packages/6a/db/23e322d7177873eaedea59a7932ca5084ec5b7e20cb30f341ab594130a71/ruff-0.14.6-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1432ab6e1ae2dc565a7eea707d3b03a0c234ef401482a6f1621bc1f427c2ff55", size = 13035010, upload-time = "2025-11-21T14:25:55.536Z" }, - { url = "https://files.pythonhosted.org/packages/a8/9c/20e21d4d69dbb35e6a1df7691e02f363423658a20a2afacf2a2c011800dc/ruff-0.14.6-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:4c55cfbbe7abb61eb914bfd20683d14cdfb38a6d56c6c66efa55ec6570ee4e71", size = 13054082, upload-time = "2025-11-21T14:25:58.625Z" }, - { url = "https://files.pythonhosted.org/packages/66/25/906ee6a0464c3125c8d673c589771a974965c2be1a1e28b5c3b96cb6ef88/ruff-0.14.6-py3-none-musllinux_1_2_i686.whl", hash = "sha256:efea3c0f21901a685fff4befda6d61a1bf4cb43de16da87e8226a281d614350b", size = 13303354, upload-time = "2025-11-21T14:26:01.816Z" }, - { url = "https://files.pythonhosted.org/packages/4c/58/60577569e198d56922b7ead07b465f559002b7b11d53f40937e95067ca1c/ruff-0.14.6-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:344d97172576d75dc6afc0e9243376dbe1668559c72de1864439c4fc95f78185", size = 14054487, upload-time = "2025-11-21T14:26:05.058Z" }, - { url = "https://files.pythonhosted.org/packages/67/0b/8e4e0639e4cc12547f41cb771b0b44ec8225b6b6a93393176d75fe6f7d40/ruff-0.14.6-py3-none-win32.whl", hash = "sha256:00169c0c8b85396516fdd9ce3446c7ca20c2a8f90a77aa945ba6b8f2bfe99e85", size = 13013361, upload-time = "2025-11-21T14:26:08.152Z" }, - { url = "https://files.pythonhosted.org/packages/fb/02/82240553b77fd1341f80ebb3eaae43ba011c7a91b4224a9f317d8e6591af/ruff-0.14.6-py3-none-win_amd64.whl", hash = "sha256:390e6480c5e3659f8a4c8d6a0373027820419ac14fa0d2713bd8e6c3e125b8b9", size = 14432087, upload-time = "2025-11-21T14:26:10.891Z" }, - { url = "https://files.pythonhosted.org/packages/a5/1f/93f9b0fad9470e4c829a5bb678da4012f0c710d09331b860ee555216f4ea/ruff-0.14.6-py3-none-win_arm64.whl", hash = "sha256:d43c81fbeae52cfa8728d8766bbf46ee4298c888072105815b392da70ca836b2", size = 13520930, upload-time = "2025-11-21T14:26:13.951Z" }, + { url = "https://files.pythonhosted.org/packages/41/2f/0b08ced94412af091807b6119ca03755d651d3d93a242682bf020189db94/ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e", size = 10489037, upload-time = "2026-03-19T16:26:32.47Z" }, + { url = "https://files.pythonhosted.org/packages/91/4a/82e0fa632e5c8b1eba5ee86ecd929e8ff327bbdbfb3c6ac5d81631bef605/ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477", size = 10955433, upload-time = "2026-03-19T16:27:00.205Z" }, + { url = "https://files.pythonhosted.org/packages/ab/10/12586735d0ff42526ad78c049bf51d7428618c8b5c467e72508c694119df/ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e", size = 10269302, upload-time = "2026-03-19T16:26:26.183Z" }, + { url = "https://files.pythonhosted.org/packages/eb/5d/32b5c44ccf149a26623671df49cbfbd0a0ae511ff3df9d9d2426966a8d57/ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf", size = 10607625, upload-time = "2026-03-19T16:27:03.263Z" }, + { url = "https://files.pythonhosted.org/packages/5d/f1/f0001cabe86173aaacb6eb9bb734aa0605f9a6aa6fa7d43cb49cbc4af9c9/ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85", size = 10324743, upload-time = "2026-03-19T16:27:09.791Z" }, + { url = "https://files.pythonhosted.org/packages/7a/87/b8a8f3d56b8d848008559e7c9d8bf367934d5367f6d932ba779456e2f73b/ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0", size = 11138536, upload-time = "2026-03-19T16:27:06.101Z" }, + { url = "https://files.pythonhosted.org/packages/e4/f2/4fd0d05aab0c5934b2e1464784f85ba2eab9d54bffc53fb5430d1ed8b829/ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912", size = 11994292, upload-time = "2026-03-19T16:26:48.718Z" }, + { url = "https://files.pythonhosted.org/packages/64/22/fc4483871e767e5e95d1622ad83dad5ebb830f762ed0420fde7dfa9d9b08/ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036", size = 11398981, upload-time = "2026-03-19T16:26:54.513Z" }, + { url = "https://files.pythonhosted.org/packages/b0/99/66f0343176d5eab02c3f7fcd2de7a8e0dd7a41f0d982bee56cd1c24db62b/ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5", size = 11242422, upload-time = "2026-03-19T16:26:29.277Z" }, + { url = "https://files.pythonhosted.org/packages/5d/3a/a7060f145bfdcce4c987ea27788b30c60e2c81d6e9a65157ca8afe646328/ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12", size = 11232158, upload-time = "2026-03-19T16:26:42.321Z" }, + { url = "https://files.pythonhosted.org/packages/a7/53/90fbb9e08b29c048c403558d3cdd0adf2668b02ce9d50602452e187cd4af/ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c", size = 10577861, upload-time = "2026-03-19T16:26:57.459Z" }, + { url = "https://files.pythonhosted.org/packages/2f/aa/5f486226538fe4d0f0439e2da1716e1acf895e2a232b26f2459c55f8ddad/ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4", size = 10327310, upload-time = "2026-03-19T16:26:35.909Z" }, + { url = "https://files.pythonhosted.org/packages/99/9e/271afdffb81fe7bfc8c43ba079e9d96238f674380099457a74ccb3863857/ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d", size = 10840752, upload-time = "2026-03-19T16:26:45.723Z" }, + { url = "https://files.pythonhosted.org/packages/bf/29/a4ae78394f76c7759953c47884eb44de271b03a66634148d9f7d11e721bd/ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580", size = 11336961, upload-time = "2026-03-19T16:26:39.076Z" }, + { url = "https://files.pythonhosted.org/packages/26/6b/8786ba5736562220d588a2f6653e6c17e90c59ced34a2d7b512ef8956103/ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de", size = 10582538, upload-time = "2026-03-19T16:26:15.992Z" }, + { url = "https://files.pythonhosted.org/packages/2b/e9/346d4d3fffc6871125e877dae8d9a1966b254fbd92a50f8561078b88b099/ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1", size = 11755839, upload-time = "2026-03-19T16:26:19.897Z" }, + { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, ] [[package]] name = "s3transfer" -version = "0.10.4" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/0a/1cdbabf9edd0ea7747efdf6c9ab4e7061b085aa7f9bfc36bb1601563b069/s3transfer-0.10.4.tar.gz", hash = "sha256:29edc09801743c21eb5ecbc617a152df41d3c287f67b615f73e5f750583666a7", size = 145287, upload-time = "2024-11-20T21:06:05.981Z" } +sdist = { url = "https://files.pythonhosted.org/packages/05/04/74127fc843314818edfa81b5540e26dd537353b123a4edc563109d8f17dd/s3transfer-0.16.0.tar.gz", hash = "sha256:8e990f13268025792229cd52fa10cb7163744bf56e719e0b9cb925ab79abf920", size = 153827, upload-time = "2025-12-01T02:30:59.114Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/05/7957af15543b8c9799209506df4660cba7afc4cf94bfb60513827e96bed6/s3transfer-0.10.4-py3-none-any.whl", hash = "sha256:244a76a24355363a68164241438de1b72f8781664920260c48465896b712a41e", size = 83175, upload-time = "2024-11-20T21:06:03.961Z" }, + { url = "https://files.pythonhosted.org/packages/fc/51/727abb13f44c1fcf6d145979e1535a35794db0f6e450a0cb46aa24732fe2/s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:18e25d66fed509e3868dc1572b3f427ff947dd2c56f844a5bf09481ad3f3b2fe", size = 86830, upload-time = "2025-12-01T02:30:57.729Z" }, ] [[package]] @@ -6384,109 +6524,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5d/e6/ec8471c8072382cb91233ba7267fd931219753bb43814cbc71757bfd4dab/safetensors-0.7.0-cp38-abi3-win_amd64.whl", hash = "sha256:d1239932053f56f3456f32eb9625590cc7582e905021f94636202a864d470755", size = 341380, upload-time = "2025-11-19T15:18:44.427Z" }, ] -[[package]] -name = "scikit-network" -version = "0.33.5" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, - { name = "scipy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/46/6d/28b00fbef9ff7d8ba31861bf16705a1a74a1696fb65aab2a7c584f966bec/scikit_network-0.33.5.tar.gz", hash = "sha256:ae2149d9a280fdc4bbadd5f8a7b17c8af61c054bc3f834792bc61483e6783c12", size = 1784205, upload-time = "2025-11-19T09:45:14.402Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/d4/ed/4a14f48ae7eceeb4bc2085c9467e7dda487b8728261a2556a3d2fdc99b0e/scikit_network-0.33.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:38a5a55f125b9ff574b085994e6374f783469faf1542d140c1630ad2c14127b9", size = 2854152, upload-time = "2025-11-19T09:44:43.049Z" }, - { url = "https://files.pythonhosted.org/packages/30/f9/71e225866bedaf6256ba3cd720c5bdbbe4828df7574d3c9cb741789b0f85/scikit_network-0.33.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6d796382dd914ccaa7d3fb990e148f5f095817690f07119d614f4d5bd63b347e", size = 2840146, upload-time = "2025-11-19T09:44:45.04Z" }, - { url = "https://files.pythonhosted.org/packages/a0/ac/ed64fbc2a21074aa399d3e527695f4d4bb35be3346b5e09edf0637d58080/scikit_network-0.33.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:406af38a07ee1d631b62616496013fc5d941fffb5221fdb9e9091b87352a3f0c", size = 8005762, upload-time = "2025-11-19T09:44:47.186Z" }, - { url = "https://files.pythonhosted.org/packages/d5/8e/9631ea79d6144d746e9a73e79d392656ae0c1d2989b8183f4f13134ad253/scikit_network-0.33.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d001988f07abe03ef4d8189da148968ad520f8cb8666d9ecee0d51c277bd466", size = 8061688, upload-time = "2025-11-19T09:44:49.034Z" }, - { url = "https://files.pythonhosted.org/packages/05/d2/26d21c25245204ef0ac6dfd047e8b6181bdc40d486dd9daadaadfed1c0e6/scikit_network-0.33.5-cp311-cp311-win_amd64.whl", hash = "sha256:8b9338feb0b2bae0f32ddd77453125b4e6c2d365cfb1bb1334b016888466e42f", size = 2738782, upload-time = "2025-11-19T09:44:50.881Z" }, - { url = "https://files.pythonhosted.org/packages/72/32/f092fca9a3ae256e0608ec6a4d8830023ea4ae478c2e27e94cf5802824f0/scikit_network-0.33.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:ced625228be1632595d11adaf22d61d1d4c788909cd4d8c364e720b2814aac7f", size = 2874698, upload-time = "2025-11-19T09:44:52.765Z" }, - { url = "https://files.pythonhosted.org/packages/4b/c7/5b2ce72f93b422af48a2b755fbcbab271bf980a6b46484f754d63978f1ff/scikit_network-0.33.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:808b625d28005c24b47cbdf65f780a3a355aa4080992e6b78f03434e873d06e6", size = 2854224, upload-time = "2025-11-19T09:44:55.574Z" }, - { url = "https://files.pythonhosted.org/packages/86/02/974ae67f493ccf988108894e8a9dedfd00ca5113d2848e0b9fc2a4d18824/scikit_network-0.33.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9f2d98059cd79bdb935ff6a638f1c3a12b0b1cb7ace9e1d3fb35476eeeaabaa", size = 7924650, upload-time = "2025-11-19T09:44:57.704Z" }, - { url = "https://files.pythonhosted.org/packages/8d/34/b67e48e111916a6f09fe29a971a9716c14f78525e0ab7c46e6a6538cf2f6/scikit_network-0.33.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aac2d3bc14214f02dac300624f5ec2650af9a98b52f304d300f0ec2813a0e544", size = 8012322, upload-time = "2025-11-19T09:45:00.079Z" }, - { url = "https://files.pythonhosted.org/packages/c3/a2/49293b53b837b3248d19fffd41a9fd73dcb2eeabbab890cc4c8aa237b545/scikit_network-0.33.5-cp312-cp312-win_amd64.whl", hash = "sha256:2866b16aed9ef25ba42cb2f2e44ef2ad079337f336ce48d0604b55fa4af87688", size = 2746491, upload-time = "2025-11-19T09:45:01.997Z" }, -] - -[[package]] -name = "scipy" -version = "1.17.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0", size = 30573822, upload-time = "2026-02-23T00:26:24.851Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/df/75/b4ce781849931fef6fd529afa6b63711d5a733065722d0c3e2724af9e40a/scipy-1.17.1-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:1f95b894f13729334fb990162e911c9e5dc1ab390c58aa6cbecb389c5b5e28ec", size = 31613675, upload-time = "2026-02-23T00:16:00.13Z" }, - { url = "https://files.pythonhosted.org/packages/f7/58/bccc2861b305abdd1b8663d6130c0b3d7cc22e8d86663edbc8401bfd40d4/scipy-1.17.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:e18f12c6b0bc5a592ed23d3f7b891f68fd7f8241d69b7883769eb5d5dfb52696", size = 28162057, upload-time = "2026-02-23T00:16:09.456Z" }, - { url = "https://files.pythonhosted.org/packages/6d/ee/18146b7757ed4976276b9c9819108adbc73c5aad636e5353e20746b73069/scipy-1.17.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:a3472cfbca0a54177d0faa68f697d8ba4c80bbdc19908c3465556d9f7efce9ee", size = 20334032, upload-time = "2026-02-23T00:16:17.358Z" }, - { url = "https://files.pythonhosted.org/packages/ec/e6/cef1cf3557f0c54954198554a10016b6a03b2ec9e22a4e1df734936bd99c/scipy-1.17.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:766e0dc5a616d026a3a1cffa379af959671729083882f50307e18175797b3dfd", size = 22709533, upload-time = "2026-02-23T00:16:25.791Z" }, - { url = "https://files.pythonhosted.org/packages/4d/60/8804678875fc59362b0fb759ab3ecce1f09c10a735680318ac30da8cd76b/scipy-1.17.1-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:744b2bf3640d907b79f3fd7874efe432d1cf171ee721243e350f55234b4cec4c", size = 33062057, upload-time = "2026-02-23T00:16:36.931Z" }, - { url = "https://files.pythonhosted.org/packages/09/7d/af933f0f6e0767995b4e2d705a0665e454d1c19402aa7e895de3951ebb04/scipy-1.17.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43af8d1f3bea642559019edfe64e9b11192a8978efbd1539d7bc2aaa23d92de4", size = 35349300, upload-time = "2026-02-23T00:16:49.108Z" }, - { url = "https://files.pythonhosted.org/packages/b4/3d/7ccbbdcbb54c8fdc20d3b6930137c782a163fa626f0aef920349873421ba/scipy-1.17.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:cd96a1898c0a47be4520327e01f874acfd61fb48a9420f8aa9f6483412ffa444", size = 35127333, upload-time = "2026-02-23T00:17:01.293Z" }, - { url = "https://files.pythonhosted.org/packages/e8/19/f926cb11c42b15ba08e3a71e376d816ac08614f769b4f47e06c3580c836a/scipy-1.17.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4eb6c25dd62ee8d5edf68a8e1c171dd71c292fdae95d8aeb3dd7d7de4c364082", size = 37741314, upload-time = "2026-02-23T00:17:12.576Z" }, - { url = "https://files.pythonhosted.org/packages/95/da/0d1df507cf574b3f224ccc3d45244c9a1d732c81dcb26b1e8a766ae271a8/scipy-1.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:d30e57c72013c2a4fe441c2fcb8e77b14e152ad48b5464858e07e2ad9fbfceff", size = 36607512, upload-time = "2026-02-23T00:17:23.424Z" }, - { url = "https://files.pythonhosted.org/packages/68/7f/bdd79ceaad24b671543ffe0ef61ed8e659440eb683b66f033454dcee90eb/scipy-1.17.1-cp311-cp311-win_arm64.whl", hash = "sha256:9ecb4efb1cd6e8c4afea0daa91a87fbddbce1b99d2895d151596716c0b2e859d", size = 24599248, upload-time = "2026-02-23T00:17:34.561Z" }, - { url = "https://files.pythonhosted.org/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8", size = 31610954, upload-time = "2026-02-23T00:17:49.855Z" }, - { url = "https://files.pythonhosted.org/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76", size = 28172662, upload-time = "2026-02-23T00:18:01.64Z" }, - { url = "https://files.pythonhosted.org/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086", size = 20344366, upload-time = "2026-02-23T00:18:12.015Z" }, - { url = "https://files.pythonhosted.org/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b", size = 22704017, upload-time = "2026-02-23T00:18:21.502Z" }, - { url = "https://files.pythonhosted.org/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21", size = 32927842, upload-time = "2026-02-23T00:18:35.367Z" }, - { url = "https://files.pythonhosted.org/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458", size = 35235890, upload-time = "2026-02-23T00:18:49.188Z" }, - { url = "https://files.pythonhosted.org/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb", size = 35003557, upload-time = "2026-02-23T00:18:54.74Z" }, - { url = "https://files.pythonhosted.org/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea", size = 37625856, upload-time = "2026-02-23T00:19:00.307Z" }, - { url = "https://files.pythonhosted.org/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87", size = 36549682, upload-time = "2026-02-23T00:19:07.67Z" }, - { url = "https://files.pythonhosted.org/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3", size = 24547340, upload-time = "2026-02-23T00:19:12.024Z" }, -] - [[package]] name = "scipy-stubs" -version = "1.16.3.1" +version = "1.17.1.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "optype", extra = ["numpy"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0b/3e/8baf960c68f012b8297930d4686b235813974833a417db8d0af798b0b93d/scipy_stubs-1.16.3.1.tar.gz", hash = "sha256:0738d55a7f8b0c94cdb8063f711d53330ebefe166f7d48dec9ffd932a337226d", size = 359990, upload-time = "2025-11-23T23:05:21.274Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/59/59c6cc3f9970154b9ed6b1aff42a0185cdd60cef54adc0404b9e77972221/scipy_stubs-1.17.1.3.tar.gz", hash = "sha256:5eb87a8d23d726706259b012ebe76a4a96a9ae9e141fc59bf55fc8eac2ed9e0f", size = 392185, upload-time = "2026-03-22T22:11:58.34Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0c/39/e2a69866518f88dc01940c9b9b044db97c3387f2826bd2a173e49a5c0469/scipy_stubs-1.16.3.1-py3-none-any.whl", hash = "sha256:69bc52ef6c3f8e09208abdfaf32291eb51e9ddf8fa4389401ccd9473bdd2a26d", size = 560397, upload-time = "2025-11-23T23:05:19.432Z" }, -] - -[[package]] -name = "secretstorage" -version = "3.5.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "cryptography", marker = "sys_platform == 'linux'" }, - { name = "jeepney", marker = "sys_platform == 'linux'" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/1c/03/e834bcd866f2f8a49a85eaff47340affa3bfa391ee9912a952a1faa68c7b/secretstorage-3.5.0.tar.gz", hash = "sha256:f04b8e4689cbce351744d5537bf6b1329c6fc68f91fa666f60a380edddcd11be", size = 19884, upload-time = "2025-11-23T19:02:53.191Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/46/f5af3402b579fd5e11573ce652019a67074317e18c1935cc0b4ba9b35552/secretstorage-3.5.0-py3-none-any.whl", hash = "sha256:0ce65888c0725fcb2c5bc0fdb8e5438eece02c523557ea40ce0703c266248137", size = 15554, upload-time = "2025-11-23T19:02:51.545Z" }, + { url = "https://files.pythonhosted.org/packages/2c/d4/94304532c0a75a55526119043dd44a9bd1541a21e14483cbb54261c527d2/scipy_stubs-1.17.1.3-py3-none-any.whl", hash = "sha256:7b91d3f05aa47da06fbca14eb6c5bb4c28994e9245fd250cc847e375bab31297", size = 597933, upload-time = "2026-03-22T22:11:56.525Z" }, ] [[package]] name = "sendgrid" -version = "6.12.5" +version = "6.12.4" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "cryptography" }, + { name = "ecdsa" }, { name = "python-http-client" }, { name = "werkzeug" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/fa/f718b2b953f99c1f0085811598ac7e31ccbd4229a81ec2a5290be868187a/sendgrid-6.12.5.tar.gz", hash = "sha256:ea9aae30cd55c332e266bccd11185159482edfc07c149b6cd15cf08869fabdb7", size = 50310, upload-time = "2025-09-19T06:23:09.229Z" } +sdist = { url = "https://files.pythonhosted.org/packages/11/31/62e00433878dccf33edf07f8efa417b9030a2464eb3b04bbd797a11b4447/sendgrid-6.12.4.tar.gz", hash = "sha256:9e88b849daf0fa4bdf256c3b5da9f5a3272402c0c2fd6b1928c9de440db0a03d", size = 50271, upload-time = "2025-06-12T10:29:37.213Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bd/55/b3c3880a77082e8f7374954e0074aafafaa9bc78bdf9c8f5a92c2e7afc6a/sendgrid-6.12.5-py3-none-any.whl", hash = "sha256:96f92cc91634bf552fdb766b904bbb53968018da7ae41fdac4d1090dc0311ca8", size = 102173, upload-time = "2025-09-19T06:23:07.93Z" }, + { url = "https://files.pythonhosted.org/packages/c2/9c/45d068fd831a65e6ed1e2ab3233de58784842afdc62fdcdd0a01bbb6b39d/sendgrid-6.12.4-py3-none-any.whl", hash = "sha256:9a211b96241e63bd5b9ed9afcc8608f4bcac426e4a319b3920ab877c8426e92c", size = 102122, upload-time = "2025-06-12T10:29:35.457Z" }, ] [[package]] name = "sentry-sdk" -version = "2.28.0" +version = "2.55.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5e/bb/6a41b2e0e9121bed4d2ec68d50568ab95c49f4744156a9bbb789c866c66d/sentry_sdk-2.28.0.tar.gz", hash = "sha256:14d2b73bc93afaf2a9412490329099e6217761cbab13b6ee8bc0e82927e1504e", size = 325052, upload-time = "2025-05-12T07:53:12.785Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/b8/285293dc60fc198fffc3fcdbc7c6d4e646e0f74e61461c355d40faa64ceb/sentry_sdk-2.55.0.tar.gz", hash = "sha256:3774c4d8820720ca4101548131b9c162f4c9426eb7f4d24aca453012a7470f69", size = 424505, upload-time = "2026-03-17T14:15:51.707Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/4e/b1575833094c088dfdef63fbca794518860fcbc8002aadf51ebe8b6a387f/sentry_sdk-2.28.0-py2.py3-none-any.whl", hash = "sha256:51496e6cb3cb625b99c8e08907c67a9112360259b0ef08470e532c3ab184a232", size = 341693, upload-time = "2025-05-12T07:53:10.882Z" }, + { url = "https://files.pythonhosted.org/packages/9a/66/20465097782d7e1e742d846407ea7262d338c6e876ddddad38ca8907b38f/sentry_sdk-2.55.0-py2.py3-none-any.whl", hash = "sha256:97026981cb15699394474a196b88503a393cbc58d182ece0d3abe12b9bd978d4", size = 449284, upload-time = "2026-03-17T14:15:49.604Z" }, ] [package.optional-dependencies] @@ -6498,38 +6572,11 @@ flask = [ [[package]] name = "setuptools" -version = "80.9.0" +version = "82.0.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/5d/3bf57dcd21979b887f014ea83c24ae194cfcd12b9e0fda66b957c69d1fca/setuptools-80.9.0.tar.gz", hash = "sha256:f36b47402ecde768dbfafc46e8e4207b4360c654f1f3bb84475f0a28628fb19c", size = 1319958, upload-time = "2025-05-27T00:56:51.443Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4f/db/cfac1baf10650ab4d1c111714410d2fbb77ac5a616db26775db562c8fab2/setuptools-82.0.1.tar.gz", hash = "sha256:7d872682c5d01cfde07da7bccc7b65469d3dca203318515ada1de5eda35efbf9", size = 1152316, upload-time = "2026-03-09T12:47:17.221Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a3/dc/17031897dae0efacfea57dfd3a82fdd2a2aeb58e0ff71b77b87e44edc772/setuptools-80.9.0-py3-none-any.whl", hash = "sha256:062d34222ad13e0cc312a4c02d73f059e86a4acbfbdea8f8f76b28c99f306922", size = 1201486, upload-time = "2025-05-27T00:56:49.664Z" }, -] - -[[package]] -name = "shapely" -version = "2.1.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/4d/bc/0989043118a27cccb4e906a46b7565ce36ca7b57f5a18b78f4f1b0f72d9d/shapely-2.1.2.tar.gz", hash = "sha256:2ed4ecb28320a433db18a5bf029986aa8afcfd740745e78847e330d5d94922a9", size = 315489, upload-time = "2025-09-24T13:51:41.432Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8f/8d/1ff672dea9ec6a7b5d422eb6d095ed886e2e523733329f75fdcb14ee1149/shapely-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:91121757b0a36c9aac3427a651a7e6567110a4a67c97edf04f8d55d4765f6618", size = 1820038, upload-time = "2025-09-24T13:50:15.628Z" }, - { url = "https://files.pythonhosted.org/packages/4f/ce/28fab8c772ce5db23a0d86bf0adaee0c4c79d5ad1db766055fa3dab442e2/shapely-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:16a9c722ba774cf50b5d4541242b4cce05aafd44a015290c82ba8a16931ff63d", size = 1626039, upload-time = "2025-09-24T13:50:16.881Z" }, - { url = "https://files.pythonhosted.org/packages/70/8b/868b7e3f4982f5006e9395c1e12343c66a8155c0374fdc07c0e6a1ab547d/shapely-2.1.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cc4f7397459b12c0b196c9efe1f9d7e92463cbba142632b4cc6d8bbbbd3e2b09", size = 3001519, upload-time = "2025-09-24T13:50:18.606Z" }, - { url = "https://files.pythonhosted.org/packages/13/02/58b0b8d9c17c93ab6340edd8b7308c0c5a5b81f94ce65705819b7416dba5/shapely-2.1.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:136ab87b17e733e22f0961504d05e77e7be8c9b5a8184f685b4a91a84efe3c26", size = 3110842, upload-time = "2025-09-24T13:50:21.77Z" }, - { url = "https://files.pythonhosted.org/packages/af/61/8e389c97994d5f331dcffb25e2fa761aeedfb52b3ad9bcdd7b8671f4810a/shapely-2.1.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:16c5d0fc45d3aa0a69074979f4f1928ca2734fb2e0dde8af9611e134e46774e7", size = 4021316, upload-time = "2025-09-24T13:50:23.626Z" }, - { url = "https://files.pythonhosted.org/packages/d3/d4/9b2a9fe6039f9e42ccf2cb3e84f219fd8364b0c3b8e7bbc857b5fbe9c14c/shapely-2.1.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6ddc759f72b5b2b0f54a7e7cde44acef680a55019eb52ac63a7af2cf17cb9cd2", size = 4178586, upload-time = "2025-09-24T13:50:25.443Z" }, - { url = "https://files.pythonhosted.org/packages/16/f6/9840f6963ed4decf76b08fd6d7fed14f8779fb7a62cb45c5617fa8ac6eab/shapely-2.1.2-cp311-cp311-win32.whl", hash = "sha256:2fa78b49485391224755a856ed3b3bd91c8455f6121fee0db0e71cefb07d0ef6", size = 1543961, upload-time = "2025-09-24T13:50:26.968Z" }, - { url = "https://files.pythonhosted.org/packages/38/1e/3f8ea46353c2a33c1669eb7327f9665103aa3a8dfe7f2e4ef714c210b2c2/shapely-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:c64d5c97b2f47e3cd9b712eaced3b061f2b71234b3fc263e0fcf7d889c6559dc", size = 1722856, upload-time = "2025-09-24T13:50:28.497Z" }, - { url = "https://files.pythonhosted.org/packages/24/c0/f3b6453cf2dfa99adc0ba6675f9aaff9e526d2224cbd7ff9c1a879238693/shapely-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe2533caae6a91a543dec62e8360fe86ffcdc42a7c55f9dfd0128a977a896b94", size = 1833550, upload-time = "2025-09-24T13:50:30.019Z" }, - { url = "https://files.pythonhosted.org/packages/86/07/59dee0bc4b913b7ab59ab1086225baca5b8f19865e6101db9ebb7243e132/shapely-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ba4d1333cc0bc94381d6d4308d2e4e008e0bd128bdcff5573199742ee3634359", size = 1643556, upload-time = "2025-09-24T13:50:32.291Z" }, - { url = "https://files.pythonhosted.org/packages/26/29/a5397e75b435b9895cd53e165083faed5d12fd9626eadec15a83a2411f0f/shapely-2.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0bd308103340030feef6c111d3eb98d50dc13feea33affc8a6f9fa549e9458a3", size = 2988308, upload-time = "2025-09-24T13:50:33.862Z" }, - { url = "https://files.pythonhosted.org/packages/b9/37/e781683abac55dde9771e086b790e554811a71ed0b2b8a1e789b7430dd44/shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1e7d4d7ad262a48bb44277ca12c7c78cb1b0f56b32c10734ec9a1d30c0b0c54b", size = 3099844, upload-time = "2025-09-24T13:50:35.459Z" }, - { url = "https://files.pythonhosted.org/packages/d8/f3/9876b64d4a5a321b9dc482c92bb6f061f2fa42131cba643c699f39317cb9/shapely-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e9eddfe513096a71896441a7c37db72da0687b34752c4e193577a145c71736fc", size = 3988842, upload-time = "2025-09-24T13:50:37.478Z" }, - { url = "https://files.pythonhosted.org/packages/d1/a0/704c7292f7014c7e74ec84eddb7b109e1fbae74a16deae9c1504b1d15565/shapely-2.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:980c777c612514c0cf99bc8a9de6d286f5e186dcaf9091252fcd444e5638193d", size = 4152714, upload-time = "2025-09-24T13:50:39.9Z" }, - { url = "https://files.pythonhosted.org/packages/53/46/319c9dc788884ad0785242543cdffac0e6530e4d0deb6c4862bc4143dcf3/shapely-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9111274b88e4d7b54a95218e243282709b330ef52b7b86bc6aaf4f805306f454", size = 1542745, upload-time = "2025-09-24T13:50:41.414Z" }, - { url = "https://files.pythonhosted.org/packages/ec/bf/cb6c1c505cb31e818e900b9312d514f381fbfa5c4363edfce0fcc4f8c1a4/shapely-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:743044b4cfb34f9a67205cee9279feaf60ba7d02e69febc2afc609047cb49179", size = 1722861, upload-time = "2025-09-24T13:50:43.35Z" }, + { url = "https://files.pythonhosted.org/packages/9d/76/f789f7a86709c6b087c5a2f52f911838cad707cc613162401badc665acfe/setuptools-82.0.1-py3-none-any.whl", hash = "sha256:a59e362652f08dcd477c78bb6e7bd9d80a7995bc73ce773050228a348ce2e5bb", size = 1006223, upload-time = "2026-03-09T12:47:15.026Z" }, ] [[package]] @@ -6551,12 +6598,24 @@ wheels = [ ] [[package]] -name = "smmap" -version = "5.0.2" +name = "smart-open" +version = "7.5.1" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/44/cd/a040c4b3119bbe532e5b0732286f805445375489fceaec1f48306068ee3b/smmap-5.0.2.tar.gz", hash = "sha256:26ea65a03958fa0c8a1c7e8c7a58fdc77221b8910f6be2131affade476898ad5", size = 22329, upload-time = "2025-01-02T07:14:40.909Z" } +dependencies = [ + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e8/be/a66598b305763861a9ab15ff0f2fbc44e47b1ce7a776797337a4eef37c66/smart_open-7.5.1.tar.gz", hash = "sha256:3f08e16827c4733699e6b2cc40328a3568f900cb12ad9a3ad233ba6c872d9fe7", size = 54034, upload-time = "2026-02-23T11:01:28.979Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/04/be/d09147ad1ec7934636ad912901c5fd7667e1c858e19d355237db0d0cd5e4/smmap-5.0.2-py3-none-any.whl", hash = "sha256:b30115f0def7d7531d22a0fb6502488d879e75b260a9db4d0819cfb25403af5e", size = 24303, upload-time = "2025-01-02T07:14:38.724Z" }, + { url = "https://files.pythonhosted.org/packages/5e/ea/dcdecd68acebb49d3fd560473a43499b1635076f7f1ae8641c060fe7ce74/smart_open-7.5.1-py3-none-any.whl", hash = "sha256:3e07cbbd9c8a908bcb8e25d48becf1a5cbb4886fa975e9f34c672ed171df2318", size = 64108, upload-time = "2026-02-23T11:01:27.429Z" }, +] + +[[package]] +name = "smmap" +version = "5.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" }, ] [[package]] @@ -6588,85 +6647,164 @@ wheels = [ [[package]] name = "soupsieve" -version = "2.8" +version = "2.8.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6d/e6/21ccce3262dd4889aa3332e5a119a3491a95e8f60939870a3a035aabac0d/soupsieve-2.8.tar.gz", hash = "sha256:e2dd4a40a628cb5f28f6d4b0db8800b8f581b65bb380b97de22ba5ca8d72572f", size = 103472, upload-time = "2025-08-27T15:39:51.78Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/14/a0/bb38d3b76b8cae341dad93a2dd83ab7462e6dbcdd84d43f54ee60a8dc167/soupsieve-2.8-py3-none-any.whl", hash = "sha256:0cc76456a30e20f5d7f2e14a98a4ae2ee4e5abdc7c5ea0aafe795f344bc7984c", size = 36679, upload-time = "2025-08-27T15:39:50.179Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, +] + +[[package]] +name = "spacy" +version = "3.8.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "catalogue" }, + { name = "cymem" }, + { name = "jinja2" }, + { name = "murmurhash" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "preshed" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "spacy-legacy" }, + { name = "spacy-loggers" }, + { name = "srsly" }, + { name = "thinc" }, + { name = "tqdm" }, + { name = "typer-slim" }, + { name = "wasabi" }, + { name = "weasel" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/59/9f/424244b0e2656afc9ff82fb7a96931a47397bfce5ba382213827b198312a/spacy-3.8.11.tar.gz", hash = "sha256:54e1e87b74a2f9ea807ffd606166bf29ac45e2bd81ff7f608eadc7b05787d90d", size = 1326804, upload-time = "2025-11-17T20:40:03.079Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/74/d3/0c795e6f31ee3535b6e70d08e89fc22247b95b61f94fc8334a01d39bf871/spacy-3.8.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a12d83e8bfba07563300ae5e0086548e41aa4bfe3734c97dda87e0eec813df0d", size = 6487958, upload-time = "2025-11-17T20:38:40.378Z" }, + { url = "https://files.pythonhosted.org/packages/4e/2a/83ca9b4d0a2b31adcf0ced49fa667212d12958f75d4e238618a60eb50b10/spacy-3.8.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e07a50b69500ef376326545353a470f00d1ed7203c76341b97242af976e3681a", size = 6148078, upload-time = "2025-11-17T20:38:42.524Z" }, + { url = "https://files.pythonhosted.org/packages/2c/f0/ff520df18a6152ba2dbf808c964014308e71a48feb4c7563f2a6cd6e668d/spacy-3.8.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:718b7bb5e83c76cb841ed6e407f7b40255d0b46af7101a426c20e04af3afd64e", size = 32056451, upload-time = "2025-11-17T20:38:44.92Z" }, + { url = "https://files.pythonhosted.org/packages/9d/3a/6c44c0b9b6a70595888b8d021514ded065548a5b10718ac253bd39f9fd73/spacy-3.8.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f860f9d51c1aeb2d61852442b232576e4ca4d239cb3d1b40ac452118b8eb2c68", size = 32302908, upload-time = "2025-11-17T20:38:47.672Z" }, + { url = "https://files.pythonhosted.org/packages/db/77/00e99e00efd4c2456772befc48400c2e19255140660d663e16b6924a0f2e/spacy-3.8.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ff8d928ce70d751b7bb27f60ee5e3a308216efd4ab4517291e6ff05d9b194840", size = 32280936, upload-time = "2025-11-17T20:38:50.893Z" }, + { url = "https://files.pythonhosted.org/packages/d8/da/692b51e9e5be2766d2d1fb9a7c8122cfd99c337570e621f09c40ce94ad17/spacy-3.8.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3f3cb91d7d42fafd92b8d5bf9f696571170d2f0747f85724a2c5b997753e33c9", size = 33117270, upload-time = "2025-11-17T20:38:53.596Z" }, + { url = "https://files.pythonhosted.org/packages/9b/13/a542ac9b61d071f3328fda1fd8087b523fb7a4f2c340010bc70b1f762485/spacy-3.8.11-cp311-cp311-win_amd64.whl", hash = "sha256:745c190923584935272188c604e0cc170f4179aace1025814a25d92ee90cf3de", size = 15348350, upload-time = "2025-11-17T20:38:56.833Z" }, + { url = "https://files.pythonhosted.org/packages/23/53/975c16514322f6385d6caa5929771613d69f5458fb24f03e189ba533f279/spacy-3.8.11-cp311-cp311-win_arm64.whl", hash = "sha256:27535d81d9dee0483b66660cadd93d14c1668f55e4faf4386aca4a11a41a8b97", size = 14701913, upload-time = "2025-11-17T20:38:59.507Z" }, + { url = "https://files.pythonhosted.org/packages/51/fb/01eadf4ba70606b3054702dc41fc2ccf7d70fb14514b3cd57f0ff78ebea8/spacy-3.8.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:aa1ee8362074c30098feaaf2dd888c829a1a79c4311eec1b117a0a61f16fa6dd", size = 6073726, upload-time = "2025-11-17T20:39:01.679Z" }, + { url = "https://files.pythonhosted.org/packages/3a/f8/07b03a2997fc2621aaeafae00af50f55522304a7da6926b07027bb6d0709/spacy-3.8.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:75a036d04c2cf11d6cb566c0a689860cc5a7a75b439e8fea1b3a6b673dabf25d", size = 5724702, upload-time = "2025-11-17T20:39:03.486Z" }, + { url = "https://files.pythonhosted.org/packages/13/0c/c4fa0f379dbe3258c305d2e2df3760604a9fcd71b34f8f65c23e43f4cf55/spacy-3.8.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7cb599d2747d4a59a5f90e8a453c149b13db382a8297925cf126333141dbc4f7", size = 32727774, upload-time = "2025-11-17T20:39:05.894Z" }, + { url = "https://files.pythonhosted.org/packages/ce/8e/6a4ba82bed480211ebdf5341b0f89e7271b454307525ac91b5e447825914/spacy-3.8.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:94632e302ad2fb79dc285bf1e9e4d4a178904d5c67049e0e02b7fb4a77af85c4", size = 33215053, upload-time = "2025-11-17T20:39:08.588Z" }, + { url = "https://files.pythonhosted.org/packages/a6/bc/44d863d248e9d7358c76a0aa8b3f196b8698df520650ed8de162e18fbffb/spacy-3.8.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:aeca6cf34009d48cda9fb1bbfb532469e3d643817241a73e367b34ab99a5806f", size = 32074195, upload-time = "2025-11-17T20:39:11.601Z" }, + { url = "https://files.pythonhosted.org/packages/6f/7d/0b115f3f16e1dd2d3f99b0f89497867fc11c41aed94f4b7a4367b4b54136/spacy-3.8.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:368a79b8df925b15d89dccb5e502039446fb2ce93cf3020e092d5b962c3349b9", size = 32996143, upload-time = "2025-11-17T20:39:14.705Z" }, + { url = "https://files.pythonhosted.org/packages/7d/48/7e9581b476df76aaf9ee182888d15322e77c38b0bbbd5e80160ba0bddd4c/spacy-3.8.11-cp312-cp312-win_amd64.whl", hash = "sha256:88d65941a87f58d75afca1785bd64d01183a92f7269dcbcf28bd9d6f6a77d1a7", size = 14217511, upload-time = "2025-11-17T20:39:17.316Z" }, + { url = "https://files.pythonhosted.org/packages/7b/1f/307a16f32f90aa5ee7ad8d29ff8620a57132b80a4c8c536963d46d192e1a/spacy-3.8.11-cp312-cp312-win_arm64.whl", hash = "sha256:97b865d6d3658e2ab103a67d6c8a2d678e193e84a07f40d9938565b669ceee39", size = 13614446, upload-time = "2025-11-17T20:39:19.748Z" }, +] + +[[package]] +name = "spacy-legacy" +version = "3.0.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d9/79/91f9d7cc8db5642acad830dcc4b49ba65a7790152832c4eceb305e46d681/spacy-legacy-3.0.12.tar.gz", hash = "sha256:b37d6e0c9b6e1d7ca1cf5bc7152ab64a4c4671f59c85adaf7a3fcb870357a774", size = 23806, upload-time = "2023-01-23T09:04:15.104Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/55/12e842c70ff8828e34e543a2c7176dac4da006ca6901c9e8b43efab8bc6b/spacy_legacy-3.0.12-py2.py3-none-any.whl", hash = "sha256:476e3bd0d05f8c339ed60f40986c07387c0a71479245d6d0f4298dbd52cda55f", size = 29971, upload-time = "2023-01-23T09:04:13.45Z" }, +] + +[[package]] +name = "spacy-loggers" +version = "1.0.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/67/3d/926db774c9c98acf66cb4ed7faf6c377746f3e00b84b700d0868b95d0712/spacy-loggers-1.0.5.tar.gz", hash = "sha256:d60b0bdbf915a60e516cc2e653baeff946f0cfc461b452d11a4d5458c6fe5f24", size = 20811, upload-time = "2023-09-11T12:26:52.323Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/33/78/d1a1a026ef3af911159398c939b1509d5c36fe524c7b644f34a5146c4e16/spacy_loggers-1.0.5-py3-none-any.whl", hash = "sha256:196284c9c446cc0cdb944005384270d775fdeaf4f494d8e269466cfa497ef645", size = 22343, upload-time = "2023-09-11T12:26:50.586Z" }, ] [[package]] name = "sqlalchemy" -version = "2.0.44" +version = "2.0.48" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "greenlet", marker = "platform_machine == 'AMD64' or platform_machine == 'WIN32' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'ppc64le' or platform_machine == 'win32' or platform_machine == 'x86_64'" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f0/f2/840d7b9496825333f532d2e3976b8eadbf52034178aac53630d09fe6e1ef/sqlalchemy-2.0.44.tar.gz", hash = "sha256:0ae7454e1ab1d780aee69fd2aae7d6b8670a581d8847f2d1e0f7ddfbf47e5a22", size = 9819830, upload-time = "2025-10-10T14:39:12.935Z" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/73/b4a9737255583b5fa858e0bb8e116eb94b88c910164ed2ed719147bde3de/sqlalchemy-2.0.48.tar.gz", hash = "sha256:5ca74f37f3369b45e1f6b7b06afb182af1fd5dde009e4ffd831830d98cbe5fe7", size = 9886075, upload-time = "2026-03-02T15:28:51.474Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e3/81/15d7c161c9ddf0900b076b55345872ed04ff1ed6a0666e5e94ab44b0163c/sqlalchemy-2.0.44-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0fe3917059c7ab2ee3f35e77757062b1bea10a0b6ca633c58391e3f3c6c488dd", size = 2140517, upload-time = "2025-10-10T15:36:15.64Z" }, - { url = "https://files.pythonhosted.org/packages/d4/d5/4abd13b245c7d91bdf131d4916fd9e96a584dac74215f8b5bc945206a974/sqlalchemy-2.0.44-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:de4387a354ff230bc979b46b2207af841dc8bf29847b6c7dbe60af186d97aefa", size = 2130738, upload-time = "2025-10-10T15:36:16.91Z" }, - { url = "https://files.pythonhosted.org/packages/cb/3c/8418969879c26522019c1025171cefbb2a8586b6789ea13254ac602986c0/sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3678a0fb72c8a6a29422b2732fe423db3ce119c34421b5f9955873eb9b62c1e", size = 3304145, upload-time = "2025-10-10T15:34:19.569Z" }, - { url = "https://files.pythonhosted.org/packages/94/2d/fdb9246d9d32518bda5d90f4b65030b9bf403a935cfe4c36a474846517cb/sqlalchemy-2.0.44-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3cf6872a23601672d61a68f390e44703442639a12ee9dd5a88bbce52a695e46e", size = 3304511, upload-time = "2025-10-10T15:47:05.088Z" }, - { url = "https://files.pythonhosted.org/packages/7d/fb/40f2ad1da97d5c83f6c1269664678293d3fe28e90ad17a1093b735420549/sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:329aa42d1be9929603f406186630135be1e7a42569540577ba2c69952b7cf399", size = 3235161, upload-time = "2025-10-10T15:34:21.193Z" }, - { url = "https://files.pythonhosted.org/packages/95/cb/7cf4078b46752dca917d18cf31910d4eff6076e5b513c2d66100c4293d83/sqlalchemy-2.0.44-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:70e03833faca7166e6a9927fbee7c27e6ecde436774cd0b24bbcc96353bce06b", size = 3261426, upload-time = "2025-10-10T15:47:07.196Z" }, - { url = "https://files.pythonhosted.org/packages/f8/3b/55c09b285cb2d55bdfa711e778bdffdd0dc3ffa052b0af41f1c5d6e582fa/sqlalchemy-2.0.44-cp311-cp311-win32.whl", hash = "sha256:253e2f29843fb303eca6b2fc645aca91fa7aa0aa70b38b6950da92d44ff267f3", size = 2105392, upload-time = "2025-10-10T15:38:20.051Z" }, - { url = "https://files.pythonhosted.org/packages/c7/23/907193c2f4d680aedbfbdf7bf24c13925e3c7c292e813326c1b84a0b878e/sqlalchemy-2.0.44-cp311-cp311-win_amd64.whl", hash = "sha256:7a8694107eb4308a13b425ca8c0e67112f8134c846b6e1f722698708741215d5", size = 2130293, upload-time = "2025-10-10T15:38:21.601Z" }, - { url = "https://files.pythonhosted.org/packages/62/c4/59c7c9b068e6813c898b771204aad36683c96318ed12d4233e1b18762164/sqlalchemy-2.0.44-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:72fea91746b5890f9e5e0997f16cbf3d53550580d76355ba2d998311b17b2250", size = 2139675, upload-time = "2025-10-10T16:03:31.064Z" }, - { url = "https://files.pythonhosted.org/packages/d6/ae/eeb0920537a6f9c5a3708e4a5fc55af25900216bdb4847ec29cfddf3bf3a/sqlalchemy-2.0.44-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:585c0c852a891450edbb1eaca8648408a3cc125f18cf433941fa6babcc359e29", size = 2127726, upload-time = "2025-10-10T16:03:35.934Z" }, - { url = "https://files.pythonhosted.org/packages/d8/d5/2ebbabe0379418eda8041c06b0b551f213576bfe4c2f09d77c06c07c8cc5/sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b94843a102efa9ac68a7a30cd46df3ff1ed9c658100d30a725d10d9c60a2f44", size = 3327603, upload-time = "2025-10-10T15:35:28.322Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/5aa65852dadc24b7d8ae75b7efb8d19303ed6ac93482e60c44a585930ea5/sqlalchemy-2.0.44-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:119dc41e7a7defcefc57189cfa0e61b1bf9c228211aba432b53fb71ef367fda1", size = 3337842, upload-time = "2025-10-10T15:43:45.431Z" }, - { url = "https://files.pythonhosted.org/packages/41/92/648f1afd3f20b71e880ca797a960f638d39d243e233a7082c93093c22378/sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0765e318ee9179b3718c4fd7ba35c434f4dd20332fbc6857a5e8df17719c24d7", size = 3264558, upload-time = "2025-10-10T15:35:29.93Z" }, - { url = "https://files.pythonhosted.org/packages/40/cf/e27d7ee61a10f74b17740918e23cbc5bc62011b48282170dc4c66da8ec0f/sqlalchemy-2.0.44-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2e7b5b079055e02d06a4308d0481658e4f06bc7ef211567edc8f7d5dce52018d", size = 3301570, upload-time = "2025-10-10T15:43:48.407Z" }, - { url = "https://files.pythonhosted.org/packages/3b/3d/3116a9a7b63e780fb402799b6da227435be878b6846b192f076d2f838654/sqlalchemy-2.0.44-cp312-cp312-win32.whl", hash = "sha256:846541e58b9a81cce7dee8329f352c318de25aa2f2bbe1e31587eb1f057448b4", size = 2103447, upload-time = "2025-10-10T15:03:21.678Z" }, - { url = "https://files.pythonhosted.org/packages/25/83/24690e9dfc241e6ab062df82cc0df7f4231c79ba98b273fa496fb3dd78ed/sqlalchemy-2.0.44-cp312-cp312-win_amd64.whl", hash = "sha256:7cbcb47fd66ab294703e1644f78971f6f2f1126424d2b300678f419aa73c7b6e", size = 2130912, upload-time = "2025-10-10T15:03:24.656Z" }, - { url = "https://files.pythonhosted.org/packages/9c/5e/6a29fa884d9fb7ddadf6b69490a9d45fded3b38541713010dad16b77d015/sqlalchemy-2.0.44-py3-none-any.whl", hash = "sha256:19de7ca1246fbef9f9d1bff8f1ab25641569df226364a0e40457dc5457c54b05", size = 1928718, upload-time = "2025-10-10T15:29:45.32Z" }, -] - -[package.optional-dependencies] -asyncio = [ - { name = "greenlet" }, + { url = "https://files.pythonhosted.org/packages/d7/6d/b8b78b5b80f3c3ab3f7fa90faa195ec3401f6d884b60221260fd4d51864c/sqlalchemy-2.0.48-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1b4c575df7368b3b13e0cebf01d4679f9a28ed2ae6c1cd0b1d5beffb6b2007dc", size = 2157184, upload-time = "2026-03-02T15:38:28.161Z" }, + { url = "https://files.pythonhosted.org/packages/21/4b/4f3d4a43743ab58b95b9ddf5580a265b593d017693df9e08bd55780af5bb/sqlalchemy-2.0.48-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e83e3f959aaa1c9df95c22c528096d94848a1bc819f5d0ebf7ee3df0ca63db6c", size = 3313555, upload-time = "2026-03-02T15:58:57.21Z" }, + { url = "https://files.pythonhosted.org/packages/21/dd/3b7c53f1dbbf736fd27041aee68f8ac52226b610f914085b1652c2323442/sqlalchemy-2.0.48-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6f7b7243850edd0b8b97043f04748f31de50cf426e939def5c16bedb540698f7", size = 3313057, upload-time = "2026-03-02T15:52:29.366Z" }, + { url = "https://files.pythonhosted.org/packages/d9/cc/3e600a90ae64047f33313d7d32e5ad025417f09d2ded487e8284b5e21a15/sqlalchemy-2.0.48-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:82745b03b4043e04600a6b665cb98697c4339b24e34d74b0a2ac0a2488b6f94d", size = 3265431, upload-time = "2026-03-02T15:58:59.096Z" }, + { url = "https://files.pythonhosted.org/packages/8b/19/780138dacfe3f5024f4cf96e4005e91edf6653d53d3673be4844578faf1d/sqlalchemy-2.0.48-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e5e088bf43f6ee6fec7dbf1ef7ff7774a616c236b5c0cb3e00662dd71a56b571", size = 3287646, upload-time = "2026-03-02T15:52:31.569Z" }, + { url = "https://files.pythonhosted.org/packages/40/fd/f32ced124f01a23151f4777e4c705f3a470adc7bd241d9f36a7c941a33bf/sqlalchemy-2.0.48-cp311-cp311-win32.whl", hash = "sha256:9c7d0a77e36b5f4b01ca398482230ab792061d243d715299b44a0b55c89fe617", size = 2116956, upload-time = "2026-03-02T15:46:54.535Z" }, + { url = "https://files.pythonhosted.org/packages/58/d5/dd767277f6feef12d05651538f280277e661698f617fa4d086cce6055416/sqlalchemy-2.0.48-cp311-cp311-win_amd64.whl", hash = "sha256:583849c743e0e3c9bb7446f5b5addeacedc168d657a69b418063dfdb2d90081c", size = 2141627, upload-time = "2026-03-02T15:46:55.849Z" }, + { url = "https://files.pythonhosted.org/packages/ef/91/a42ae716f8925e9659df2da21ba941f158686856107a61cc97a95e7647a3/sqlalchemy-2.0.48-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:348174f228b99f33ca1f773e85510e08927620caa59ffe7803b37170df30332b", size = 2155737, upload-time = "2026-03-02T15:49:13.207Z" }, + { url = "https://files.pythonhosted.org/packages/b9/52/f75f516a1f3888f027c1cfb5d22d4376f4b46236f2e8669dcb0cddc60275/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53667b5f668991e279d21f94ccfa6e45b4e3f4500e7591ae59a8012d0f010dcb", size = 3337020, upload-time = "2026-03-02T15:50:34.547Z" }, + { url = "https://files.pythonhosted.org/packages/37/9a/0c28b6371e0cdcb14f8f1930778cb3123acfcbd2c95bb9cf6b4a2ba0cce3/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34634e196f620c7a61d18d5cf7dc841ca6daa7961aed75d532b7e58b309ac894", size = 3349983, upload-time = "2026-03-02T15:53:25.542Z" }, + { url = "https://files.pythonhosted.org/packages/1c/46/0aee8f3ff20b1dcbceb46ca2d87fcc3d48b407925a383ff668218509d132/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:546572a1793cc35857a2ffa1fe0e58571af1779bcc1ffa7c9fb0839885ed69a9", size = 3279690, upload-time = "2026-03-02T15:50:36.277Z" }, + { url = "https://files.pythonhosted.org/packages/ce/8c/a957bc91293b49181350bfd55e6dfc6e30b7f7d83dc6792d72043274a390/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:07edba08061bc277bfdc772dd2a1a43978f5a45994dd3ede26391b405c15221e", size = 3314738, upload-time = "2026-03-02T15:53:27.519Z" }, + { url = "https://files.pythonhosted.org/packages/4b/44/1d257d9f9556661e7bdc83667cc414ba210acfc110c82938cb3611eea58f/sqlalchemy-2.0.48-cp312-cp312-win32.whl", hash = "sha256:908a3fa6908716f803b86896a09a2c4dde5f5ce2bb07aacc71ffebb57986ce99", size = 2115546, upload-time = "2026-03-02T15:54:31.591Z" }, + { url = "https://files.pythonhosted.org/packages/f2/af/c3c7e1f3a2b383155a16454df62ae8c62a30dd238e42e68c24cebebbfae6/sqlalchemy-2.0.48-cp312-cp312-win_amd64.whl", hash = "sha256:68549c403f79a8e25984376480959975212a670405e3913830614432b5daa07a", size = 2142484, upload-time = "2026-03-02T15:54:34.072Z" }, + { url = "https://files.pythonhosted.org/packages/46/2c/9664130905f03db57961b8980b05cab624afd114bf2be2576628a9f22da4/sqlalchemy-2.0.48-py3-none-any.whl", hash = "sha256:a66fe406437dd65cacd96a72689a3aaaecaebbcd62d81c5ac1c0fdbeac835096", size = 1940202, upload-time = "2026-03-02T15:52:43.285Z" }, ] [[package]] name = "sqlglot" -version = "28.0.0" +version = "30.0.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/8d/9ce5904aca760b81adf821c77a1dcf07c98f9caaa7e3b5c991c541ff89d2/sqlglot-28.0.0.tar.gz", hash = "sha256:cc9a651ef4182e61dac58aa955e5fb21845a5865c6a4d7d7b5a7857450285ad4", size = 5520798, upload-time = "2025-11-17T10:34:57.016Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5b/7a/3b6a8853fc2fa166f785a8ea4fecde46f70588e35471bc7811373da31a49/sqlglot-30.0.3.tar.gz", hash = "sha256:35ba7514c132b54f87fd1732a65a73615efa9fd83f6e1eed0a315bc9ee3e1027", size = 5802632, upload-time = "2026-03-19T16:51:39.166Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/56/6d/86de134f40199105d2fee1b066741aa870b3ce75ee74018d9c8508bbb182/sqlglot-28.0.0-py3-none-any.whl", hash = "sha256:ac1778e7fa4812f4f7e5881b260632fc167b00ca4c1226868891fb15467122e4", size = 536127, upload-time = "2025-11-17T10:34:55.192Z" }, + { url = "https://files.pythonhosted.org/packages/c6/13/b57ab75b0f60b5ee8cb8924bc01a5c419ed3221e00f8f11f8c059a707eb7/sqlglot-30.0.3-py3-none-any.whl", hash = "sha256:5489cc98b5666f1fafc21e0304ca286e513e142aa054ee5760806a2139d07a05", size = 651853, upload-time = "2026-03-19T16:51:36.241Z" }, ] [[package]] name = "sqlparse" -version = "0.5.4" +version = "0.5.5" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/67/701f86b28d63b2086de47c942eccf8ca2208b3be69715a1119a4e384415a/sqlparse-0.5.4.tar.gz", hash = "sha256:4396a7d3cf1cd679c1be976cf3dc6e0a51d0111e87787e7a8d780e7d5a998f9e", size = 120112, upload-time = "2025-11-28T07:10:18.377Z" } +sdist = { url = "https://files.pythonhosted.org/packages/90/76/437d71068094df0726366574cf3432a4ed754217b436eb7429415cf2d480/sqlparse-0.5.5.tar.gz", hash = "sha256:e20d4a9b0b8585fdf63b10d30066c7c94c5d7a7ec47c889a2d83a3caa93ff28e", size = 120815, upload-time = "2025-12-19T07:17:45.073Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/25/70/001ee337f7aa888fb2e3f5fd7592a6afc5283adb1ed44ce8df5764070f22/sqlparse-0.5.4-py3-none-any.whl", hash = "sha256:99a9f0314977b76d776a0fcb8554de91b9bb8a18560631d6bc48721d07023dcb", size = 45933, upload-time = "2025-11-28T07:10:19.73Z" }, + { url = "https://files.pythonhosted.org/packages/49/4b/359f28a903c13438ef59ebeee215fb25da53066db67b305c125f1c6d2a25/sqlparse-0.5.5-py3-none-any.whl", hash = "sha256:12a08b3bf3eec877c519589833aed092e2444e68240a3577e8e26148acc7b1ba", size = 46138, upload-time = "2025-12-19T07:17:46.573Z" }, +] + +[[package]] +name = "srsly" +version = "2.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "catalogue" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cf/77/5633c4ba65e3421b72b5b4bd93aa328360b351b3a1e5bf3c90eb224668e5/srsly-2.5.2.tar.gz", hash = "sha256:4092bc843c71b7595c6c90a0302a197858c5b9fe43067f62ae6a45bc3baa1c19", size = 492055, upload-time = "2025-11-17T14:11:02.543Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/6e/2e3d07b38c1c2e98487f0af92f93b392c6741062d85c65cdc18c7b77448a/srsly-2.5.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e7e07babdcece2405b32c9eea25ef415749f214c889545e38965622bb66837ce", size = 655286, upload-time = "2025-11-17T14:09:52.468Z" }, + { url = "https://files.pythonhosted.org/packages/a1/e7/587bcade6b72f919133e587edf60e06039d88049aef9015cd0bdea8df189/srsly-2.5.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1718fe40b73e5cc73b14625233f57e15fb23643d146f53193e8fe653a49e9a0f", size = 653094, upload-time = "2025-11-17T14:09:53.837Z" }, + { url = "https://files.pythonhosted.org/packages/8d/24/5c3aabe292cb4eb906c828f2866624e3a65603ef0a73e964e486ff146b84/srsly-2.5.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d7b07e6103db7dd3199c0321935b0c8b9297fd6e018a66de97dc836068440111", size = 1141286, upload-time = "2025-11-17T14:09:55.535Z" }, + { url = "https://files.pythonhosted.org/packages/2a/fe/2cbdcef2495e0c40dafb96da205d9ab3b9e59f64938277800bf65f923281/srsly-2.5.2-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f2dedf03b2ae143dd70039f097d128fb901deba2482c3a749ac0a985ac735aad", size = 1144667, upload-time = "2025-11-17T14:09:57.24Z" }, + { url = "https://files.pythonhosted.org/packages/91/7c/9a2c9d8141daf7b7a6f092c2be403421a0ab280e7c03cc62c223f37fdf47/srsly-2.5.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:9d5be1d8b79a4c4180073461425cb49c8924a184ab49d976c9c81a7bf87731d9", size = 1103935, upload-time = "2025-11-17T14:09:58.576Z" }, + { url = "https://files.pythonhosted.org/packages/f1/ad/8ae727430368fedbb1a7fa41b62d7a86237558bc962c5c5a9aa8bfa82548/srsly-2.5.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c8e42d6bcddda2e6fc1a8438cc050c4a36d0e457a63bcc7117d23c5175dfedec", size = 1117985, upload-time = "2025-11-17T14:10:00.348Z" }, + { url = "https://files.pythonhosted.org/packages/60/69/d6afaef1a8d5192fd802752115c7c3cc104493a7d604b406112b8bc2b610/srsly-2.5.2-cp311-cp311-win_amd64.whl", hash = "sha256:e7362981e687eead00248525c3ef3b8ddd95904c93362c481988d91b26b6aeef", size = 654148, upload-time = "2025-11-17T14:10:01.772Z" }, + { url = "https://files.pythonhosted.org/packages/8f/1c/21f658d98d602a559491b7886c7ca30245c2cd8987ff1b7709437c0f74b1/srsly-2.5.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6f92b4f883e6be4ca77f15980b45d394d310f24903e25e1b2c46df783c7edcce", size = 656161, upload-time = "2025-11-17T14:10:03.181Z" }, + { url = "https://files.pythonhosted.org/packages/2f/a2/bc6fd484ed703857043ae9abd6c9aea9152f9480a6961186ee6c1e0c49e8/srsly-2.5.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ac4790a54b00203f1af5495b6b8ac214131139427f30fcf05cf971dde81930eb", size = 653237, upload-time = "2025-11-17T14:10:04.636Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ea/e3895da29a15c8d325e050ad68a0d1238eece1d2648305796adf98dcba66/srsly-2.5.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ce5c6b016050857a7dd365c9dcdd00d96e7ac26317cfcb175db387e403de05bf", size = 1174418, upload-time = "2025-11-17T14:10:05.945Z" }, + { url = "https://files.pythonhosted.org/packages/a6/a5/21996231f53ee97191d0746c3a672ba33a4d86a19ffad85a1c0096c91c5f/srsly-2.5.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:539c6d0016e91277b5e9be31ebed03f03c32580d49c960e4a92c9003baecf69e", size = 1183089, upload-time = "2025-11-17T14:10:07.335Z" }, + { url = "https://files.pythonhosted.org/packages/7b/df/eb17aa8e4a828e8df7aa7dc471295529d9126e6b710f1833ebe0d8568a8e/srsly-2.5.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:9f24b2c4f4c29da04083f09158543eb3f8893ba0ac39818693b3b259ee8044f0", size = 1122594, upload-time = "2025-11-17T14:10:08.899Z" }, + { url = "https://files.pythonhosted.org/packages/80/74/1654a80e6c8ec3ee32370ea08a78d3651e0ba1c4d6e6be31c9efdb9a2d10/srsly-2.5.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:d34675047460a3f6999e43478f40d9b43917ea1e93a75c41d05bf7648f3e872d", size = 1139594, upload-time = "2025-11-17T14:10:10.286Z" }, + { url = "https://files.pythonhosted.org/packages/73/aa/8393344ca7f0e81965febba07afc5cad68335ed0426408d480b861ab915b/srsly-2.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:81fd133ba3c66c07f0e3a889d2b4c852984d71ea833a665238a9d47d8e051ba5", size = 654750, upload-time = "2025-11-17T14:10:11.637Z" }, ] [[package]] name = "sseclient-py" -version = "1.8.0" +version = "1.9.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/e8/ed/3df5ab8bb0c12f86c28d0cadb11ed1de44a92ed35ce7ff4fd5518a809325/sseclient-py-1.8.0.tar.gz", hash = "sha256:c547c5c1a7633230a38dc599a21a2dc638f9b5c297286b48b46b935c71fac3e8", size = 7791, upload-time = "2023-09-01T19:39:20.45Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/49/58/97655efdfeb5b4eeab85b1fc5d3fa1023661246c2ab2a26ea8e47402d4f2/sseclient_py-1.8.0-py2.py3-none-any.whl", hash = "sha256:4ecca6dc0b9f963f8384e9d7fd529bf93dd7d708144c4fb5da0e0a1a926fee83", size = 8828, upload-time = "2023-09-01T19:39:17.627Z" }, + { url = "https://files.pythonhosted.org/packages/4d/2e/59920f7d66b7f9932a3d83dd0ec53fab001be1e058bf582606fe414a5198/sseclient_py-1.9.0-py3-none-any.whl", hash = "sha256:340062b1587fc2880892811e2ab5b176d98ef3eee98b3672ff3a3ba1e8ed0f6f", size = 8351, upload-time = "2026-01-02T23:39:30.995Z" }, ] [[package]] name = "starlette" -version = "0.49.1" +version = "1.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1b/3f/507c21db33b66fb027a332f2cb3abbbe924cc3a79ced12f01ed8645955c9/starlette-0.49.1.tar.gz", hash = "sha256:481a43b71e24ed8c43b11ea02f5353d77840e01480881b8cb5a26b8cae64a8cb", size = 2654703, upload-time = "2025-10-28T17:34:10.928Z" } +sdist = { url = "https://files.pythonhosted.org/packages/81/69/17425771797c36cded50b7fe44e850315d039f28b15901ab44839e70b593/starlette-1.0.0.tar.gz", hash = "sha256:6a4beaf1f81bb472fd19ea9b918b50dc3a77a6f2e190a12954b25e6ed5eea149", size = 2655289, upload-time = "2026-03-22T18:29:46.779Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/51/da/545b75d420bb23b5d494b0517757b351963e974e79933f01e05c929f20a6/starlette-0.49.1-py3-none-any.whl", hash = "sha256:d92ce9f07e4a3caa3ac13a79523bd18e3bc0042bb8ff2d759a8e7dd0e1859875", size = 74175, upload-time = "2025-10-28T17:34:09.13Z" }, + { url = "https://files.pythonhosted.org/packages/0b/c9/584bc9651441b4ba60cc4d557d8a547b5aff901af35bda3a4ee30c819b82/starlette-1.0.0-py3-none-any.whl", hash = "sha256:d3ec55e0bb321692d275455ddfd3df75fff145d009685eb40dc91fc66b03d38b", size = 72651, upload-time = "2026-03-22T18:29:45.111Z" }, ] [[package]] @@ -6750,7 +6888,7 @@ wheels = [ [[package]] name = "tablestore" -version = "6.3.7" +version = "6.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -6763,9 +6901,9 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f1/39/47a3ec8e42fe74dd05af1dfed9c3b02b8f8adfdd8656b2c5d4f95f975c9f/tablestore-6.3.7.tar.gz", hash = "sha256:990682dbf6b602f317a2d359b4281dcd054b4326081e7a67b73dbbe95407be51", size = 117440, upload-time = "2025-10-29T02:57:57.415Z" } +sdist = { url = "https://files.pythonhosted.org/packages/62/00/53f8eeb0016e7ad518f92b085de8855891d10581b42f86d15d1df7a56d33/tablestore-6.4.1.tar.gz", hash = "sha256:005c6939832f2ecd403e01220b7045de45f2e53f1ffaf0c2efc435810885fffb", size = 120319, upload-time = "2026-02-13T06:58:37.267Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/fe/55/1b24d8c369204a855ac652712f815e88a4909802094e613fe3742a2d80e3/tablestore-6.3.7-py3-none-any.whl", hash = "sha256:38dcc55085912ab2515e183afd4532a58bb628a763590a99fc1bd2a4aba6855c", size = 139041, upload-time = "2025-10-29T02:57:55.727Z" }, + { url = "https://files.pythonhosted.org/packages/cc/96/a132bdecb753dc9dc34124a53019da29672baaa34485c8c504895897ea96/tablestore-6.4.1-py3-none-any.whl", hash = "sha256:616898d294dfe22f0d427463c241c6788374cdb2ace9aaf85673ce2c2a18d7e0", size = 141556, upload-time = "2026-02-13T06:58:35.579Z" }, ] [[package]] @@ -6791,7 +6929,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/20/81/be13f417065200182 [[package]] name = "tcvectordb" -version = "1.6.4" +version = "2.0.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -6804,23 +6942,23 @@ dependencies = [ { name = "ujson" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/ec/c80579aff1539257aafcf8dc3f3c13630171f299d65b33b68440e166f27c/tcvectordb-1.6.4.tar.gz", hash = "sha256:6fb18e15ccc6744d5147e9bbd781f84df3d66112de7d9cc615878b3f72d3a29a", size = 75188, upload-time = "2025-03-05T09:14:19.925Z" } +sdist = { url = "https://files.pythonhosted.org/packages/16/21/3bcd466df20ac69408c0228b1c5e793cf3283085238d3ef5d352c556b6ad/tcvectordb-2.0.0.tar.gz", hash = "sha256:38c6ed17931b9bd702138941ca6cfe10b2b60301424ffa36b64a3c2686318941", size = 82209, upload-time = "2025-12-27T07:55:27.376Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/68/bf/f38d9f629324ecffca8fe934e8df47e1233a9021b0739447e59e9fb248f9/tcvectordb-1.6.4-py3-none-any.whl", hash = "sha256:06ef13e7edb4575b04615065fc90e1a28374e318ada305f3786629aec5c9318a", size = 88917, upload-time = "2025-03-05T09:14:17.494Z" }, + { url = "https://files.pythonhosted.org/packages/af/10/e807b273348edef3b321194bc13b67d2cd4df64e22f0404b9e39082415c7/tcvectordb-2.0.0-py3-none-any.whl", hash = "sha256:1731d9c6c0d17a4199872747ddfb1dd3feb26f14ffe7a657f8a5ac3af4ddcdd1", size = 96256, upload-time = "2025-12-27T07:55:24.362Z" }, ] [[package]] name = "tenacity" -version = "8.5.0" +version = "9.1.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a3/4d/6a19536c50b849338fcbe9290d562b52cbdcf30d8963d3588a68a4107df1/tenacity-8.5.0.tar.gz", hash = "sha256:8bc6c0c8a09b31e6cad13c47afbed1a567518250a9a171418582ed8d9c20ca78", size = 47309, upload-time = "2024-07-05T07:25:31.836Z" } +sdist = { url = "https://files.pythonhosted.org/packages/47/c6/ee486fd809e357697ee8a44d3d69222b344920433d3b6666ccd9b374630c/tenacity-9.1.4.tar.gz", hash = "sha256:adb31d4c263f2bd041081ab33b498309a57c77f9acf2db65aadf0898179cf93a", size = 49413, upload-time = "2026-02-07T10:45:33.841Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/3f/8ba87d9e287b9d385a02a7114ddcef61b26f86411e121c9003eb509a1773/tenacity-8.5.0-py3-none-any.whl", hash = "sha256:b594c2a5945830c267ce6b79a166228323ed52718f30302c1359836112346687", size = 28165, upload-time = "2024-07-05T07:25:29.591Z" }, + { url = "https://files.pythonhosted.org/packages/d7/c1/eb8f9debc45d3b7918a32ab756658a0904732f75e555402972246b0b8e71/tenacity-9.1.4-py3-none-any.whl", hash = "sha256:6095a360c919085f28c6527de529e76a06ad89b23659fa881ae0649b867a9d55", size = 28926, upload-time = "2026-02-07T10:45:32.24Z" }, ] [[package]] name = "testcontainers" -version = "4.13.3" +version = "4.14.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docker" }, @@ -6829,70 +6967,111 @@ dependencies = [ { name = "urllib3" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fc/b3/c272537f3ea2f312555efeb86398cc382cd07b740d5f3c730918c36e64e1/testcontainers-4.13.3.tar.gz", hash = "sha256:9d82a7052c9a53c58b69e1dc31da8e7a715e8b3ec1c4df5027561b47e2efe646", size = 79064, upload-time = "2025-11-14T05:08:47.584Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/ac/a597c3a0e02b26cbed6dd07df68be1e57684766fd1c381dee9b170a99690/testcontainers-4.14.2.tar.gz", hash = "sha256:1340ccf16fe3acd9389a6c9e1d9ab21d9fe99a8afdf8165f89c3e69c1967d239", size = 166841, upload-time = "2026-03-18T05:19:16.696Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/73/27/c2f24b19dafa197c514abe70eda69bc031c5152c6b1f1e5b20099e2ceedd/testcontainers-4.13.3-py3-none-any.whl", hash = "sha256:063278c4805ffa6dd85e56648a9da3036939e6c0ac1001e851c9276b19b05970", size = 124784, upload-time = "2025-11-14T05:08:46.053Z" }, + { url = "https://files.pythonhosted.org/packages/13/2d/26b8b30067d94339afee62c3edc9b803a6eb9332f521ba77d8aaab5de873/testcontainers-4.14.2-py3-none-any.whl", hash = "sha256:0d0522c3cd8f8d9627cda41f7a6b51b639fa57bdc492923c045117933c668d68", size = 125712, upload-time = "2026-03-18T05:19:15.29Z" }, +] + +[[package]] +name = "thinc" +version = "8.3.11" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "blis" }, + { name = "catalogue" }, + { name = "confection" }, + { name = "cymem" }, + { name = "murmurhash" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "preshed" }, + { name = "pydantic" }, + { name = "setuptools" }, + { name = "srsly" }, + { name = "wasabi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/90/93/383073cf130108c7e4e5866a2745ada8acae9a6f000f5a5124153a196eeb/thinc-8.3.11.tar.gz", hash = "sha256:440ad9030ecee31e4a6ba25bd760fa272fb48768a6f6cce3f734595de5a8b0f4", size = 194557, upload-time = "2026-03-20T09:25:22.868Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7a/04/6160eb79415084dcbc1a85feff27e585ae8798028e0e17911bf1f0c1e483/thinc-8.3.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:12ffca57cbccf4017b1259da016f25013bc70046985c8e8163e7bb61364187fd", size = 843039, upload-time = "2026-03-20T09:24:27.816Z" }, + { url = "https://files.pythonhosted.org/packages/6c/cd/141a5017a79c654227ccf09c91225ced00a5b32be996d1ee8736eccf57e1/thinc-8.3.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8e3548fc28a0d9637feab31fd3a0636e6ada27bd71a35154e107ed0275ffcc5", size = 810757, upload-time = "2026-03-20T09:24:29.557Z" }, + { url = "https://files.pythonhosted.org/packages/e6/55/a586d7133788a3731e6b09d7a5f5f9f323f5b47c194fac3cbf99814946d8/thinc-8.3.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2ff49be22ba3fd3b35bf98f2698330e024a0b385f63c22aba1503abaf46812c4", size = 4095565, upload-time = "2026-03-20T09:24:31.364Z" }, + { url = "https://files.pythonhosted.org/packages/ab/93/2bf3c2c57039dcd23f8ea81bcb4b866a8a99ef694e21771ad386b45754fe/thinc-8.3.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2827bbd80bca036739613e7f7d61244d79bea29774ac0fdfc31c9fd4f9b9982e", size = 4126116, upload-time = "2026-03-20T09:24:32.852Z" }, + { url = "https://files.pythonhosted.org/packages/ef/45/9a3368f505f8d3caabd598f0cc8d283f66924fd07fbf2ccd356b8f1f3ae0/thinc-8.3.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8e7d395c0f07500ecd72f7af38196c5978b809806999baaa4540dd3bded7c809", size = 5094816, upload-time = "2026-03-20T09:24:34.44Z" }, + { url = "https://files.pythonhosted.org/packages/f1/58/2075003d1154d8c88d96db1826ef8d537e3d91f6ff7f4c21a27dd0114d05/thinc-8.3.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b9700665e057dff3176f96815b20817e9e7ae83db58d45973cc3d45ad95fe140", size = 5256607, upload-time = "2026-03-20T09:24:36.077Z" }, + { url = "https://files.pythonhosted.org/packages/fd/27/cbe70789c7c01241879672dd2e664d3787af78a7e503992c4c7713817992/thinc-8.3.11-cp311-cp311-win_amd64.whl", hash = "sha256:977a99b65bc5182a7bfd4164b026224d120ea19de6ccce835bf7066517fb8368", size = 1792838, upload-time = "2026-03-20T09:24:37.747Z" }, + { url = "https://files.pythonhosted.org/packages/ed/7c/5fea21fc261e5b04480c150db86baeb55e399551ffecff73f96c9bfe98dd/thinc-8.3.11-cp311-cp311-win_arm64.whl", hash = "sha256:e8ef857f65af2ad6f7147b360853c3b4cf9121939d7360e89c4949a3bf5e42a9", size = 1719115, upload-time = "2026-03-20T09:24:39.721Z" }, + { url = "https://files.pythonhosted.org/packages/d7/27/3ffb080640ff3cf6e689ff2349c2915adac09a288b487c4c7b723089d8ee/thinc-8.3.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cf66abb78d5720d915628fe326a8a6b9fa045388422b0920d6c7cb480d084eae", size = 820142, upload-time = "2026-03-20T09:24:41.942Z" }, + { url = "https://files.pythonhosted.org/packages/ec/07/4ea5c0f9833d9d3daf820680c2ed550417bc32fcdbd7bb39b9973296094c/thinc-8.3.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:542ec7db49b24cd2c9ee79a1f99dd2d9cb31f62b0a3ea164da4ea7e6224bce9d", size = 790920, upload-time = "2026-03-20T09:24:43.829Z" }, + { url = "https://files.pythonhosted.org/packages/89/80/dd291c8dab49a98656c378883c07cfb707b9aeaa3cdfb2504f8e0a2997ab/thinc-8.3.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b158fb1c3579fa8d07f8ab2c462906b99a46555f657aba721e4004f9d5dc16dd", size = 3843924, upload-time = "2026-03-20T09:24:45.586Z" }, + { url = "https://files.pythonhosted.org/packages/7f/e4/592ac8e8538bcb6fee3f97f92509918e61630566ace31d4fd703e3efb2fc/thinc-8.3.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7addbd7684bb599fcd76341a080010795067889f5c68ed2900b503f860505de7", size = 3898974, upload-time = "2026-03-20T09:24:47.211Z" }, + { url = "https://files.pythonhosted.org/packages/2b/36/de82f9e1660f1780dcf987184c94fe4903f3b09d528515dc0060e977630e/thinc-8.3.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:983830873dbcbe965665d7a1f99887e7c5a03eec7691eee4ef144e7213f959d9", size = 4826963, upload-time = "2026-03-20T09:24:48.804Z" }, + { url = "https://files.pythonhosted.org/packages/5a/4b/b60e0bd9f76fc9d1fbb6dcfcc174a744ba86c343445f3f5f08b7ef6e2714/thinc-8.3.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:bcfd84ca787e56d6b8167265268acaa8c755decb39e0a5e97fa9f0740c42969e", size = 5022850, upload-time = "2026-03-20T09:24:50.476Z" }, + { url = "https://files.pythonhosted.org/packages/60/40/d35297c3688368772e3a3bbfc854a0acea29ccbe757cdaf1fbc234588c5e/thinc-8.3.11-cp312-cp312-win_amd64.whl", hash = "sha256:0d0775b22ac345342d33cdbdf5a69c032c8560065545ff3ad3f19c79300f7cd7", size = 1719187, upload-time = "2026-03-20T09:24:52.408Z" }, + { url = "https://files.pythonhosted.org/packages/36/8e/d5e97b58365fee10798e792f6de6b5c5817107acc3bc70b3c182de2115ad/thinc-8.3.11-cp312-cp312-win_arm64.whl", hash = "sha256:6f6cc64c60462165e0fae98b32c9f52592b7227720d26476de8c40e044628c56", size = 1643965, upload-time = "2026-03-20T09:24:53.99Z" }, ] [[package]] name = "tidb-vector" -version = "0.0.9" +version = "0.0.15" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1a/98/ab324fdfbbf064186ca621e21aa3871ddf886ecb78358a9864509241e802/tidb_vector-0.0.9.tar.gz", hash = "sha256:e10680872532808e1bcffa7a92dd2b05bb65d63982f833edb3c6cd590dec7709", size = 16948, upload-time = "2024-05-08T07:54:36.955Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b1/55/6247b3b8dd0c0ec05a7b0dd7d4f016d03337d6f089db9cc221a31de1308c/tidb_vector-0.0.15.tar.gz", hash = "sha256:dfd16b31b06f025737f5c7432a08e04265dde8a7c9c67d037e6e694c8125f6f5", size = 20702, upload-time = "2025-07-15T09:48:07.423Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/bb/0f3b7b4d31537e90f4dd01f50fa58daef48807c789c1c1bdd610204ff103/tidb_vector-0.0.9-py3-none-any.whl", hash = "sha256:db060ee1c981326d3882d0810e0b8b57811f278668f9381168997b360c4296c2", size = 17026, upload-time = "2024-05-08T07:54:34.849Z" }, + { url = "https://files.pythonhosted.org/packages/24/27/5a4aeeae058f75c1925646ff82215551903688ec33acc64ca46135eac631/tidb_vector-0.0.15-py3-none-any.whl", hash = "sha256:2bc7d02f5508ba153c8d67d049ab1e661c850e09e3a29286dc8b19945e512ad8", size = 21924, upload-time = "2025-07-15T09:48:05.834Z" }, ] [[package]] name = "tiktoken" -version = "0.9.0" +version = "0.12.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "regex" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/ea/cf/756fedf6981e82897f2d570dd25fa597eb3f4459068ae0572d7e888cfd6f/tiktoken-0.9.0.tar.gz", hash = "sha256:d02a5ca6a938e0490e1ff957bc48c8b078c88cb83977be1625b1fd8aac792c5d", size = 35991, upload-time = "2025-02-14T06:03:01.003Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931", size = 37806, upload-time = "2025-10-06T20:22:45.419Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/ae/4613a59a2a48e761c5161237fc850eb470b4bb93696db89da51b79a871f1/tiktoken-0.9.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:f32cc56168eac4851109e9b5d327637f15fd662aa30dd79f964b7c39fbadd26e", size = 1065987, upload-time = "2025-02-14T06:02:14.174Z" }, - { url = "https://files.pythonhosted.org/packages/3f/86/55d9d1f5b5a7e1164d0f1538a85529b5fcba2b105f92db3622e5d7de6522/tiktoken-0.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:45556bc41241e5294063508caf901bf92ba52d8ef9222023f83d2483a3055348", size = 1009155, upload-time = "2025-02-14T06:02:15.384Z" }, - { url = "https://files.pythonhosted.org/packages/03/58/01fb6240df083b7c1916d1dcb024e2b761213c95d576e9f780dfb5625a76/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03935988a91d6d3216e2ec7c645afbb3d870b37bcb67ada1943ec48678e7ee33", size = 1142898, upload-time = "2025-02-14T06:02:16.666Z" }, - { url = "https://files.pythonhosted.org/packages/b1/73/41591c525680cd460a6becf56c9b17468d3711b1df242c53d2c7b2183d16/tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b3d80aad8d2c6b9238fc1a5524542087c52b860b10cbf952429ffb714bc1136", size = 1197535, upload-time = "2025-02-14T06:02:18.595Z" }, - { url = "https://files.pythonhosted.org/packages/7d/7c/1069f25521c8f01a1a182f362e5c8e0337907fae91b368b7da9c3e39b810/tiktoken-0.9.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b2a21133be05dc116b1d0372af051cd2c6aa1d2188250c9b553f9fa49301b336", size = 1259548, upload-time = "2025-02-14T06:02:20.729Z" }, - { url = "https://files.pythonhosted.org/packages/6f/07/c67ad1724b8e14e2b4c8cca04b15da158733ac60136879131db05dda7c30/tiktoken-0.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:11a20e67fdf58b0e2dea7b8654a288e481bb4fc0289d3ad21291f8d0849915fb", size = 893895, upload-time = "2025-02-14T06:02:22.67Z" }, - { url = "https://files.pythonhosted.org/packages/cf/e5/21ff33ecfa2101c1bb0f9b6df750553bd873b7fb532ce2cb276ff40b197f/tiktoken-0.9.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:e88f121c1c22b726649ce67c089b90ddda8b9662545a8aeb03cfef15967ddd03", size = 1065073, upload-time = "2025-02-14T06:02:24.768Z" }, - { url = "https://files.pythonhosted.org/packages/8e/03/a95e7b4863ee9ceec1c55983e4cc9558bcfd8f4f80e19c4f8a99642f697d/tiktoken-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a6600660f2f72369acb13a57fb3e212434ed38b045fd8cc6cdd74947b4b5d210", size = 1008075, upload-time = "2025-02-14T06:02:26.92Z" }, - { url = "https://files.pythonhosted.org/packages/40/10/1305bb02a561595088235a513ec73e50b32e74364fef4de519da69bc8010/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:95e811743b5dfa74f4b227927ed86cbc57cad4df859cb3b643be797914e41794", size = 1140754, upload-time = "2025-02-14T06:02:28.124Z" }, - { url = "https://files.pythonhosted.org/packages/1b/40/da42522018ca496432ffd02793c3a72a739ac04c3794a4914570c9bb2925/tiktoken-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:99376e1370d59bcf6935c933cb9ba64adc29033b7e73f5f7569f3aad86552b22", size = 1196678, upload-time = "2025-02-14T06:02:29.845Z" }, - { url = "https://files.pythonhosted.org/packages/5c/41/1e59dddaae270ba20187ceb8aa52c75b24ffc09f547233991d5fd822838b/tiktoken-0.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:badb947c32739fb6ddde173e14885fb3de4d32ab9d8c591cbd013c22b4c31dd2", size = 1259283, upload-time = "2025-02-14T06:02:33.838Z" }, - { url = "https://files.pythonhosted.org/packages/5b/64/b16003419a1d7728d0d8c0d56a4c24325e7b10a21a9dd1fc0f7115c02f0a/tiktoken-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a62d7a25225bafed786a524c1b9f0910a1128f4232615bf3f8257a73aaa3b16", size = 894897, upload-time = "2025-02-14T06:02:36.265Z" }, + { url = "https://files.pythonhosted.org/packages/de/46/21ea696b21f1d6d1efec8639c204bdf20fde8bafb351e1355c72c5d7de52/tiktoken-0.12.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:6e227c7f96925003487c33b1b32265fad2fbcec2b7cf4817afb76d416f40f6bb", size = 1051565, upload-time = "2025-10-06T20:21:44.566Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d9/35c5d2d9e22bb2a5f74ba48266fb56c63d76ae6f66e02feb628671c0283e/tiktoken-0.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c06cf0fcc24c2cb2adb5e185c7082a82cba29c17575e828518c2f11a01f445aa", size = 995284, upload-time = "2025-10-06T20:21:45.622Z" }, + { url = "https://files.pythonhosted.org/packages/01/84/961106c37b8e49b9fdcf33fe007bb3a8fdcc380c528b20cc7fbba80578b8/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:f18f249b041851954217e9fd8e5c00b024ab2315ffda5ed77665a05fa91f42dc", size = 1129201, upload-time = "2025-10-06T20:21:47.074Z" }, + { url = "https://files.pythonhosted.org/packages/6a/d0/3d9275198e067f8b65076a68894bb52fd253875f3644f0a321a720277b8a/tiktoken-0.12.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:47a5bc270b8c3db00bb46ece01ef34ad050e364b51d406b6f9730b64ac28eded", size = 1152444, upload-time = "2025-10-06T20:21:48.139Z" }, + { url = "https://files.pythonhosted.org/packages/78/db/a58e09687c1698a7c592e1038e01c206569b86a0377828d51635561f8ebf/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:508fa71810c0efdcd1b898fda574889ee62852989f7c1667414736bcb2b9a4bd", size = 1195080, upload-time = "2025-10-06T20:21:49.246Z" }, + { url = "https://files.pythonhosted.org/packages/9e/1b/a9e4d2bf91d515c0f74afc526fd773a812232dd6cda33ebea7f531202325/tiktoken-0.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a1af81a6c44f008cba48494089dd98cccb8b313f55e961a52f5b222d1e507967", size = 1255240, upload-time = "2025-10-06T20:21:50.274Z" }, + { url = "https://files.pythonhosted.org/packages/9d/15/963819345f1b1fb0809070a79e9dd96938d4ca41297367d471733e79c76c/tiktoken-0.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:3e68e3e593637b53e56f7237be560f7a394451cb8c11079755e80ae64b9e6def", size = 879422, upload-time = "2025-10-06T20:21:51.734Z" }, + { url = "https://files.pythonhosted.org/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8", size = 1050728, upload-time = "2025-10-06T20:21:52.756Z" }, + { url = "https://files.pythonhosted.org/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b", size = 994049, upload-time = "2025-10-06T20:21:53.782Z" }, + { url = "https://files.pythonhosted.org/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37", size = 1129008, upload-time = "2025-10-06T20:21:54.832Z" }, + { url = "https://files.pythonhosted.org/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad", size = 1152665, upload-time = "2025-10-06T20:21:56.129Z" }, + { url = "https://files.pythonhosted.org/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5", size = 1194230, upload-time = "2025-10-06T20:21:57.546Z" }, + { url = "https://files.pythonhosted.org/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3", size = 1254688, upload-time = "2025-10-06T20:21:58.619Z" }, + { url = "https://files.pythonhosted.org/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd", size = 878694, upload-time = "2025-10-06T20:21:59.876Z" }, ] [[package]] name = "tokenizers" -version = "0.22.1" +version = "0.22.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "huggingface-hub" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1c/46/fb6854cec3278fbfa4a75b50232c77622bc517ac886156e6afbfa4d8fc6e/tokenizers-0.22.1.tar.gz", hash = "sha256:61de6522785310a309b3407bac22d99c4db5dba349935e99e4d15ea2226af2d9", size = 363123, upload-time = "2025-09-19T09:49:23.424Z" } +sdist = { url = "https://files.pythonhosted.org/packages/73/6f/f80cfef4a312e1fb34baf7d85c72d4411afde10978d4657f8cdd811d3ccc/tokenizers-0.22.2.tar.gz", hash = "sha256:473b83b915e547aa366d1eee11806deaf419e17be16310ac0a14077f1e28f917", size = 372115, upload-time = "2026-01-05T10:45:15.988Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/bf/33/f4b2d94ada7ab297328fc671fed209368ddb82f965ec2224eb1892674c3a/tokenizers-0.22.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:59fdb013df17455e5f950b4b834a7b3ee2e0271e6378ccb33aa74d178b513c73", size = 3069318, upload-time = "2025-09-19T09:49:11.848Z" }, - { url = "https://files.pythonhosted.org/packages/1c/58/2aa8c874d02b974990e89ff95826a4852a8b2a273c7d1b4411cdd45a4565/tokenizers-0.22.1-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:8d4e484f7b0827021ac5f9f71d4794aaef62b979ab7608593da22b1d2e3c4edc", size = 2926478, upload-time = "2025-09-19T09:49:09.759Z" }, - { url = "https://files.pythonhosted.org/packages/1e/3b/55e64befa1e7bfea963cf4b787b2cea1011362c4193f5477047532ce127e/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:19d2962dd28bc67c1f205ab180578a78eef89ac60ca7ef7cbe9635a46a56422a", size = 3256994, upload-time = "2025-09-19T09:48:56.701Z" }, - { url = "https://files.pythonhosted.org/packages/71/0b/fbfecf42f67d9b7b80fde4aabb2b3110a97fac6585c9470b5bff103a80cb/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:38201f15cdb1f8a6843e6563e6e79f4abd053394992b9bbdf5213ea3469b4ae7", size = 3153141, upload-time = "2025-09-19T09:48:59.749Z" }, - { url = "https://files.pythonhosted.org/packages/17/a9/b38f4e74e0817af8f8ef925507c63c6ae8171e3c4cb2d5d4624bf58fca69/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d1cbe5454c9a15df1b3443c726063d930c16f047a3cc724b9e6e1a91140e5a21", size = 3508049, upload-time = "2025-09-19T09:49:05.868Z" }, - { url = "https://files.pythonhosted.org/packages/d2/48/dd2b3dac46bb9134a88e35d72e1aa4869579eacc1a27238f1577270773ff/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e7d094ae6312d69cc2a872b54b91b309f4f6fbce871ef28eb27b52a98e4d0214", size = 3710730, upload-time = "2025-09-19T09:49:01.832Z" }, - { url = "https://files.pythonhosted.org/packages/93/0e/ccabc8d16ae4ba84a55d41345207c1e2ea88784651a5a487547d80851398/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:afd7594a56656ace95cdd6df4cca2e4059d294c5cfb1679c57824b605556cb2f", size = 3412560, upload-time = "2025-09-19T09:49:03.867Z" }, - { url = "https://files.pythonhosted.org/packages/d0/c6/dc3a0db5a6766416c32c034286d7c2d406da1f498e4de04ab1b8959edd00/tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2ef6063d7a84994129732b47e7915e8710f27f99f3a3260b8a38fc7ccd083f4", size = 3250221, upload-time = "2025-09-19T09:49:07.664Z" }, - { url = "https://files.pythonhosted.org/packages/d7/a6/2c8486eef79671601ff57b093889a345dd3d576713ef047776015dc66de7/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:ba0a64f450b9ef412c98f6bcd2a50c6df6e2443b560024a09fa6a03189726879", size = 9345569, upload-time = "2025-09-19T09:49:14.214Z" }, - { url = "https://files.pythonhosted.org/packages/6b/16/32ce667f14c35537f5f605fe9bea3e415ea1b0a646389d2295ec348d5657/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:331d6d149fa9c7d632cde4490fb8bbb12337fa3a0232e77892be656464f4b446", size = 9271599, upload-time = "2025-09-19T09:49:16.639Z" }, - { url = "https://files.pythonhosted.org/packages/51/7c/a5f7898a3f6baa3fc2685c705e04c98c1094c523051c805cdd9306b8f87e/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:607989f2ea68a46cb1dfbaf3e3aabdf3f21d8748312dbeb6263d1b3b66c5010a", size = 9533862, upload-time = "2025-09-19T09:49:19.146Z" }, - { url = "https://files.pythonhosted.org/packages/36/65/7e75caea90bc73c1dd8d40438adf1a7bc26af3b8d0a6705ea190462506e1/tokenizers-0.22.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:a0f307d490295717726598ef6fa4f24af9d484809223bbc253b201c740a06390", size = 9681250, upload-time = "2025-09-19T09:49:21.501Z" }, - { url = "https://files.pythonhosted.org/packages/30/2c/959dddef581b46e6209da82df3b78471e96260e2bc463f89d23b1bf0e52a/tokenizers-0.22.1-cp39-abi3-win32.whl", hash = "sha256:b5120eed1442765cd90b903bb6cfef781fd8fe64e34ccaecbae4c619b7b12a82", size = 2472003, upload-time = "2025-09-19T09:49:27.089Z" }, - { url = "https://files.pythonhosted.org/packages/b3/46/e33a8c93907b631a99377ef4c5f817ab453d0b34f93529421f42ff559671/tokenizers-0.22.1-cp39-abi3-win_amd64.whl", hash = "sha256:65fd6e3fb11ca1e78a6a93602490f134d1fdeb13bcef99389d5102ea318ed138", size = 2674684, upload-time = "2025-09-19T09:49:24.953Z" }, + { url = "https://files.pythonhosted.org/packages/92/97/5dbfabf04c7e348e655e907ed27913e03db0923abb5dfdd120d7b25630e1/tokenizers-0.22.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:544dd704ae7238755d790de45ba8da072e9af3eea688f698b137915ae959281c", size = 3100275, upload-time = "2026-01-05T10:41:02.158Z" }, + { url = "https://files.pythonhosted.org/packages/2e/47/174dca0502ef88b28f1c9e06b73ce33500eedfac7a7692108aec220464e7/tokenizers-0.22.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:1e418a55456beedca4621dbab65a318981467a2b188e982a23e117f115ce5001", size = 2981472, upload-time = "2026-01-05T10:41:00.276Z" }, + { url = "https://files.pythonhosted.org/packages/d6/84/7990e799f1309a8b87af6b948f31edaa12a3ed22d11b352eaf4f4b2e5753/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2249487018adec45d6e3554c71d46eb39fa8ea67156c640f7513eb26f318cec7", size = 3290736, upload-time = "2026-01-05T10:40:32.165Z" }, + { url = "https://files.pythonhosted.org/packages/78/59/09d0d9ba94dcd5f4f1368d4858d24546b4bdc0231c2354aa31d6199f0399/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25b85325d0815e86e0bac263506dd114578953b7b53d7de09a6485e4a160a7dd", size = 3168835, upload-time = "2026-01-05T10:40:38.847Z" }, + { url = "https://files.pythonhosted.org/packages/47/50/b3ebb4243e7160bda8d34b731e54dd8ab8b133e50775872e7a434e524c28/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:bfb88f22a209ff7b40a576d5324bf8286b519d7358663db21d6246fb17eea2d5", size = 3521673, upload-time = "2026-01-05T10:40:56.614Z" }, + { url = "https://files.pythonhosted.org/packages/e0/fa/89f4cb9e08df770b57adb96f8cbb7e22695a4cb6c2bd5f0c4f0ebcf33b66/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1c774b1276f71e1ef716e5486f21e76333464f47bece56bbd554485982a9e03e", size = 3724818, upload-time = "2026-01-05T10:40:44.507Z" }, + { url = "https://files.pythonhosted.org/packages/64/04/ca2363f0bfbe3b3d36e95bf67e56a4c88c8e3362b658e616d1ac185d47f2/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:df6c4265b289083bf710dff49bc51ef252f9d5be33a45ee2bed151114a56207b", size = 3379195, upload-time = "2026-01-05T10:40:51.139Z" }, + { url = "https://files.pythonhosted.org/packages/2e/76/932be4b50ef6ccedf9d3c6639b056a967a86258c6d9200643f01269211ca/tokenizers-0.22.2-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:369cc9fc8cc10cb24143873a0d95438bb8ee257bb80c71989e3ee290e8d72c67", size = 3274982, upload-time = "2026-01-05T10:40:58.331Z" }, + { url = "https://files.pythonhosted.org/packages/1d/28/5f9f5a4cc211b69e89420980e483831bcc29dade307955cc9dc858a40f01/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:29c30b83d8dcd061078b05ae0cb94d3c710555fbb44861139f9f83dcca3dc3e4", size = 9478245, upload-time = "2026-01-05T10:41:04.053Z" }, + { url = "https://files.pythonhosted.org/packages/6c/fb/66e2da4704d6aadebf8cb39f1d6d1957df667ab24cff2326b77cda0dcb85/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:37ae80a28c1d3265bb1f22464c856bd23c02a05bb211e56d0c5301a435be6c1a", size = 9560069, upload-time = "2026-01-05T10:45:10.673Z" }, + { url = "https://files.pythonhosted.org/packages/16/04/fed398b05caa87ce9b1a1bb5166645e38196081b225059a6edaff6440fac/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:791135ee325f2336f498590eb2f11dc5c295232f288e75c99a36c5dbce63088a", size = 9899263, upload-time = "2026-01-05T10:45:12.559Z" }, + { url = "https://files.pythonhosted.org/packages/05/a1/d62dfe7376beaaf1394917e0f8e93ee5f67fea8fcf4107501db35996586b/tokenizers-0.22.2-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:38337540fbbddff8e999d59970f3c6f35a82de10053206a7562f1ea02d046fa5", size = 10033429, upload-time = "2026-01-05T10:45:14.333Z" }, + { url = "https://files.pythonhosted.org/packages/fd/18/a545c4ea42af3df6effd7d13d250ba77a0a86fb20393143bbb9a92e434d4/tokenizers-0.22.2-cp39-abi3-win32.whl", hash = "sha256:a6bf3f88c554a2b653af81f3204491c818ae2ac6fbc09e76ef4773351292bc92", size = 2502363, upload-time = "2026-01-05T10:45:20.593Z" }, + { url = "https://files.pythonhosted.org/packages/65/71/0670843133a43d43070abeb1949abfdef12a86d490bea9cd9e18e37c5ff7/tokenizers-0.22.2-cp39-abi3-win_amd64.whl", hash = "sha256:c9ea31edff2968b44a88f97d784c2f16dc0729b8b143ed004699ebca91f05c48", size = 2747786, upload-time = "2026-01-05T10:45:18.411Z" }, + { url = "https://files.pythonhosted.org/packages/72/f4/0de46cfa12cdcbcd464cc59fde36912af405696f687e53a091fb432f694c/tokenizers-0.22.2-cp39-abi3-win_arm64.whl", hash = "sha256:9ce725d22864a1e965217204946f830c37876eee3b2ba6fc6255e8e903d5fcbc", size = 2612133, upload-time = "2026-01-05T10:45:17.232Z" }, ] [[package]] @@ -6906,27 +7085,29 @@ wheels = [ [[package]] name = "tomli" -version = "2.3.0" +version = "2.4.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/3f73f72945444548f33eba9a87fc7a6e969915e7b1acc8260b30e1f76a2f/tomli-2.3.0.tar.gz", hash = "sha256:64be704a875d2a59753d80ee8a533c3fe183e3f06807ff7dc2232938ccb01549", size = 17392, upload-time = "2025-10-08T22:01:47.119Z" } +sdist = { url = "https://files.pythonhosted.org/packages/82/30/31573e9457673ab10aa432461bee537ce6cef177667deca369efb79df071/tomli-2.4.0.tar.gz", hash = "sha256:aa89c3f6c277dd275d8e243ad24f3b5e701491a860d5121f2cdd399fbb31fc9c", size = 17477, upload-time = "2026-01-11T11:22:38.165Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/2e/299f62b401438d5fe1624119c723f5d877acc86a4c2492da405626665f12/tomli-2.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:88bd15eb972f3664f5ed4b57c1634a97153b4bac4479dcb6a495f41921eb7f45", size = 153236, upload-time = "2025-10-08T22:01:00.137Z" }, - { url = "https://files.pythonhosted.org/packages/86/7f/d8fffe6a7aefdb61bced88fcb5e280cfd71e08939da5894161bd71bea022/tomli-2.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:883b1c0d6398a6a9d29b508c331fa56adbcdff647f6ace4dfca0f50e90dfd0ba", size = 148084, upload-time = "2025-10-08T22:01:01.63Z" }, - { url = "https://files.pythonhosted.org/packages/47/5c/24935fb6a2ee63e86d80e4d3b58b222dafaf438c416752c8b58537c8b89a/tomli-2.3.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d1381caf13ab9f300e30dd8feadb3de072aeb86f1d34a8569453ff32a7dea4bf", size = 234832, upload-time = "2025-10-08T22:01:02.543Z" }, - { url = "https://files.pythonhosted.org/packages/89/da/75dfd804fc11e6612846758a23f13271b76d577e299592b4371a4ca4cd09/tomli-2.3.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0e285d2649b78c0d9027570d4da3425bdb49830a6156121360b3f8511ea3441", size = 242052, upload-time = "2025-10-08T22:01:03.836Z" }, - { url = "https://files.pythonhosted.org/packages/70/8c/f48ac899f7b3ca7eb13af73bacbc93aec37f9c954df3c08ad96991c8c373/tomli-2.3.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0a154a9ae14bfcf5d8917a59b51ffd5a3ac1fd149b71b47a3a104ca4edcfa845", size = 239555, upload-time = "2025-10-08T22:01:04.834Z" }, - { url = "https://files.pythonhosted.org/packages/ba/28/72f8afd73f1d0e7829bfc093f4cb98ce0a40ffc0cc997009ee1ed94ba705/tomli-2.3.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:74bf8464ff93e413514fefd2be591c3b0b23231a77f901db1eb30d6f712fc42c", size = 245128, upload-time = "2025-10-08T22:01:05.84Z" }, - { url = "https://files.pythonhosted.org/packages/b6/eb/a7679c8ac85208706d27436e8d421dfa39d4c914dcf5fa8083a9305f58d9/tomli-2.3.0-cp311-cp311-win32.whl", hash = "sha256:00b5f5d95bbfc7d12f91ad8c593a1659b6387b43f054104cda404be6bda62456", size = 96445, upload-time = "2025-10-08T22:01:06.896Z" }, - { url = "https://files.pythonhosted.org/packages/0a/fe/3d3420c4cb1ad9cb462fb52967080575f15898da97e21cb6f1361d505383/tomli-2.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:4dc4ce8483a5d429ab602f111a93a6ab1ed425eae3122032db7e9acf449451be", size = 107165, upload-time = "2025-10-08T22:01:08.107Z" }, - { url = "https://files.pythonhosted.org/packages/ff/b7/40f36368fcabc518bb11c8f06379a0fd631985046c038aca08c6d6a43c6e/tomli-2.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d7d86942e56ded512a594786a5ba0a5e521d02529b3826e7761a05138341a2ac", size = 154891, upload-time = "2025-10-08T22:01:09.082Z" }, - { url = "https://files.pythonhosted.org/packages/f9/3f/d9dd692199e3b3aab2e4e4dd948abd0f790d9ded8cd10cbaae276a898434/tomli-2.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:73ee0b47d4dad1c5e996e3cd33b8a76a50167ae5f96a2607cbe8cc773506ab22", size = 148796, upload-time = "2025-10-08T22:01:10.266Z" }, - { url = "https://files.pythonhosted.org/packages/60/83/59bff4996c2cf9f9387a0f5a3394629c7efa5ef16142076a23a90f1955fa/tomli-2.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:792262b94d5d0a466afb5bc63c7daa9d75520110971ee269152083270998316f", size = 242121, upload-time = "2025-10-08T22:01:11.332Z" }, - { url = "https://files.pythonhosted.org/packages/45/e5/7c5119ff39de8693d6baab6c0b6dcb556d192c165596e9fc231ea1052041/tomli-2.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f195fe57ecceac95a66a75ac24d9d5fbc98ef0962e09b2eddec5d39375aae52", size = 250070, upload-time = "2025-10-08T22:01:12.498Z" }, - { url = "https://files.pythonhosted.org/packages/45/12/ad5126d3a278f27e6701abde51d342aa78d06e27ce2bb596a01f7709a5a2/tomli-2.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e31d432427dcbf4d86958c184b9bfd1e96b5b71f8eb17e6d02531f434fd335b8", size = 245859, upload-time = "2025-10-08T22:01:13.551Z" }, - { url = "https://files.pythonhosted.org/packages/fb/a1/4d6865da6a71c603cfe6ad0e6556c73c76548557a8d658f9e3b142df245f/tomli-2.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b0882799624980785240ab732537fcfc372601015c00f7fc367c55308c186f6", size = 250296, upload-time = "2025-10-08T22:01:14.614Z" }, - { url = "https://files.pythonhosted.org/packages/a0/b7/a7a7042715d55c9ba6e8b196d65d2cb662578b4d8cd17d882d45322b0d78/tomli-2.3.0-cp312-cp312-win32.whl", hash = "sha256:ff72b71b5d10d22ecb084d345fc26f42b5143c5533db5e2eaba7d2d335358876", size = 97124, upload-time = "2025-10-08T22:01:15.629Z" }, - { url = "https://files.pythonhosted.org/packages/06/1e/f22f100db15a68b520664eb3328fb0ae4e90530887928558112c8d1f4515/tomli-2.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:1cb4ed918939151a03f33d4242ccd0aa5f11b3547d0cf30f7c74a408a5b99878", size = 107698, upload-time = "2025-10-08T22:01:16.51Z" }, - { url = "https://files.pythonhosted.org/packages/77/b8/0135fadc89e73be292b473cb820b4f5a08197779206b33191e801feeae40/tomli-2.3.0-py3-none-any.whl", hash = "sha256:e95b1af3c5b07d9e643909b5abbec77cd9f1217e6d0bca72b0234736b9fb1f1b", size = 14408, upload-time = "2025-10-08T22:01:46.04Z" }, + { url = "https://files.pythonhosted.org/packages/3c/d9/3dc2289e1f3b32eb19b9785b6a006b28ee99acb37d1d47f78d4c10e28bf8/tomli-2.4.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b5ef256a3fd497d4973c11bf142e9ed78b150d36f5773f1ca6088c230ffc5867", size = 153663, upload-time = "2026-01-11T11:21:45.27Z" }, + { url = "https://files.pythonhosted.org/packages/51/32/ef9f6845e6b9ca392cd3f64f9ec185cc6f09f0a2df3db08cbe8809d1d435/tomli-2.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5572e41282d5268eb09a697c89a7bee84fae66511f87533a6f88bd2f7b652da9", size = 148469, upload-time = "2026-01-11T11:21:46.873Z" }, + { url = "https://files.pythonhosted.org/packages/d6/c2/506e44cce89a8b1b1e047d64bd495c22c9f71f21e05f380f1a950dd9c217/tomli-2.4.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:551e321c6ba03b55676970b47cb1b73f14a0a4dce6a3e1a9458fd6d921d72e95", size = 236039, upload-time = "2026-01-11T11:21:48.503Z" }, + { url = "https://files.pythonhosted.org/packages/b3/40/e1b65986dbc861b7e986e8ec394598187fa8aee85b1650b01dd925ca0be8/tomli-2.4.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5e3f639a7a8f10069d0e15408c0b96a2a828cfdec6fca05296ebcdcc28ca7c76", size = 243007, upload-time = "2026-01-11T11:21:49.456Z" }, + { url = "https://files.pythonhosted.org/packages/9c/6f/6e39ce66b58a5b7ae572a0f4352ff40c71e8573633deda43f6a379d56b3e/tomli-2.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1b168f2731796b045128c45982d3a4874057626da0e2ef1fdd722848b741361d", size = 240875, upload-time = "2026-01-11T11:21:50.755Z" }, + { url = "https://files.pythonhosted.org/packages/aa/ad/cb089cb190487caa80204d503c7fd0f4d443f90b95cf4ef5cf5aa0f439b0/tomli-2.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:133e93646ec4300d651839d382d63edff11d8978be23da4cc106f5a18b7d0576", size = 246271, upload-time = "2026-01-11T11:21:51.81Z" }, + { url = "https://files.pythonhosted.org/packages/0b/63/69125220e47fd7a3a27fd0de0c6398c89432fec41bc739823bcc66506af6/tomli-2.4.0-cp311-cp311-win32.whl", hash = "sha256:b6c78bdf37764092d369722d9946cb65b8767bfa4110f902a1b2542d8d173c8a", size = 96770, upload-time = "2026-01-11T11:21:52.647Z" }, + { url = "https://files.pythonhosted.org/packages/1e/0d/a22bb6c83f83386b0008425a6cd1fa1c14b5f3dd4bad05e98cf3dbbf4a64/tomli-2.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:d3d1654e11d724760cdb37a3d7691f0be9db5fbdaef59c9f532aabf87006dbaa", size = 107626, upload-time = "2026-01-11T11:21:53.459Z" }, + { url = "https://files.pythonhosted.org/packages/2f/6d/77be674a3485e75cacbf2ddba2b146911477bd887dda9d8c9dfb2f15e871/tomli-2.4.0-cp311-cp311-win_arm64.whl", hash = "sha256:cae9c19ed12d4e8f3ebf46d1a75090e4c0dc16271c5bce1c833ac168f08fb614", size = 94842, upload-time = "2026-01-11T11:21:54.831Z" }, + { url = "https://files.pythonhosted.org/packages/3c/43/7389a1869f2f26dba52404e1ef13b4784b6b37dac93bac53457e3ff24ca3/tomli-2.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:920b1de295e72887bafa3ad9f7a792f811847d57ea6b1215154030cf131f16b1", size = 154894, upload-time = "2026-01-11T11:21:56.07Z" }, + { url = "https://files.pythonhosted.org/packages/e9/05/2f9bf110b5294132b2edf13fe6ca6ae456204f3d749f623307cbb7a946f2/tomli-2.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7d6d9a4aee98fac3eab4952ad1d73aee87359452d1c086b5ceb43ed02ddb16b8", size = 149053, upload-time = "2026-01-11T11:21:57.467Z" }, + { url = "https://files.pythonhosted.org/packages/e8/41/1eda3ca1abc6f6154a8db4d714a4d35c4ad90adc0bcf700657291593fbf3/tomli-2.4.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36b9d05b51e65b254ea6c2585b59d2c4cb91c8a3d91d0ed0f17591a29aaea54a", size = 243481, upload-time = "2026-01-11T11:21:58.661Z" }, + { url = "https://files.pythonhosted.org/packages/d2/6d/02ff5ab6c8868b41e7d4b987ce2b5f6a51d3335a70aa144edd999e055a01/tomli-2.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1c8a885b370751837c029ef9bc014f27d80840e48bac415f3412e6593bbc18c1", size = 251720, upload-time = "2026-01-11T11:22:00.178Z" }, + { url = "https://files.pythonhosted.org/packages/7b/57/0405c59a909c45d5b6f146107c6d997825aa87568b042042f7a9c0afed34/tomli-2.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8768715ffc41f0008abe25d808c20c3d990f42b6e2e58305d5da280ae7d1fa3b", size = 247014, upload-time = "2026-01-11T11:22:01.238Z" }, + { url = "https://files.pythonhosted.org/packages/2c/0e/2e37568edd944b4165735687cbaf2fe3648129e440c26d02223672ee0630/tomli-2.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7b438885858efd5be02a9a133caf5812b8776ee0c969fea02c45e8e3f296ba51", size = 251820, upload-time = "2026-01-11T11:22:02.727Z" }, + { url = "https://files.pythonhosted.org/packages/5a/1c/ee3b707fdac82aeeb92d1a113f803cf6d0f37bdca0849cb489553e1f417a/tomli-2.4.0-cp312-cp312-win32.whl", hash = "sha256:0408e3de5ec77cc7f81960c362543cbbd91ef883e3138e81b729fc3eea5b9729", size = 97712, upload-time = "2026-01-11T11:22:03.777Z" }, + { url = "https://files.pythonhosted.org/packages/69/13/c07a9177d0b3bab7913299b9278845fc6eaaca14a02667c6be0b0a2270c8/tomli-2.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:685306e2cc7da35be4ee914fd34ab801a6acacb061b6a7abca922aaf9ad368da", size = 108296, upload-time = "2026-01-11T11:22:04.86Z" }, + { url = "https://files.pythonhosted.org/packages/18/27/e267a60bbeeee343bcc279bb9e8fbed0cbe224bc7b2a3dc2975f22809a09/tomli-2.4.0-cp312-cp312-win_arm64.whl", hash = "sha256:5aa48d7c2356055feef06a43611fc401a07337d5b006be13a30f6c58f869e3c3", size = 94553, upload-time = "2026-01-11T11:22:05.854Z" }, + { url = "https://files.pythonhosted.org/packages/23/d1/136eb2cb77520a31e1f64cbae9d33ec6df0d78bdf4160398e86eec8a8754/tomli-2.4.0-py3-none-any.whl", hash = "sha256:1f776e7d669ebceb01dee46484485f43a4048746235e683bcdffacdf1fb4785a", size = 14477, upload-time = "2026-01-11T11:22:37.446Z" }, ] [[package]] @@ -6945,70 +7126,61 @@ sdist = { url = "https://files.pythonhosted.org/packages/9a/b3/13451226f564f88d9 [[package]] name = "tqdm" -version = "4.67.1" +version = "4.67.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737, upload-time = "2024-11-24T20:12:22.481Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/a9/6ba95a270c6f1fbcd8dac228323f2777d886cb206987444e4bce66338dd4/tqdm-4.67.3.tar.gz", hash = "sha256:7d825f03f89244ef73f1d4ce193cb1774a8179fd96f31d7e1dcde62092b960bb", size = 169598, upload-time = "2026-02-03T17:35:53.048Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/d0/30/dc54f88dd4a2b5dc8a0279bdd7270e735851848b762aeb1c1184ed1f6b14/tqdm-4.67.1-py3-none-any.whl", hash = "sha256:26445eca388f82e72884e0d580d5464cd801a3ea01e63e5601bdff9ba6a48de2", size = 78540, upload-time = "2024-11-24T20:12:19.698Z" }, + { url = "https://files.pythonhosted.org/packages/16/e1/3079a9ff9b8e11b846c6ac5c8b5bfb7ff225eee721825310c91b3b50304f/tqdm-4.67.3-py3-none-any.whl", hash = "sha256:ee1e4c0e59148062281c49d80b25b67771a127c85fc9676d3be5f243206826bf", size = 78374, upload-time = "2026-02-03T17:35:50.982Z" }, ] [[package]] name = "transformers" -version = "4.56.2" +version = "5.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "filelock" }, { name = "huggingface-hub" }, { name = "numpy" }, { name = "packaging" }, { name = "pyyaml" }, { name = "regex" }, - { name = "requests" }, { name = "safetensors" }, { name = "tokenizers" }, { name = "tqdm" }, + { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e5/82/0bcfddd134cdf53440becb5e738257cc3cf34cf229d63b57bfd288e6579f/transformers-4.56.2.tar.gz", hash = "sha256:5e7c623e2d7494105c726dd10f6f90c2c99a55ebe86eef7233765abd0cb1c529", size = 9844296, upload-time = "2025-09-19T15:16:26.778Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fc/1a/70e830d53ecc96ce69cfa8de38f163712d2b43ac52fbd743f39f56025c31/transformers-5.3.0.tar.gz", hash = "sha256:009555b364029da9e2946d41f1c5de9f15e6b1df46b189b7293f33a161b9c557", size = 8830831, upload-time = "2026-03-04T17:41:46.119Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/70/26/2591b48412bde75e33bfd292034103ffe41743cacd03120e3242516cd143/transformers-4.56.2-py3-none-any.whl", hash = "sha256:79c03d0e85b26cb573c109ff9eafa96f3c8d4febfd8a0774e8bba32702dd6dde", size = 11608055, upload-time = "2025-09-19T15:16:23.736Z" }, -] - -[[package]] -name = "twine" -version = "5.1.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "importlib-metadata" }, - { name = "keyring" }, - { name = "pkginfo" }, - { name = "readme-renderer" }, - { name = "requests" }, - { name = "requests-toolbelt" }, - { name = "rfc3986" }, - { name = "rich" }, - { name = "urllib3" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/77/68/bd982e5e949ef8334e6f7dcf76ae40922a8750aa2e347291ae1477a4782b/twine-5.1.1.tar.gz", hash = "sha256:9aa0825139c02b3434d913545c7b847a21c835e11597f5255842d457da2322db", size = 225531, upload-time = "2024-06-26T15:00:46.58Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/ec/00f9d5fd040ae29867355e559a94e9a8429225a0284a3f5f091a3878bfc0/twine-5.1.1-py3-none-any.whl", hash = "sha256:215dbe7b4b94c2c50a7315c0275d2258399280fbb7d04182c7e55e24b5f93997", size = 38650, upload-time = "2024-06-26T15:00:43.825Z" }, + { url = "https://files.pythonhosted.org/packages/b8/88/ae8320064e32679a5429a2c9ebbc05c2bf32cefb6e076f9b07f6d685a9b4/transformers-5.3.0-py3-none-any.whl", hash = "sha256:50ac8c89c3c7033444fb3f9f53138096b997ebb70d4b5e50a2e810bf12d3d29a", size = 10661827, upload-time = "2026-03-04T17:41:42.722Z" }, ] [[package]] name = "typer" -version = "0.20.0" +version = "0.24.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "annotated-doc" }, { name = "click" }, { name = "rich" }, { name = "shellingham" }, - { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8f/28/7c85c8032b91dbe79725b6f17d2fffc595dff06a35c7a30a37bef73a1ab4/typer-0.20.0.tar.gz", hash = "sha256:1aaf6494031793e4876fb0bacfa6a912b551cf43c1e63c800df8b1a866720c37", size = 106492, upload-time = "2025-10-20T17:03:49.445Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/78/64/7713ffe4b5983314e9d436a90d5bd4f63b6054e2aca783a3cfc44cb95bbf/typer-0.20.0-py3-none-any.whl", hash = "sha256:5b463df6793ec1dca6213a3cf4c0f03bc6e322ac5e16e13ddd622a889489784a", size = 47028, upload-time = "2025-10-20T17:03:47.617Z" }, + { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" }, +] + +[[package]] +name = "typer-slim" +version = "0.24.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typer" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a7/a7/e6aecc4b4eb59598829a3b5076a93aff291b4fdaa2ded25efc4e1f4d219c/typer_slim-0.24.0.tar.gz", hash = "sha256:f0ed36127183f52ae6ced2ecb2521789995992c521a46083bfcdbb652d22ad34", size = 4776, upload-time = "2026-02-16T22:08:51.2Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a7/24/5480c20380dfd18cf33d14784096dca45a24eae6102e91d49a718d3b6855/typer_slim-0.24.0-py3-none-any.whl", hash = "sha256:d5d7ee1ee2834d5020c7c616ed5e0d0f29b9a4b1dd283bdebae198ec09778d0e", size = 3394, upload-time = "2026-02-16T22:08:49.92Z" }, ] [[package]] @@ -7022,11 +7194,11 @@ wheels = [ [[package]] name = "types-awscrt" -version = "0.29.0" +version = "0.31.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6e/77/c25c0fbdd3b269b13139c08180bcd1521957c79bd133309533384125810c/types_awscrt-0.29.0.tar.gz", hash = "sha256:7f81040846095cbaf64e6b79040434750d4f2f487544d7748b778c349d393510", size = 17715, upload-time = "2025-11-21T21:01:24.223Z" } +sdist = { url = "https://files.pythonhosted.org/packages/76/26/0aa563e229c269c528a3b8c709fc671ac2a5c564732fab0852ac6ee006cf/types_awscrt-0.31.3.tar.gz", hash = "sha256:09d3eaf00231e0f47e101bd9867e430873bc57040050e2a3bd8305cb4fc30865", size = 18178, upload-time = "2026-03-08T02:31:14.569Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/37/a9/6b7a0ceb8e6f2396cc290ae2f1520a1598842119f09b943d83d6ff01bc49/types_awscrt-0.29.0-py3-none-any.whl", hash = "sha256:ece1906d5708b51b6603b56607a702ed1e5338a2df9f31950e000f03665ac387", size = 42343, upload-time = "2025-11-21T21:01:22.979Z" }, + { url = "https://files.pythonhosted.org/packages/3e/e5/47a573bbbd0a790f8f9fe452f7188ea72b212d21c9be57d5fc0cbc442075/types_awscrt-0.31.3-py3-none-any.whl", hash = "sha256:e5ce65a00a2ab4f35eacc1e3d700d792338d56e4823ee7b4dbe017f94cfc4458", size = 43340, upload-time = "2026-03-08T02:31:13.38Z" }, ] [[package]] @@ -7043,23 +7215,23 @@ wheels = [ [[package]] name = "types-cachetools" -version = "5.5.0.20240820" +version = "6.2.0.20260317" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c2/7e/ad6ba4a56b2a994e0f0a04a61a50466b60ee88a13d10a18c83ac14a66c61/types-cachetools-5.5.0.20240820.tar.gz", hash = "sha256:b888ab5c1a48116f7799cd5004b18474cd82b5463acb5ffb2db2fc9c7b053bc0", size = 4198, upload-time = "2024-08-20T02:30:07.525Z" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/7f/16a4d8344c28193a5a74358028c2d2f753f0d9658dd98b9e1967c50045a2/types_cachetools-6.2.0.20260317.tar.gz", hash = "sha256:6d91855bcc944665897c125e720aa3c80aace929b77a64e796343701df4f61c6", size = 9812, upload-time = "2026-03-17T04:06:32.007Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/27/4d/fd7cc050e2d236d5570c4d92531c0396573a1e14b31735870e849351c717/types_cachetools-5.5.0.20240820-py3-none-any.whl", hash = "sha256:efb2ed8bf27a4b9d3ed70d33849f536362603a90b8090a328acf0cd42fda82e2", size = 4149, upload-time = "2024-08-20T02:30:06.461Z" }, + { url = "https://files.pythonhosted.org/packages/17/9a/b00b23054934c4d569c19f7278c4fb32746cd36a64a175a216d3073a4713/types_cachetools-6.2.0.20260317-py3-none-any.whl", hash = "sha256:92fa9bc50e4629e31fca67ceb3fb1de71791e314fa16c0a0d2728724dc222c8b", size = 9346, upload-time = "2026-03-17T04:06:31.184Z" }, ] [[package]] name = "types-cffi" -version = "1.17.0.20250915" +version = "2.0.0.20260316" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/2a/98/ea454cea03e5f351323af6a482c65924f3c26c515efd9090dede58f2b4b6/types_cffi-1.17.0.20250915.tar.gz", hash = "sha256:4362e20368f78dabd5c56bca8004752cc890e07a71605d9e0d9e069dbaac8c06", size = 17229, upload-time = "2025-09-15T03:01:25.31Z" } +sdist = { url = "https://files.pythonhosted.org/packages/07/4c/805b40b094eb3fd60f8d17fa7b3c58a33781311a95d0e6a74da0751ce294/types_cffi-2.0.0.20260316.tar.gz", hash = "sha256:8fb06ed4709675c999853689941133affcd2250cd6121cc11fd22c0d81ad510c", size = 17399, upload-time = "2026-03-16T07:54:43.059Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/ec/092f2b74b49ec4855cdb53050deb9699f7105b8fda6fe034c0781b8687f3/types_cffi-1.17.0.20250915-py3-none-any.whl", hash = "sha256:cef4af1116c83359c11bb4269283c50f0688e9fc1d7f0eeb390f3661546da52c", size = 20112, upload-time = "2025-09-15T03:01:24.187Z" }, + { url = "https://files.pythonhosted.org/packages/81/5e/9f1a709225ad9d0e1d7a6e4366ff285f0113c749e882d6cbeb40eab32e75/types_cffi-2.0.0.20260316-py3-none-any.whl", hash = "sha256:dd504698029db4c580385f679324621cc64d886e6a23e9821d52bc5169251302", size = 20096, upload-time = "2026-03-16T07:54:41.994Z" }, ] [[package]] @@ -7082,32 +7254,32 @@ wheels = [ [[package]] name = "types-deprecated" -version = "1.2.15.20250304" +version = "1.3.1.20260130" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0e/67/eeefaaabb03b288aad85483d410452c8bbcbf8b2bd876b0e467ebd97415b/types_deprecated-1.2.15.20250304.tar.gz", hash = "sha256:c329030553029de5cc6cb30f269c11f4e00e598c4241290179f63cda7d33f719", size = 8015, upload-time = "2025-03-04T02:48:17.894Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/97/9924e496f88412788c432891cacd041e542425fe0bffff4143a7c1c89ac4/types_deprecated-1.3.1.20260130.tar.gz", hash = "sha256:726b05e5e66d42359b1d6631835b15de62702588c8a59b877aa4b1e138453450", size = 8455, upload-time = "2026-01-30T03:58:17.401Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/e3/c18aa72ab84e0bc127a3a94e93be1a6ac2cb281371d3a45376ab7cfdd31c/types_deprecated-1.2.15.20250304-py3-none-any.whl", hash = "sha256:86a65aa550ea8acf49f27e226b8953288cd851de887970fbbdf2239c116c3107", size = 8553, upload-time = "2025-03-04T02:48:16.666Z" }, + { url = "https://files.pythonhosted.org/packages/d2/b2/6f920582af7efcd37165cd6321707f3ad5839dd24565a8a982f2bd9c6fd1/types_deprecated-1.3.1.20260130-py3-none-any.whl", hash = "sha256:593934d85c38ca321a9d301f00c42ffe13e4cf830b71b10579185ba0ce172d9a", size = 9077, upload-time = "2026-01-30T03:58:16.633Z" }, ] [[package]] name = "types-docutils" -version = "0.21.0.20250809" +version = "0.22.3.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/be/9b/f92917b004e0a30068e024e8925c7d9b10440687b96d91f26d8762f4b68c/types_docutils-0.21.0.20250809.tar.gz", hash = "sha256:cc2453c87dc729b5aae499597496e4f69b44aa5fccb27051ed8bb55b0bd5e31b", size = 54770, upload-time = "2025-08-09T03:15:42.752Z" } +sdist = { url = "https://files.pythonhosted.org/packages/44/bb/243a87fc1605a4a94c2c343d6dbddbf0d7ef7c0b9550f360b8cda8e82c39/types_docutils-0.22.3.20260322.tar.gz", hash = "sha256:e2450bb997283c3141ec5db3e436b91f0aa26efe35eb9165178ca976ccb4930b", size = 57311, upload-time = "2026-03-22T04:08:44.064Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7e/a9/46bc12e4c918c4109b67401bf87fd450babdffbebd5dbd7833f5096f42a5/types_docutils-0.21.0.20250809-py3-none-any.whl", hash = "sha256:af02c82327e8ded85f57dd85c8ebf93b6a0b643d85a44c32d471e3395604ea50", size = 89598, upload-time = "2025-08-09T03:15:41.503Z" }, + { url = "https://files.pythonhosted.org/packages/c6/4a/22c090cd4615a16917dff817cbe7c5956da376c961e024c241cd962d2c3d/types_docutils-0.22.3.20260322-py3-none-any.whl", hash = "sha256:681d4510ce9b80a0c6a593f0f9843d81f8caa786db7b39ba04d9fd5480ac4442", size = 91978, upload-time = "2026-03-22T04:08:43.117Z" }, ] [[package]] name = "types-flask-cors" -version = "5.0.0.20250413" +version = "6.0.0.20250809" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "flask" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/f3/dd2f0d274ecb77772d3ce83735f75ad14713461e8cf7e6d61a7c272037b1/types_flask_cors-5.0.0.20250413.tar.gz", hash = "sha256:b346d052f4ef3b606b73faf13e868e458f1efdbfedcbe1aba739eb2f54a6cf5f", size = 9921, upload-time = "2025-04-13T04:04:15.515Z" } +sdist = { url = "https://files.pythonhosted.org/packages/45/e0/e5dd841bf475765fb61cb04c1e70d2fd0675a0d4ddfacd50a333eafe7267/types_flask_cors-6.0.0.20250809.tar.gz", hash = "sha256:24380a2b82548634c0931d50b9aafab214eea9f85dcc04f15ab1518752a7e6aa", size = 9951, upload-time = "2025-08-09T03:16:37.454Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/34/7d64eb72d80bfd5b9e6dd31e7fe351a1c9a735f5c01e85b1d3b903a9d656/types_flask_cors-5.0.0.20250413-py3-none-any.whl", hash = "sha256:8183fdba764d45a5b40214468a1d5daa0e86c4ee6042d13f38cc428308f27a64", size = 9982, upload-time = "2025-04-13T04:04:14.27Z" }, + { url = "https://files.pythonhosted.org/packages/9f/5e/1e60c29eb5796233d4d627ca4979c4ae8da962fd0aae0cdb6e3e6a807bbc/types_flask_cors-6.0.0.20250809-py3-none-any.whl", hash = "sha256:f6d660dddab946779f4263cb561bffe275d86cb8747ce02e9fec8d340780131b", size = 9971, upload-time = "2025-08-09T03:16:36.593Z" }, ] [[package]] @@ -7125,15 +7297,15 @@ wheels = [ [[package]] name = "types-gevent" -version = "25.9.0.20251102" +version = "25.9.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-greenlet" }, { name = "types-psutil" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4c/21/552d818a475e1a31780fb7ae50308feb64211a05eb403491d1a34df95e5f/types_gevent-25.9.0.20251102.tar.gz", hash = "sha256:76f93513af63f4577bb4178c143676dd6c4780abc305f405a4e8ff8f1fa177f8", size = 38096, upload-time = "2025-11-02T03:07:42.112Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/f0/14a99ddcaa69b559fa7cec8c9de880b792bebb0b848ae865d94ea9058533/types_gevent-25.9.0.20260322.tar.gz", hash = "sha256:91257920845762f09753c08aa20fad1743ac13d2de8bcf23f4b8fe967d803732", size = 38241, upload-time = "2026-03-22T04:08:55.213Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/60/a1/776d2de31a02123f225aaa790641113ae47f738f6e8e3091d3012240a88e/types_gevent-25.9.0.20251102-py3-none-any.whl", hash = "sha256:0f14b9977cb04bf3d94444b5ae6ec5d78ac30f74c4df83483e0facec86f19d8b", size = 55592, upload-time = "2025-11-02T03:07:41.003Z" }, + { url = "https://files.pythonhosted.org/packages/89/0f/964440b57eb4ddb4aca03479a4093852e1ce79010d1c5967234e6f5d6bd9/types_gevent-25.9.0.20260322-py3-none-any.whl", hash = "sha256:21b3c269b3a20ecb0e4668289c63b97d21694d84a004ab059c1e32ab970eacc2", size = 55500, upload-time = "2026-03-22T04:08:54.103Z" }, ] [[package]] @@ -7159,23 +7331,23 @@ wheels = [ [[package]] name = "types-jmespath" -version = "1.0.2.20250809" +version = "1.1.0.20260124" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d5/ff/6848b1603ca47fff317b44dfff78cc1fb0828262f840b3ab951b619d5a22/types_jmespath-1.0.2.20250809.tar.gz", hash = "sha256:e194efec21c0aeae789f701ae25f17c57c25908e789b1123a5c6f8d915b4adff", size = 10248, upload-time = "2025-08-09T03:14:57.996Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2b/ca/c8d7fc6e450c2f8fc6f510cb194754c43b17f933f2dcabcfc6985cbb97a8/types_jmespath-1.1.0.20260124.tar.gz", hash = "sha256:29d86868e72c0820914577077b27d167dcab08b1fc92157a29d537ff7153fdfe", size = 10709, upload-time = "2026-01-24T03:18:46.557Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/0e/6a/65c8be6b6555beaf1a654ae1c2308c2e19a610c0b318a9730e691b79ac79/types_jmespath-1.0.2.20250809-py3-none-any.whl", hash = "sha256:4147d17cc33454f0dac7e78b4e18e532a1330c518d85f7f6d19e5818ab83da21", size = 11494, upload-time = "2025-08-09T03:14:57.292Z" }, + { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] [[package]] name = "types-jsonschema" -version = "4.23.0.20250516" +version = "4.26.0.20260202" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "referencing" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a0/ec/27ea5bffdb306bf261f6677a98b6993d93893b2c2e30f7ecc1d2c99d32e7/types_jsonschema-4.23.0.20250516.tar.gz", hash = "sha256:9ace09d9d35c4390a7251ccd7d833b92ccc189d24d1b347f26212afce361117e", size = 14911, upload-time = "2025-05-16T03:09:33.728Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e6/48/73ae8b388e19fc4a2a8060d0876325ec7310cfd09b53a2185186fd35959f/types_jsonschema-4.23.0.20250516-py3-none-any.whl", hash = "sha256:e7d0dd7db7e59e63c26e3230e26ffc64c4704cc5170dc21270b366a35ead1618", size = 15027, upload-time = "2025-05-16T03:09:32.499Z" }, + { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, ] [[package]] @@ -7189,11 +7361,11 @@ wheels = [ [[package]] name = "types-oauthlib" -version = "3.2.0.20250516" +version = "3.3.0.20250822" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b1/2c/dba2c193ccff2d1e2835589d4075b230d5627b9db363e9c8de153261d6ec/types_oauthlib-3.2.0.20250516.tar.gz", hash = "sha256:56bf2cffdb8443ae718d4e83008e3fbd5f861230b4774e6d7799527758119d9a", size = 24683, upload-time = "2025-05-16T03:07:42.484Z" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/6e/d08033f562053c459322333c46baa8cf8d2d8c18f30d46dd898c8fd8df77/types_oauthlib-3.3.0.20250822.tar.gz", hash = "sha256:2cd41587dd80c199e4230e3f086777e9ae525e89579c64afe5e0039ab09be9de", size = 25700, upload-time = "2025-08-22T03:02:41.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b8/54/cdd62283338616fd2448f534b29110d79a42aaabffaf5f45e7aed365a366/types_oauthlib-3.2.0.20250516-py3-none-any.whl", hash = "sha256:5799235528bc9bd262827149a1633ff55ae6e5a5f5f151f4dae74359783a31b3", size = 45671, upload-time = "2025-05-16T03:07:41.268Z" }, + { url = "https://files.pythonhosted.org/packages/18/4b/00593b8b5d055550e1fcb9af2c42fa11b0a90bf16a94759a77bc1c3c0c72/types_oauthlib-3.3.0.20250822-py3-none-any.whl", hash = "sha256:b7f4c9b9eed0e020f454e0af800b10e93dd2efd196da65744b76910cce7e70d6", size = 48800, upload-time = "2025-08-22T03:02:40.427Z" }, ] [[package]] @@ -7216,29 +7388,29 @@ wheels = [ [[package]] name = "types-openpyxl" -version = "3.1.5.20250919" +version = "3.1.5.20260322" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c4/12/8bc4a25d49f1e4b7bbca868daa3ee80b1983d8137b4986867b5b65ab2ecd/types_openpyxl-3.1.5.20250919.tar.gz", hash = "sha256:232b5906773eebace1509b8994cdadda043f692cfdba9bfbb86ca921d54d32d7", size = 100880, upload-time = "2025-09-19T02:54:39.997Z" } +sdist = { url = "https://files.pythonhosted.org/packages/77/bf/15240de4d68192d2a1f385ef2f6f1ecb29b85d2f3791dd2e2d5b980be30f/types_openpyxl-3.1.5.20260322.tar.gz", hash = "sha256:a61d66ebe1e49697853c6db8e0929e1cda2c96755e71fb676ed7fc48dfdcf697", size = 101325, upload-time = "2026-03-22T04:08:40.426Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/36/3c/d49cf3f4489a10e9ddefde18fd258f120754c5825d06d145d9a0aaac770b/types_openpyxl-3.1.5.20250919-py3-none-any.whl", hash = "sha256:bd06f18b12fd5e1c9f0b666ee6151d8140216afa7496f7ebb9fe9d33a1a3ce99", size = 166078, upload-time = "2025-09-19T02:54:38.657Z" }, + { url = "https://files.pythonhosted.org/packages/bf/b4/c14191b30bcb266365b124b2bb4e67ecd68425a78ba77ee026f33667daa9/types_openpyxl-3.1.5.20260322-py3-none-any.whl", hash = "sha256:2f515f0b0bbfb04bfb587de34f7522d90b5151a8da7bbbd11ecec4ca40f64238", size = 166102, upload-time = "2026-03-22T04:08:39.174Z" }, ] [[package]] name = "types-pexpect" -version = "4.9.0.20250916" +version = "4.9.0.20260127" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/0c/e6/cc43e306dc7de14ec7861c24ac4957f688741ae39ae685049695d796b587/types_pexpect-4.9.0.20250916.tar.gz", hash = "sha256:69e5fed6199687a730a572de780a5749248a4c5df2ff1521e194563475c9928d", size = 13322, upload-time = "2025-09-16T02:49:25.61Z" } +sdist = { url = "https://files.pythonhosted.org/packages/2e/32/7e03a07e16f79a404d6200ed6bdfcc320d0fb833436a5c6895a1403dedb7/types_pexpect-4.9.0.20260127.tar.gz", hash = "sha256:f8d43efc24251a8e533c71ea9be03d19bb5d08af096d561611697af9720cba7f", size = 13461, upload-time = "2026-01-27T03:28:30.923Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/6d/7740e235a9fb2570968da7d386d7feb511ce68cd23472402ff8cdf7fc78f/types_pexpect-4.9.0.20250916-py3-none-any.whl", hash = "sha256:7fa43cb96042ac58bc74f7c28e5d85782be0ee01344149886849e9d90936fe8a", size = 17057, upload-time = "2025-09-16T02:49:24.546Z" }, + { url = "https://files.pythonhosted.org/packages/8a/d9/7ac5c9aa5a89a1a64cd835ae348227f4939406d826e461b85b690a8ba1c2/types_pexpect-4.9.0.20260127-py3-none-any.whl", hash = "sha256:69216c0ebf0fe45ad2900823133959b027e9471e24fc3f2e4c7b00605555da5f", size = 17078, upload-time = "2026-01-27T03:28:29.848Z" }, ] [[package]] name = "types-protobuf" -version = "5.29.1.20250403" +version = "6.32.1.20260221" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/78/6d/62a2e73b966c77609560800004dd49a926920dd4976a9fdd86cf998e7048/types_protobuf-5.29.1.20250403.tar.gz", hash = "sha256:7ff44f15022119c9d7558ce16e78b2d485bf7040b4fadced4dd069bb5faf77a2", size = 59413, upload-time = "2025-04-02T10:07:17.138Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5f/e2/9aa4a3b2469508bd7b4e2ae11cbedaf419222a09a1b94daffcd5efca4023/types_protobuf-6.32.1.20260221.tar.gz", hash = "sha256:6d5fb060a616bfb076cbb61b4b3c3969f5fc8bec5810f9a2f7e648ee5cbcbf6e", size = 64408, upload-time = "2026-02-21T03:55:13.916Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/e3/b74dcc2797b21b39d5a4f08a8b08e20369b4ca250d718df7af41a60dd9f0/types_protobuf-5.29.1.20250403-py3-none-any.whl", hash = "sha256:c71de04106a2d54e5b2173d0a422058fae0ef2d058d70cf369fb797bf61ffa59", size = 73874, upload-time = "2025-04-02T10:07:15.755Z" }, + { url = "https://files.pythonhosted.org/packages/2e/e8/1fd38926f9cf031188fbc5a96694203ea6f24b0e34bd64a225ec6f6291ba/types_protobuf-6.32.1.20260221-py3-none-any.whl", hash = "sha256:da7cdd947975964a93c30bfbcc2c6841ee646b318d3816b033adc2c4eb6448e4", size = 77956, upload-time = "2026-02-21T03:55:12.894Z" }, ] [[package]] @@ -7252,11 +7424,11 @@ wheels = [ [[package]] name = "types-psycopg2" -version = "2.9.21.20251012" +version = "2.9.21.20260223" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9b/b3/2d09eaf35a084cffd329c584970a3fa07101ca465c13cad1576d7c392587/types_psycopg2-2.9.21.20251012.tar.gz", hash = "sha256:4cdafd38927da0cfde49804f39ab85afd9c6e9c492800e42f1f0c1a1b0312935", size = 26710, upload-time = "2025-10-12T02:55:39.5Z" } +sdist = { url = "https://files.pythonhosted.org/packages/55/1f/4daff0ce5e8e191844e65aaa793ed1b9cb40027dc2700906ecf2b6bcc0ed/types_psycopg2-2.9.21.20260223.tar.gz", hash = "sha256:78ed70de2e56bc6b5c26c8c1da8e9af54e49fdc3c94d1504609f3519e2b84f02", size = 27090, upload-time = "2026-02-23T04:11:18.177Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ec/0c/05feaf8cb51159f2c0af04b871dab7e98a2f83a3622f5f216331d2dd924c/types_psycopg2-2.9.21.20251012-py3-none-any.whl", hash = "sha256:712bad5c423fe979e357edbf40a07ca40ef775d74043de72bd4544ca328cc57e", size = 24883, upload-time = "2025-10-12T02:55:38.439Z" }, + { url = "https://files.pythonhosted.org/packages/8d/e7/c566df58410bc0728348b514e718f0b38fa0d248b5c10599a11494ba25d2/types_psycopg2-2.9.21.20260223-py3-none-any.whl", hash = "sha256:c6228ade72d813b0624f4c03feeb89471950ac27cd0506b5debed6f053086bc8", size = 24919, upload-time = "2026-02-23T04:11:17.214Z" }, ] [[package]] @@ -7273,11 +7445,11 @@ wheels = [ [[package]] name = "types-pymysql" -version = "1.1.0.20250916" +version = "1.1.0.20251220" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/1f/12/bda1d977c07e0e47502bede1c44a986dd45946494d89e005e04cdeb0f8de/types_pymysql-1.1.0.20250916.tar.gz", hash = "sha256:98d75731795fcc06723a192786662bdfa760e1e00f22809c104fbb47bac5e29b", size = 22131, upload-time = "2025-09-16T02:49:22.039Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d3/59/e959dd6d2f8e3b3c3f058d79ac9ece328922a5a8770c707fe9c3a757481c/types_pymysql-1.1.0.20251220.tar.gz", hash = "sha256:ae1c3df32a777489431e2e9963880a0df48f6591e0aa2fd3a6fabd9dee6eca54", size = 22184, upload-time = "2025-12-20T03:07:38.689Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/21/eb/a225e32a6e7b196af67ab2f1b07363595f63255374cc3b88bfdab53b4ee8/types_pymysql-1.1.0.20250916-py3-none-any.whl", hash = "sha256:873eb9836bb5e3de4368cc7010ca72775f86e9692a5c7810f8c7f48da082e55b", size = 23063, upload-time = "2025-09-16T02:49:20.933Z" }, + { url = "https://files.pythonhosted.org/packages/8b/fa/4f4d3bfca9ef6dd17d69ed18b96564c53b32d3ce774132308d0bee849f10/types_pymysql-1.1.0.20251220-py3-none-any.whl", hash = "sha256:fa1082af7dea6c53b6caa5784241924b1296ea3a8d3bd060417352c5e10c0618", size = 23067, upload-time = "2025-12-20T03:07:37.766Z" }, ] [[package]] @@ -7295,11 +7467,11 @@ wheels = [ [[package]] name = "types-python-dateutil" -version = "2.9.0.20251115" +version = "2.9.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/36/06d01fb52c0d57e9ad0c237654990920fa41195e4b3d640830dabf9eeb2f/types_python_dateutil-2.9.0.20251115.tar.gz", hash = "sha256:8a47f2c3920f52a994056b8786309b43143faa5a64d4cbb2722d6addabdf1a58", size = 16363, upload-time = "2025-11-15T03:00:13.717Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/02/f72df9ef5ffc4f959b83cb80c8aa03eb8718a43e563ecd99ccffe265fa89/types_python_dateutil-2.9.0.20260323.tar.gz", hash = "sha256:a107aef5841db41ace381dbbbd7e4945220fc940f7a72172a0be5a92d9ab7164", size = 16897, upload-time = "2026-03-23T04:15:14.829Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/43/0b/56961d3ba517ed0df9b3a27bfda6514f3d01b28d499d1bce9068cfe4edd1/types_python_dateutil-2.9.0.20251115-py3-none-any.whl", hash = "sha256:9cf9c1c582019753b8639a081deefd7e044b9fa36bd8217f565c6c4e36ee0624", size = 18251, upload-time = "2025-11-15T03:00:12.317Z" }, + { url = "https://files.pythonhosted.org/packages/92/c1/b661838b97453e699a215451f2e22cee750eaaf4ea4619b34bdaf01221a4/types_python_dateutil-2.9.0.20260323-py3-none-any.whl", hash = "sha256:a23a50a07f6eb87e729d4cb0c2eb511c81761eeb3f505db2c1413be94aae8335", size = 18433, upload-time = "2026-03-23T04:15:13.683Z" }, ] [[package]] @@ -7311,22 +7483,13 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/85/4f/b88274658cf489e35175be8571c970e9a1219713bafd8fc9e166d7351ecb/types_python_http_client-3.3.7.20250708-py3-none-any.whl", hash = "sha256:e2fc253859decab36713d82fc7f205868c3ddeaee79dbb55956ad9ca77abe12b", size = 8890, upload-time = "2025-07-08T03:14:35.506Z" }, ] -[[package]] -name = "types-pytz" -version = "2025.2.0.20251108" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/40/ff/c047ddc68c803b46470a357454ef76f4acd8c1088f5cc4891cdd909bfcf6/types_pytz-2025.2.0.20251108.tar.gz", hash = "sha256:fca87917836ae843f07129567b74c1929f1870610681b4c92cb86a3df5817bdb", size = 10961, upload-time = "2025-11-08T02:55:57.001Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e7/c1/56ef16bf5dcd255155cc736d276efa6ae0a5c26fd685e28f0412a4013c01/types_pytz-2025.2.0.20251108-py3-none-any.whl", hash = "sha256:0f1c9792cab4eb0e46c52f8845c8f77cf1e313cb3d68bf826aa867fe4717d91c", size = 10116, upload-time = "2025-11-08T02:55:56.194Z" }, -] - [[package]] name = "types-pywin32" -version = "310.0.0.20250516" +version = "311.0.0.20260323" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6c/bc/c7be2934a37cc8c645c945ca88450b541e482c4df3ac51e5556377d34811/types_pywin32-310.0.0.20250516.tar.gz", hash = "sha256:91e5bfc033f65c9efb443722eff8101e31d690dd9a540fa77525590d3da9cc9d", size = 328459, upload-time = "2025-05-16T03:07:57.411Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/cc/f03ddb7412ac2fc2238358b617c2d5919ba96812dff8d3081f3b2754bb83/types_pywin32-311.0.0.20260323.tar.gz", hash = "sha256:2e8dc6a59fedccbc51b241651ce1e8aa58488934f517debf23a9c6d0ff329b4b", size = 332263, upload-time = "2026-03-23T04:15:20.004Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/72/469e4cc32399dbe6c843e38fdb6d04fee755e984e137c0da502f74d3ac59/types_pywin32-310.0.0.20250516-py3-none-any.whl", hash = "sha256:f9ef83a1ec3e5aae2b0e24c5f55ab41272b5dfeaabb9a0451d33684c9545e41a", size = 390411, upload-time = "2025-05-16T03:07:56.282Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/d786d5d8b846e3cbe1ee52da8945560b111c789b42c3771b2129b312ab94/types_pywin32-311.0.0.20260323-py3-none-any.whl", hash = "sha256:2f2b03fc72ae77ccbb0ee258da0f181c3a38bd8602f6e332e42587b3b0d5f095", size = 395435, upload-time = "2026-03-23T04:15:18.76Z" }, ] [[package]] @@ -7353,41 +7516,41 @@ wheels = [ [[package]] name = "types-regex" -version = "2024.11.6.20250403" +version = "2026.2.28.20260301" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/c7/75/012b90c8557d3abb3b58a9073a94d211c8f75c9b2e26bf0d8af7ecf7bc78/types_regex-2024.11.6.20250403.tar.gz", hash = "sha256:3fdf2a70bbf830de4b3a28e9649a52d43dabb57cdb18fbfe2252eefb53666665", size = 12394, upload-time = "2025-04-03T02:54:35.379Z" } +sdist = { url = "https://files.pythonhosted.org/packages/3a/ed/106958cb686316113b748ed4209fa363fd92b15759d5409c3930fed36606/types_regex-2026.2.28.20260301.tar.gz", hash = "sha256:644c231db3f368908320170c14905731a7ae5fabdac0f60f5d6d12ecdd3bc8dd", size = 13157, upload-time = "2026-03-01T04:11:13.559Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/49/67200c4708f557be6aa4ecdb1fa212d67a10558c5240251efdc799cca22f/types_regex-2024.11.6.20250403-py3-none-any.whl", hash = "sha256:e22c0f67d73f4b4af6086a340f387b6f7d03bed8a0bb306224b75c51a29b0001", size = 10396, upload-time = "2025-04-03T02:54:34.555Z" }, + { url = "https://files.pythonhosted.org/packages/c7/bb/9bc26fcf5155bd25efeca35f8ba6bffb8b3c9da2baac8bf40067606418f3/types_regex-2026.2.28.20260301-py3-none-any.whl", hash = "sha256:7da7a1fe67528238176a5844fd435ca90617cf605341308686afbc579fdea5c0", size = 11130, upload-time = "2026-03-01T04:11:11.454Z" }, ] [[package]] name = "types-requests" -version = "2.32.4.20250913" +version = "2.32.4.20260107" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/36/27/489922f4505975b11de2b5ad07b4fe1dca0bca9be81a703f26c5f3acfce5/types_requests-2.32.4.20250913.tar.gz", hash = "sha256:abd6d4f9ce3a9383f269775a9835a4c24e5cd6b9f647d64f88aa4613c33def5d", size = 23113, upload-time = "2025-09-13T02:40:02.309Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/f3/a0663907082280664d745929205a89d41dffb29e89a50f753af7d57d0a96/types_requests-2.32.4.20260107.tar.gz", hash = "sha256:018a11ac158f801bfa84857ddec1650750e393df8a004a8a9ae2a9bec6fcb24f", size = 23165, upload-time = "2026-01-07T03:20:54.091Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2a/20/9a227ea57c1285986c4cf78400d0a91615d25b24e257fd9e2969606bdfae/types_requests-2.32.4.20250913-py3-none-any.whl", hash = "sha256:78c9c1fffebbe0fa487a418e0fa5252017e9c60d1a2da394077f1780f655d7e1", size = 20658, upload-time = "2025-09-13T02:40:01.115Z" }, + { url = "https://files.pythonhosted.org/packages/1c/12/709ea261f2bf91ef0a26a9eed20f2623227a8ed85610c1e54c5805692ecb/types_requests-2.32.4.20260107-py3-none-any.whl", hash = "sha256:b703fe72f8ce5b31ef031264fe9395cac8f46a04661a79f7ed31a80fb308730d", size = 20676, upload-time = "2026-01-07T03:20:52.929Z" }, ] [[package]] name = "types-s3transfer" -version = "0.15.0" +version = "0.16.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/79/bf/b00dcbecb037c4999b83c8109b8096fe78f87f1266cadc4f95d4af196292/types_s3transfer-0.15.0.tar.gz", hash = "sha256:43a523e0c43a88e447dfda5f4f6b63bf3da85316fdd2625f650817f2b170b5f7", size = 14236, upload-time = "2025-11-21T21:16:26.553Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fe/64/42689150509eb3e6e82b33ee3d89045de1592488842ddf23c56957786d05/types_s3transfer-0.16.0.tar.gz", hash = "sha256:b4636472024c5e2b62278c5b759661efeb52a81851cde5f092f24100b1ecb443", size = 13557, upload-time = "2025-12-08T08:13:09.928Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/39/39a322d7209cc259e3e27c4d498129e9583a2f3a8aea57eb1a9941cb5e9e/types_s3transfer-0.15.0-py3-none-any.whl", hash = "sha256:1e617b14a9d3ce5be565f4b187fafa1d96075546b52072121f8fda8e0a444aed", size = 19702, upload-time = "2025-11-21T21:16:25.146Z" }, + { url = "https://files.pythonhosted.org/packages/98/27/e88220fe6274eccd3bdf95d9382918716d312f6f6cef6a46332d1ee2feff/types_s3transfer-0.16.0-py3-none-any.whl", hash = "sha256:1c0cd111ecf6e21437cb410f5cddb631bfb2263b77ad973e79b9c6d0cb24e0ef", size = 19247, upload-time = "2025-12-08T08:13:08.426Z" }, ] [[package]] name = "types-setuptools" -version = "80.9.0.20250822" +version = "82.0.0.20260210" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/19/bd/1e5f949b7cb740c9f0feaac430e301b8f1c5f11a81e26324299ea671a237/types_setuptools-80.9.0.20250822.tar.gz", hash = "sha256:070ea7716968ec67a84c7f7768d9952ff24d28b65b6594797a464f1b3066f965", size = 41296, upload-time = "2025-08-22T03:02:08.771Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4b/90/796ac8c774a7f535084aacbaa6b7053d16fff5c630eff87c3ecff7896c37/types_setuptools-82.0.0.20260210.tar.gz", hash = "sha256:d9719fbbeb185254480ade1f25327c4654f8c00efda3fec36823379cebcdee58", size = 44768, upload-time = "2026-02-10T04:22:02.107Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b6/2d/475bf15c1cdc172e7a0d665b6e373ebfb1e9bf734d3f2f543d668b07a142/types_setuptools-80.9.0.20250822-py3-none-any.whl", hash = "sha256:53bf881cb9d7e46ed12c76ef76c0aaf28cfe6211d3fab12e0b83620b1a8642c3", size = 63179, upload-time = "2025-08-22T03:02:07.643Z" }, + { url = "https://files.pythonhosted.org/packages/3e/54/3489432b1d9bc713c9d8aa810296b8f5b0088403662959fb63a8acdbd4fc/types_setuptools-82.0.0.20260210-py3-none-any.whl", hash = "sha256:5124a7daf67f195c6054e0f00f1d97c69caad12fdcf9113eba33eff0bce8cd2b", size = 68433, upload-time = "2026-02-10T04:22:00.876Z" }, ] [[package]] @@ -7422,28 +7585,28 @@ wheels = [ [[package]] name = "types-tensorflow" -version = "2.18.0.20251008" +version = "2.18.0.20260322" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, { name = "types-protobuf" }, { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0d/0a/13bde03fb5a23faaadcca2d6914f865e444334133902310ea05e6ade780c/types_tensorflow-2.18.0.20251008.tar.gz", hash = "sha256:8db03d4dd391a362e2ea796ffdbccb03c082127606d4d852edb7ed9504745933", size = 257550, upload-time = "2025-10-08T02:51:51.104Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/cb/81dfaa2680031a6e087bcdfaf1c0556371098e229aee541e21c81a381065/types_tensorflow-2.18.0.20260322.tar.gz", hash = "sha256:135dc6ca06cc647a002e1bca5c5c99516fde51efd08e46c48a9b1916fc5df07f", size = 259030, upload-time = "2026-03-22T04:09:14.069Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/cc/e50e49db621b0cf03c1f3d10be47389de41a02dc9924c3a83a9c1a55bf28/types_tensorflow-2.18.0.20251008-py3-none-any.whl", hash = "sha256:d6b0dd4d81ac6d9c5af803ebcc8ce0f65c5850c063e8b9789dc828898944b5f4", size = 329023, upload-time = "2025-10-08T02:51:50.024Z" }, + { url = "https://files.pythonhosted.org/packages/5b/0c/a178061450b640e53577e2c423ad22bf5d3f692f6bfeeb12156d02b531ef/types_tensorflow-2.18.0.20260322-py3-none-any.whl", hash = "sha256:d8776b6daacdb279e64f105f9dcbc0b8e3544b9a2f2eb71ec6ea5955081f65e6", size = 329771, upload-time = "2026-03-22T04:09:12.844Z" }, ] [[package]] name = "types-tqdm" -version = "4.67.0.20250809" +version = "4.67.3.20260303" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "types-requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fb/d0/cf498fc630d9fdaf2428b93e60b0e67b08008fec22b78716b8323cf644dc/types_tqdm-4.67.0.20250809.tar.gz", hash = "sha256:02bf7ab91256080b9c4c63f9f11b519c27baaf52718e5fdab9e9606da168d500", size = 17200, upload-time = "2025-08-09T03:17:43.489Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e1/64/3e7cb0f40c4bf9578098b6873df33a96f7e0de90f3a039e614d22bfde40a/types_tqdm-4.67.3.20260303.tar.gz", hash = "sha256:7bfddb506a75aedb4030fabf4f05c5638c9a3bbdf900d54ec6c82be9034bfb96", size = 18117, upload-time = "2026-03-03T04:03:49.679Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3f/13/3ff0781445d7c12730befce0fddbbc7a76e56eb0e7029446f2853238360a/types_tqdm-4.67.0.20250809-py3-none-any.whl", hash = "sha256:1a73053b31fcabf3c1f3e2a9d5ecdba0f301bde47a418cd0e0bdf774827c5c57", size = 24020, upload-time = "2025-08-09T03:17:42.453Z" }, + { url = "https://files.pythonhosted.org/packages/37/32/e4a1fce59155c74082f1a42d0ffafa59652bfb8cff35b04d56333877748e/types_tqdm-4.67.3.20260303-py3-none-any.whl", hash = "sha256:459decf677e4b05cef36f9012ef8d6e20578edefb6b78c15bd0b546247eda62d", size = 24572, upload-time = "2026-03-03T04:03:48.913Z" }, ] [[package]] @@ -7500,11 +7663,11 @@ wheels = [ [[package]] name = "tzdata" -version = "2025.2" +version = "2025.3" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/32/1a225d6164441be760d75c2c42e2780dc0873fe382da3e98a2e1e48361e5/tzdata-2025.2.tar.gz", hash = "sha256:b60a638fcc0daffadf82fe0f57e53d06bdec2f36c4df66280ae79bce6bd6f2b9", size = 196380, upload-time = "2025-03-23T13:54:43.652Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5e/a7/c202b344c5ca7daf398f3b8a477eeb205cf3b6f32e7ec3a6bac0629ca975/tzdata-2025.3.tar.gz", hash = "sha256:de39c2ca5dc7b0344f2eba86f49d614019d29f060fc4ebc8a417896a620b56a7", size = 196772, upload-time = "2025-12-13T17:45:35.667Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5c/23/c7abc0ca0a1526a0774eca151daeb8de62ec457e77262b66b359c3c7679e/tzdata-2025.2-py2.py3-none-any.whl", hash = "sha256:1a403fada01ff9221ca8044d701868fa132215d84beb92242d9acd2147f667a8", size = 347839, upload-time = "2025-03-23T13:54:41.845Z" }, + { url = "https://files.pythonhosted.org/packages/c7/b0/003792df09decd6849a5e39c28b513c06e84436a54440380862b5aeff25d/tzdata-2025.3-py2.py3-none-any.whl", hash = "sha256:06a47e5700f3081aab02b2e513160914ff0694bce9947d6b76ebd6bf57cfc5d1", size = 348521, upload-time = "2025-12-13T17:45:33.889Z" }, ] [[package]] @@ -7521,47 +7684,59 @@ wheels = [ [[package]] name = "ujson" -version = "5.9.0" +version = "5.12.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6e/54/6f2bdac7117e89a47de4511c9f01732a283457ab1bf856e1e51aa861619e/ujson-5.9.0.tar.gz", hash = "sha256:89cc92e73d5501b8a7f48575eeb14ad27156ad092c2e9fc7e3cf949f07e75532", size = 7154214, upload-time = "2023-12-10T22:50:34.812Z" } +sdist = { url = "https://files.pythonhosted.org/packages/cb/3e/c35530c5ffc25b71c59ae0cd7b8f99df37313daa162ce1e2f7925f7c2877/ujson-5.12.0.tar.gz", hash = "sha256:14b2e1eb528d77bc0f4c5bd1a7ebc05e02b5b41beefb7e8567c9675b8b13bcf4", size = 7158451, upload-time = "2026-03-11T22:19:30.397Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/ca/ae3a6ca5b4f82ce654d6ac3dde5e59520537e20939592061ba506f4e569a/ujson-5.9.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:3b23bbb46334ce51ddb5dded60c662fbf7bb74a37b8f87221c5b0fec1ec6454b", size = 57753, upload-time = "2023-12-10T22:49:03.939Z" }, - { url = "https://files.pythonhosted.org/packages/34/5f/c27fa9a1562c96d978c39852b48063c3ca480758f3088dcfc0f3b09f8e93/ujson-5.9.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6974b3a7c17bbf829e6c3bfdc5823c67922e44ff169851a755eab79a3dd31ec0", size = 54092, upload-time = "2023-12-10T22:49:05.194Z" }, - { url = "https://files.pythonhosted.org/packages/19/f3/1431713de9e5992e5e33ba459b4de28f83904233958855d27da820a101f9/ujson-5.9.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b5964ea916edfe24af1f4cc68488448fbb1ec27a3ddcddc2b236da575c12c8ae", size = 51675, upload-time = "2023-12-10T22:49:06.449Z" }, - { url = "https://files.pythonhosted.org/packages/d3/93/de6fff3ae06351f3b1c372f675fe69bc180f93d237c9e496c05802173dd6/ujson-5.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ba7cac47dd65ff88571eceeff48bf30ed5eb9c67b34b88cb22869b7aa19600d", size = 53246, upload-time = "2023-12-10T22:49:07.691Z" }, - { url = "https://files.pythonhosted.org/packages/26/73/db509fe1d7da62a15c0769c398cec66bdfc61a8bdffaf7dfa9d973e3d65c/ujson-5.9.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6bbd91a151a8f3358c29355a491e915eb203f607267a25e6ab10531b3b157c5e", size = 58182, upload-time = "2023-12-10T22:49:08.89Z" }, - { url = "https://files.pythonhosted.org/packages/fc/a8/6be607fa3e1fa3e1c9b53f5de5acad33b073b6cc9145803e00bcafa729a8/ujson-5.9.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:829a69d451a49c0de14a9fecb2a2d544a9b2c884c2b542adb243b683a6f15908", size = 584493, upload-time = "2023-12-10T22:49:11.043Z" }, - { url = "https://files.pythonhosted.org/packages/c8/c7/33822c2f1a8175e841e2bc378ffb2c1109ce9280f14cedb1b2fa0caf3145/ujson-5.9.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:a807ae73c46ad5db161a7e883eec0fbe1bebc6a54890152ccc63072c4884823b", size = 656038, upload-time = "2023-12-10T22:49:12.651Z" }, - { url = "https://files.pythonhosted.org/packages/51/b8/5309fbb299d5fcac12bbf3db20896db5178392904abe6b992da233dc69d6/ujson-5.9.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8fc2aa18b13d97b3c8ccecdf1a3c405f411a6e96adeee94233058c44ff92617d", size = 597643, upload-time = "2023-12-10T22:49:14.883Z" }, - { url = "https://files.pythonhosted.org/packages/5f/64/7b63043b95dd78feed401b9973958af62645a6d19b72b6e83d1ea5af07e0/ujson-5.9.0-cp311-cp311-win32.whl", hash = "sha256:70e06849dfeb2548be48fdd3ceb53300640bc8100c379d6e19d78045e9c26120", size = 38342, upload-time = "2023-12-10T22:49:16.854Z" }, - { url = "https://files.pythonhosted.org/packages/7a/13/a3cd1fc3a1126d30b558b6235c05e2d26eeaacba4979ee2fd2b5745c136d/ujson-5.9.0-cp311-cp311-win_amd64.whl", hash = "sha256:7309d063cd392811acc49b5016728a5e1b46ab9907d321ebbe1c2156bc3c0b99", size = 41923, upload-time = "2023-12-10T22:49:17.983Z" }, - { url = "https://files.pythonhosted.org/packages/16/7e/c37fca6cd924931fa62d615cdbf5921f34481085705271696eff38b38867/ujson-5.9.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:20509a8c9f775b3a511e308bbe0b72897ba6b800767a7c90c5cca59d20d7c42c", size = 57834, upload-time = "2023-12-10T22:49:19.799Z" }, - { url = "https://files.pythonhosted.org/packages/fb/44/2753e902ee19bf6ccaf0bda02f1f0037f92a9769a5d31319905e3de645b4/ujson-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b28407cfe315bd1b34f1ebe65d3bd735d6b36d409b334100be8cdffae2177b2f", size = 54119, upload-time = "2023-12-10T22:49:21.039Z" }, - { url = "https://files.pythonhosted.org/packages/d2/06/2317433e394450bc44afe32b6c39d5a51014da4c6f6cfc2ae7bf7b4a2922/ujson-5.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9d302bd17989b6bd90d49bade66943c78f9e3670407dbc53ebcf61271cadc399", size = 51658, upload-time = "2023-12-10T22:49:22.494Z" }, - { url = "https://files.pythonhosted.org/packages/5b/3a/2acf0da085d96953580b46941504aa3c91a1dd38701b9e9bfa43e2803467/ujson-5.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f21315f51e0db8ee245e33a649dd2d9dce0594522de6f278d62f15f998e050e", size = 53370, upload-time = "2023-12-10T22:49:24.045Z" }, - { url = "https://files.pythonhosted.org/packages/03/32/737e6c4b1841720f88ae88ec91f582dc21174bd40742739e1fa16a0c9ffa/ujson-5.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5635b78b636a54a86fdbf6f027e461aa6c6b948363bdf8d4fbb56a42b7388320", size = 58278, upload-time = "2023-12-10T22:49:25.261Z" }, - { url = "https://files.pythonhosted.org/packages/8a/dc/3fda97f1ad070ccf2af597fb67dde358bc698ffecebe3bc77991d60e4fe5/ujson-5.9.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:82b5a56609f1235d72835ee109163c7041b30920d70fe7dac9176c64df87c164", size = 584418, upload-time = "2023-12-10T22:49:27.573Z" }, - { url = "https://files.pythonhosted.org/packages/d7/57/e4083d774fcd8ff3089c0ff19c424abe33f23e72c6578a8172bf65131992/ujson-5.9.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:5ca35f484622fd208f55041b042d9d94f3b2c9c5add4e9af5ee9946d2d30db01", size = 656126, upload-time = "2023-12-10T22:49:29.509Z" }, - { url = "https://files.pythonhosted.org/packages/0d/c3/8c6d5f6506ca9fcedd5a211e30a7d5ee053dc05caf23dae650e1f897effb/ujson-5.9.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:829b824953ebad76d46e4ae709e940bb229e8999e40881338b3cc94c771b876c", size = 597795, upload-time = "2023-12-10T22:49:31.029Z" }, - { url = "https://files.pythonhosted.org/packages/34/5a/a231f0cd305a34cf2d16930304132db3a7a8c3997b367dd38fc8f8dfae36/ujson-5.9.0-cp312-cp312-win32.whl", hash = "sha256:25fa46e4ff0a2deecbcf7100af3a5d70090b461906f2299506485ff31d9ec437", size = 38495, upload-time = "2023-12-10T22:49:33.2Z" }, - { url = "https://files.pythonhosted.org/packages/30/b7/18b841b44760ed298acdb150608dccdc045c41655e0bae4441f29bcab872/ujson-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:60718f1720a61560618eff3b56fd517d107518d3c0160ca7a5a66ac949c6cf1c", size = 42088, upload-time = "2023-12-10T22:49:34.921Z" }, + { url = "https://files.pythonhosted.org/packages/10/22/fd22e2f6766bae934d3050517ca47d463016bd8688508d1ecc1baa18a7ad/ujson-5.12.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58a11cb49482f1a095a2bd9a1d81dd7c8fb5d2357f959ece85db4e46a825fd00", size = 56139, upload-time = "2026-03-11T22:18:04.591Z" }, + { url = "https://files.pythonhosted.org/packages/c6/fd/6839adff4fc0164cbcecafa2857ba08a6eaeedd7e098d6713cb899a91383/ujson-5.12.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9b3cf13facf6f77c283af0e1713e5e8c47a0fe295af81326cb3cb4380212e797", size = 53836, upload-time = "2026-03-11T22:18:05.662Z" }, + { url = "https://files.pythonhosted.org/packages/f9/b0/0c19faac62d68ceeffa83a08dc3d71b8462cf5064d0e7e0b15ba19898dad/ujson-5.12.0-cp311-cp311-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:fb94245a715b4d6e24689de12772b85329a1f9946cbf6187923a64ecdea39e65", size = 57851, upload-time = "2026-03-11T22:18:06.744Z" }, + { url = "https://files.pythonhosted.org/packages/04/f6/e7fd283788de73b86e99e08256726bb385923249c21dcd306e59d532a1a1/ujson-5.12.0-cp311-cp311-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:0fe6b8b8968e11dd9b2348bd508f0f57cf49ab3512064b36bc4117328218718e", size = 59906, upload-time = "2026-03-11T22:18:07.791Z" }, + { url = "https://files.pythonhosted.org/packages/d7/3a/b100735a2b43ee6e8fe4c883768e362f53576f964d4ea841991060aeaf35/ujson-5.12.0-cp311-cp311-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:89e302abd3749f6d6699691747969a5d85f7c73081d5ed7e2624c7bd9721a2ab", size = 57409, upload-time = "2026-03-11T22:18:08.79Z" }, + { url = "https://files.pythonhosted.org/packages/5c/fa/f97cc20c99ca304662191b883ae13ae02912ca7244710016ba0cb8a5be34/ujson-5.12.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:0727363b05ab05ee737a28f6200dc4078bce6b0508e10bd8aab507995a15df61", size = 1037339, upload-time = "2026-03-11T22:18:10.424Z" }, + { url = "https://files.pythonhosted.org/packages/10/7a/53ddeda0ffe1420db2f9999897b3cbb920fbcff1849d1f22b196d0f34785/ujson-5.12.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b62cb9a7501e1f5c9ffe190485501349c33e8862dde4377df774e40b8166871f", size = 1196625, upload-time = "2026-03-11T22:18:11.82Z" }, + { url = "https://files.pythonhosted.org/packages/0d/1a/4c64a6bef522e9baf195dd5be151bc815cd4896c50c6e2489599edcda85f/ujson-5.12.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:a6ec5bf6bc361f2f0f9644907a36ce527715b488988a8df534120e5c34eeda94", size = 1089669, upload-time = "2026-03-11T22:18:13.343Z" }, + { url = "https://files.pythonhosted.org/packages/18/11/8ccb109f5777ec0d9fb826695a9e2ac36ae94c1949fc8b1e4d23a5bd067a/ujson-5.12.0-cp311-cp311-win32.whl", hash = "sha256:006428d3813b87477d72d306c40c09f898a41b968e57b15a7d88454ecc42a3fb", size = 39648, upload-time = "2026-03-11T22:18:14.785Z" }, + { url = "https://files.pythonhosted.org/packages/6f/e3/87fc4c27b20d5125cff7ce52d17ea7698b22b74426da0df238e3efcb0cf2/ujson-5.12.0-cp311-cp311-win_amd64.whl", hash = "sha256:40aa43a7a3a8d2f05e79900858053d697a88a605e3887be178b43acbcd781161", size = 43876, upload-time = "2026-03-11T22:18:15.768Z" }, + { url = "https://files.pythonhosted.org/packages/9e/21/324f0548a8c8c48e3e222eaed15fb6d48c796593002b206b4a28a89e445f/ujson-5.12.0-cp311-cp311-win_arm64.whl", hash = "sha256:561f89cc82deeae82e37d4a4764184926fb432f740a9691563a391b13f7339a4", size = 38553, upload-time = "2026-03-11T22:18:17.251Z" }, + { url = "https://files.pythonhosted.org/packages/84/f6/ac763d2108d28f3a40bb3ae7d2fafab52ca31b36c2908a4ad02cd3ceba2a/ujson-5.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:09b4beff9cc91d445d5818632907b85fb06943b61cb346919ce202668bf6794a", size = 56326, upload-time = "2026-03-11T22:18:18.467Z" }, + { url = "https://files.pythonhosted.org/packages/25/46/d0b3af64dcdc549f9996521c8be6d860ac843a18a190ffc8affeb7259687/ujson-5.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ca0c7ce828bb76ab78b3991904b477c2fd0f711d7815c252d1ef28ff9450b052", size = 53910, upload-time = "2026-03-11T22:18:19.502Z" }, + { url = "https://files.pythonhosted.org/packages/9a/10/853c723bcabc3e9825a079019055fc99e71b85c6bae600607a2b9d31d18d/ujson-5.12.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a2d79c6635ccffcbfc1d5c045874ba36b594589be81d50d43472570bb8de9c57", size = 57754, upload-time = "2026-03-11T22:18:20.874Z" }, + { url = "https://files.pythonhosted.org/packages/f9/c6/6e024830d988f521f144ead641981c1f7a82c17ad1927c22de3242565f5c/ujson-5.12.0-cp312-cp312-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:7e07f6f644d2c44d53b7a320a084eef98063651912c1b9449b5f45fcbdc6ccd2", size = 59936, upload-time = "2026-03-11T22:18:21.924Z" }, + { url = "https://files.pythonhosted.org/packages/34/c9/c5f236af5abe06b720b40b88819d00d10182d2247b1664e487b3ed9229cf/ujson-5.12.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:085b6ce182cdd6657481c7c4003a417e0655c4f6e58b76f26ee18f0ae21db827", size = 57463, upload-time = "2026-03-11T22:18:22.924Z" }, + { url = "https://files.pythonhosted.org/packages/ae/04/41342d9ef68e793a87d84e4531a150c2b682f3bcedfe59a7a5e3f73e9213/ujson-5.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:16b4fe9c97dc605f5e1887a9e1224287291e35c56cbc379f8aa44b6b7bcfe2bb", size = 1037239, upload-time = "2026-03-11T22:18:24.04Z" }, + { url = "https://files.pythonhosted.org/packages/d4/81/dc2b7617d5812670d4ff4a42f6dd77926430ee52df0dedb2aec7990b2034/ujson-5.12.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0d2e8db5ade3736a163906154ca686203acc7d1d30736cbf577c730d13653d84", size = 1196713, upload-time = "2026-03-11T22:18:25.391Z" }, + { url = "https://files.pythonhosted.org/packages/b6/9c/80acff0504f92459ed69e80a176286e32ca0147ac6a8252cd0659aad3227/ujson-5.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:93bc91fdadcf046da37a214eaa714574e7e9b1913568e93bb09527b2ceb7f759", size = 1089742, upload-time = "2026-03-11T22:18:26.738Z" }, + { url = "https://files.pythonhosted.org/packages/e3/f0/123ffaac17e45ef2b915e3e3303f8f4ea78bb8d42afad828844e08622b1e/ujson-5.12.0-cp312-cp312-win32.whl", hash = "sha256:2a248750abce1c76fbd11b2e1d88b95401e72819295c3b851ec73399d6849b3d", size = 39773, upload-time = "2026-03-11T22:18:28.244Z" }, + { url = "https://files.pythonhosted.org/packages/b5/20/f3bd2b069c242c2b22a69e033bfe224d1d15d3649e6cd7cc7085bb1412ff/ujson-5.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:1b5c6ceb65fecd28a1d20d1eba9dbfa992612b86594e4b6d47bb580d2dd6bcb3", size = 44040, upload-time = "2026-03-11T22:18:29.236Z" }, + { url = "https://files.pythonhosted.org/packages/f0/a7/01b5a0bcded14cd2522b218f2edc3533b0fcbccdea01f3e14a2b699071aa/ujson-5.12.0-cp312-cp312-win_arm64.whl", hash = "sha256:9a5fcbe7b949f2e95c47ea8a80b410fcdf2da61c98553b45a4ee875580418b68", size = 38526, upload-time = "2026-03-11T22:18:30.551Z" }, + { url = "https://files.pythonhosted.org/packages/95/3c/5ee154d505d1aad2debc4ba38b1a60ae1949b26cdb5fa070e85e320d6b64/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_10_13_x86_64.whl", hash = "sha256:bf85a00ac3b56a1e7a19c5be7b02b5180a0895ac4d3c234d717a55e86960691c", size = 54494, upload-time = "2026-03-11T22:19:13.035Z" }, + { url = "https://files.pythonhosted.org/packages/ce/b3/9496ec399ec921e434a93b340bd5052999030b7ac364be4cbe5365ac6b20/ujson-5.12.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:64df53eef4ac857eb5816a56e2885ccf0d7dff6333c94065c93b39c51063e01d", size = 57999, upload-time = "2026-03-11T22:19:14.385Z" }, + { url = "https://files.pythonhosted.org/packages/0e/da/e9ae98133336e7c0d50b43626c3f2327937cecfa354d844e02ac17379ed1/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6c0aed6a4439994c9666fb8a5b6c4eac94d4ef6ddc95f9b806a599ef83547e3b", size = 54518, upload-time = "2026-03-11T22:19:15.4Z" }, + { url = "https://files.pythonhosted.org/packages/58/10/978d89dded6bb1558cd46ba78f4351198bd2346db8a8ee1a94119022ce40/ujson-5.12.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efae5df7a8cc8bdb1037b0f786b044ce281081441df5418c3a0f0e1f86fe7bb3", size = 55736, upload-time = "2026-03-11T22:19:16.496Z" }, + { url = "https://files.pythonhosted.org/packages/80/25/1df8e6217c92e57a1266bf5be750b1dddc126ee96e53fe959d5693503bc6/ujson-5.12.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:8712b61eb1b74a4478cfd1c54f576056199e9f093659334aeb5c4a6b385338e5", size = 44615, upload-time = "2026-03-11T22:19:17.53Z" }, + { url = "https://files.pythonhosted.org/packages/19/fa/f4a957dddb99bd68c8be91928c0b6fefa7aa8aafc92c93f5d1e8b32f6702/ujson-5.12.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:871c0e5102e47995b0e37e8df7819a894a6c3da0d097545cd1f9f1f7d7079927", size = 52145, upload-time = "2026-03-11T22:19:18.566Z" }, + { url = "https://files.pythonhosted.org/packages/55/6e/50b5cf612de1ca06c7effdc5a5d7e815774dee85a5858f1882c425553b82/ujson-5.12.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:56ba3f7abbd6b0bb282a544dc38406d1a188d8bb9164f49fdb9c2fee62cb29da", size = 49577, upload-time = "2026-03-11T22:19:19.627Z" }, + { url = "https://files.pythonhosted.org/packages/6e/24/b6713fa9897774502cd4c2d6955bb4933349f7d84c3aa805531c382a4209/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c5a52987a990eb1bae55f9000994f1afdb0326c154fb089992f839ab3c30688", size = 50807, upload-time = "2026-03-11T22:19:20.778Z" }, + { url = "https://files.pythonhosted.org/packages/1f/b6/c0e0f7901180ef80d16f3a4bccb5dc8b01515a717336a62928963a07b80b/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_i686.manylinux_2_28_i686.whl", hash = "sha256:adf28d13a33f9d750fe7a78fb481cac298fa257d8863d8727b2ea4455ea41235", size = 56972, upload-time = "2026-03-11T22:19:21.84Z" }, + { url = "https://files.pythonhosted.org/packages/02/a9/05d91b4295ea7239151eb08cf240e5a2ba969012fda50bc27bcb1ea9cd71/ujson-5.12.0-pp311-pypy311_pp73-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:51acc750ec7a2df786cdc868fb16fa04abd6269a01d58cf59bafc57978773d8e", size = 52045, upload-time = "2026-03-11T22:19:22.879Z" }, + { url = "https://files.pythonhosted.org/packages/e3/7a/92047d32bf6f2d9db64605fc32e8eb0e0dd68b671eaafc12a464f69c4af4/ujson-5.12.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:ab9056d94e5db513d9313b34394f3a3b83e6301a581c28ad67773434f3faccab", size = 44053, upload-time = "2026-03-11T22:19:23.918Z" }, ] [[package]] name = "unstructured" -version = "0.18.31" +version = "0.21.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "backoff" }, { name = "beautifulsoup4" }, { name = "charset-normalizer" }, - { name = "dataclasses-json" }, { name = "emoji" }, + { name = "filelock" }, { name = "filetype" }, { name = "html5lib" }, + { name = "installer" }, { name = "langdetect" }, { name = "lxml" }, - { name = "nltk" }, { name = "numba" }, { name = "numpy" }, { name = "psutil" }, @@ -7569,15 +7744,17 @@ dependencies = [ { name = "python-magic" }, { name = "python-oxmsg" }, { name = "rapidfuzz" }, + { name = "regex" }, { name = "requests" }, + { name = "spacy" }, { name = "tqdm" }, { name = "typing-extensions" }, { name = "unstructured-client" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b3/e6/fbef61517d130af1def3b81681e253a5679f19de2f04e439afbbf1f021e0/unstructured-0.21.5.tar.gz", hash = "sha256:3e220d0c2b9c8ec12c99767162b95ab0acfca75e979b82c66c15ca15caa60139", size = 1501811, upload-time = "2026-02-24T15:29:27.84Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" }, + { url = "https://files.pythonhosted.org/packages/b8/b6/7e6dd60bde81d5a4d4ddf426f566a5d1b4c30490053caed69e47f55c676f/unstructured-0.21.5-py3-none-any.whl", hash = "sha256:d88a277c368462b69a8843b9cb22476f3cc4d0a58455536520359387224b3366", size = 1554925, upload-time = "2026-02-24T15:29:26.009Z" }, ] [package.optional-dependencies] @@ -7585,7 +7762,7 @@ docx = [ { name = "python-docx" }, ] epub = [ - { name = "pypandoc" }, + { name = "pypandoc-binary" }, ] md = [ { name = "markdown" }, @@ -7599,7 +7776,7 @@ pptx = [ [[package]] name = "unstructured-client" -version = "0.42.4" +version = "0.42.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiofiles" }, @@ -7608,23 +7785,24 @@ dependencies = [ { name = "httpx" }, { name = "pydantic" }, { name = "pypdf" }, + { name = "pypdfium2" }, { name = "requests-toolbelt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/a4/8f/43c9a936a153e62f18e7629128698feebd81d2cfff2835febc85377b8eb8/unstructured_client-0.42.4.tar.gz", hash = "sha256:144ecd231a11d091cdc76acf50e79e57889269b8c9d8b9df60e74cf32ac1ba5e", size = 91404, upload-time = "2025-11-14T16:59:25.131Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f7/3e/dd81a2065e50b5b013c9d12a0b6346f86b3252d43a65269a72761e234bcb/unstructured_client-0.42.10.tar.gz", hash = "sha256:e516299c27178865dbd4e2bbd6f00a820ddd40323b2578f303106732fc576217", size = 94726, upload-time = "2026-02-03T18:01:50.776Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5e/6c/7c69e4353e5bdd05fc247c2ec1d840096eb928975697277b015c49405b0f/unstructured_client-0.42.4-py3-none-any.whl", hash = "sha256:fc6341344dd2f2e2aed793636b5f4e6204cad741ff2253d5a48ff2f2bccb8e9a", size = 207863, upload-time = "2025-11-14T16:59:23.674Z" }, + { url = "https://files.pythonhosted.org/packages/c1/f9/bb9b9e7df245549e2daae58b54fdd612f016111c5b06df3c66965ac8545e/unstructured_client-0.42.10-py3-none-any.whl", hash = "sha256:0034ddcd988e17db83080db26fb36f23c24ace34afedeb267dab245029f8f7a2", size = 220161, upload-time = "2026-02-03T18:01:49.487Z" }, ] [[package]] name = "upstash-vector" -version = "0.6.0" +version = "0.8.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "httpx" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/a6/a9178fef247687917701a60eb66542eb5361c58af40c033ba8174ff7366d/upstash_vector-0.6.0.tar.gz", hash = "sha256:a716ed4d0251362208518db8b194158a616d37d1ccbb1155f619df690599e39b", size = 15075, upload-time = "2024-09-27T12:02:13.533Z" } +sdist = { url = "https://files.pythonhosted.org/packages/65/22/1b9161b82ef52addc2b71ffca9498cb745b34b2e43e77ef1c921d96fb3f1/upstash_vector-0.8.0.tar.gz", hash = "sha256:cdeeeeabe08c813f0f525d9b6ceefbf17abb720bd30190cd6df88b9f2c318334", size = 18565, upload-time = "2025-02-27T11:52:38.14Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5d/45/95073b83b7fd7b83f10ea314f197bae3989bfe022e736b90145fe9ea4362/upstash_vector-0.6.0-py3-none-any.whl", hash = "sha256:d0bdad7765b8a7f5c205b7a9c81ca4b9a4cee3ee4952afc7d5ea5fb76c3f3c3c", size = 15061, upload-time = "2024-09-27T12:02:12.041Z" }, + { url = "https://files.pythonhosted.org/packages/ab/ce/1528e6e37d4a1ba7a333ebca7191b638986f4ba9f73ba17458b45c4d36e2/upstash_vector-0.8.0-py3-none-any.whl", hash = "sha256:e8a7560e6e80e22ff2a4d95ff0b08723b22bafaae7dab38eddce51feb30c5785", size = 18480, upload-time = "2025-02-27T11:52:36.189Z" }, ] [[package]] @@ -7645,6 +7823,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, ] +[[package]] +name = "uuid-utils" +version = "0.14.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/d1/38a573f0c631c062cf42fa1f5d021d4dd3c31fb23e4376e4b56b0c9fbbed/uuid_utils-0.14.1.tar.gz", hash = "sha256:9bfc95f64af80ccf129c604fb6b8ca66c6f256451e32bc4570f760e4309c9b69", size = 22195, upload-time = "2026-02-20T22:50:38.833Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/b7/add4363039a34506a58457d96d4aa2126061df3a143eb4d042aedd6a2e76/uuid_utils-0.14.1-cp39-abi3-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:93a3b5dc798a54a1feb693f2d1cb4cf08258c32ff05ae4929b5f0a2ca624a4f0", size = 604679, upload-time = "2026-02-20T22:50:27.469Z" }, + { url = "https://files.pythonhosted.org/packages/dd/84/d1d0bef50d9e66d31b2019997c741b42274d53dde2e001b7a83e9511c339/uuid_utils-0.14.1-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:ccd65a4b8e83af23eae5e56d88034b2fe7264f465d3e830845f10d1591b81741", size = 309346, upload-time = "2026-02-20T22:50:31.857Z" }, + { url = "https://files.pythonhosted.org/packages/ef/ed/b6d6fd52a6636d7c3eddf97d68da50910bf17cd5ac221992506fb56cf12e/uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b56b0cacd81583834820588378e432b0696186683b813058b707aedc1e16c4b1", size = 344714, upload-time = "2026-02-20T22:50:42.642Z" }, + { url = "https://files.pythonhosted.org/packages/a8/a7/a19a1719fb626fe0b31882db36056d44fe904dc0cf15b06fdf56b2679cf7/uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:bb3cf14de789097320a3c56bfdfdd51b1225d11d67298afbedee7e84e3837c96", size = 350914, upload-time = "2026-02-20T22:50:36.487Z" }, + { url = "https://files.pythonhosted.org/packages/1d/fc/f6690e667fdc3bb1a73f57951f97497771c56fe23e3d302d7404be394d4f/uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:60e0854a90d67f4b0cc6e54773deb8be618f4c9bad98d3326f081423b5d14fae", size = 482609, upload-time = "2026-02-20T22:50:37.511Z" }, + { url = "https://files.pythonhosted.org/packages/54/6e/dcd3fa031320921a12ec7b4672dea3bd1dd90ddffa363a91831ba834d559/uuid_utils-0.14.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ce6743ba194de3910b5feb1a62590cd2587e33a73ab6af8a01b642ceb5055862", size = 345699, upload-time = "2026-02-20T22:50:46.87Z" }, + { url = "https://files.pythonhosted.org/packages/04/28/e5220204b58b44ac0047226a9d016a113fde039280cc8732d9e6da43b39f/uuid_utils-0.14.1-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:043fb58fde6cf1620a6c066382f04f87a8e74feb0f95a585e4ed46f5d44af57b", size = 372205, upload-time = "2026-02-20T22:50:28.438Z" }, + { url = "https://files.pythonhosted.org/packages/c7/d9/3d2eb98af94b8dfffc82b6a33b4dfc87b0a5de2c68a28f6dde0db1f8681b/uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c915d53f22945e55fe0d3d3b0b87fd965a57f5fd15666fd92d6593a73b1dd297", size = 521836, upload-time = "2026-02-20T22:50:23.057Z" }, + { url = "https://files.pythonhosted.org/packages/a8/15/0eb106cc6fe182f7577bc0ab6e2f0a40be247f35c5e297dbf7bbc460bd02/uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:0972488e3f9b449e83f006ead5a0e0a33ad4a13e4462e865b7c286ab7d7566a3", size = 625260, upload-time = "2026-02-20T22:50:25.949Z" }, + { url = "https://files.pythonhosted.org/packages/3c/17/f539507091334b109e7496830af2f093d9fc8082411eafd3ece58af1f8ba/uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:1c238812ae0c8ffe77d8d447a32c6dfd058ea4631246b08b5a71df586ff08531", size = 587824, upload-time = "2026-02-20T22:50:35.225Z" }, + { url = "https://files.pythonhosted.org/packages/2e/c2/d37a7b2e41f153519367d4db01f0526e0d4b06f1a4a87f1c5dfca5d70a8b/uuid_utils-0.14.1-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:bec8f8ef627af86abf8298e7ec50926627e29b34fa907fcfbedb45aaa72bca43", size = 551407, upload-time = "2026-02-20T22:50:44.915Z" }, + { url = "https://files.pythonhosted.org/packages/65/36/2d24b2cbe78547c6532da33fb8613debd3126eccc33a6374ab788f5e46e9/uuid_utils-0.14.1-cp39-abi3-win32.whl", hash = "sha256:b54d6aa6252d96bac1fdbc80d26ba71bad9f220b2724d692ad2f2310c22ef523", size = 183476, upload-time = "2026-02-20T22:50:32.745Z" }, + { url = "https://files.pythonhosted.org/packages/83/92/2d7e90df8b1a69ec4cff33243ce02b7a62f926ef9e2f0eca5a026889cd73/uuid_utils-0.14.1-cp39-abi3-win_amd64.whl", hash = "sha256:fc27638c2ce267a0ce3e06828aff786f91367f093c80625ee21dad0208e0f5ba", size = 187147, upload-time = "2026-02-20T22:50:45.807Z" }, + { url = "https://files.pythonhosted.org/packages/d9/26/529f4beee17e5248e37e0bc17a2761d34c0fa3b1e5729c88adb2065bae6e/uuid_utils-0.14.1-cp39-abi3-win_arm64.whl", hash = "sha256:b04cb49b42afbc4ff8dbc60cf054930afc479d6f4dd7f1ec3bbe5dbfdde06b7a", size = 188132, upload-time = "2026-02-20T22:50:41.718Z" }, + { url = "https://files.pythonhosted.org/packages/91/f9/6c64bdbf71f58ccde7919e00491812556f446a5291573af92c49a5e9aaef/uuid_utils-0.14.1-pp311-pypy311_pp73-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:b197cd5424cf89fb019ca7f53641d05bfe34b1879614bed111c9c313b5574cd8", size = 591617, upload-time = "2026-02-20T22:50:24.532Z" }, + { url = "https://files.pythonhosted.org/packages/d0/f0/758c3b0fb0c4871c7704fef26a5bc861de4f8a68e4831669883bebe07b0f/uuid_utils-0.14.1-pp311-pypy311_pp73-macosx_10_12_x86_64.whl", hash = "sha256:12c65020ba6cb6abe1d57fcbfc2d0ea0506c67049ee031714057f5caf0f9bc9c", size = 303702, upload-time = "2026-02-20T22:50:40.687Z" }, + { url = "https://files.pythonhosted.org/packages/85/89/d91862b544c695cd58855efe3201f83894ed82fffe34500774238ab8eba7/uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b5d2ad28063d422ccc2c28d46471d47b61a58de885d35113a8f18cb547e25bf", size = 337678, upload-time = "2026-02-20T22:50:39.768Z" }, + { url = "https://files.pythonhosted.org/packages/ee/6b/cf342ba8a898f1de024be0243fac67c025cad530c79ea7f89c4ce718891a/uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:da2234387b45fde40b0fedfee64a0ba591caeea9c48c7698ab6e2d85c7991533", size = 343711, upload-time = "2026-02-20T22:50:43.965Z" }, + { url = "https://files.pythonhosted.org/packages/b3/20/049418d094d396dfa6606b30af925cc68a6670c3b9103b23e6990f84b589/uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:50fffc2827348c1e48972eed3d1c698959e63f9d030aa5dd82ba451113158a62", size = 476731, upload-time = "2026-02-20T22:50:30.589Z" }, + { url = "https://files.pythonhosted.org/packages/77/a1/0857f64d53a90321e6a46a3d4cc394f50e1366132dcd2ae147f9326ca98b/uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1dbe718765f70f5b7f9b7f66b6a937802941b1cc56bcf642ce0274169741e01", size = 338902, upload-time = "2026-02-20T22:50:33.927Z" }, + { url = "https://files.pythonhosted.org/packages/ed/d0/5bf7cbf1ac138c92b9ac21066d18faf4d7e7f651047b700eb192ca4b9fdb/uuid_utils-0.14.1-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:258186964039a8e36db10810c1ece879d229b01331e09e9030bc5dcabe231bd2", size = 364700, upload-time = "2026-02-20T22:50:21.732Z" }, +] + [[package]] name = "uuid6" version = "2025.0.1" @@ -7656,15 +7863,15 @@ wheels = [ [[package]] name = "uvicorn" -version = "0.38.0" +version = "0.42.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "h11" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cb/ce/f06b84e2697fef4688ca63bdb2fdf113ca0a3be33f94488f2cadb690b0cf/uvicorn-0.38.0.tar.gz", hash = "sha256:fd97093bdd120a2609fc0d3afe931d4d4ad688b6e75f0f929fde1bc36fe0e91d", size = 80605, upload-time = "2025-10-18T13:46:44.63Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e3/ad/4a96c425be6fb67e0621e62d86c402b4a17ab2be7f7c055d9bd2f638b9e2/uvicorn-0.42.0.tar.gz", hash = "sha256:9b1f190ce15a2dd22e7758651d9b6d12df09a13d51ba5bf4fc33c383a48e1775", size = 85393, upload-time = "2026-03-16T06:19:50.077Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ee/d9/d88e73ca598f4f6ff671fb5fde8a32925c2e08a637303a1d12883c7305fa/uvicorn-0.38.0-py3-none-any.whl", hash = "sha256:48c0afd214ceb59340075b4a052ea1ee91c16fbc2a9b1469cca0e54566977b02", size = 68109, upload-time = "2025-10-18T13:46:42.958Z" }, + { url = "https://files.pythonhosted.org/packages/0a/89/f8827ccff89c1586027a105e5630ff6139a64da2515e24dafe860bd9ae4d/uvicorn-0.42.0-py3-none-any.whl", hash = "sha256:96c30f5c7abe6f74ae8900a70e92b85ad6613b745d4879eb9b16ccad15645359", size = 68830, upload-time = "2026-03-16T06:19:48.325Z" }, ] [package.optional-dependencies] @@ -7716,20 +7923,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/03/ff/7c0c86c43b3cbb927e0ccc0255cb4057ceba4799cd44ae95174ce8e8b5b2/vine-5.1.0-py3-none-any.whl", hash = "sha256:40fdf3c48b2cfe1c38a49e9ae2da6fda88e4794c810050a728bd7413811fb1dc", size = 9636, upload-time = "2023-11-05T08:46:51.205Z" }, ] -[[package]] -name = "virtualenv" -version = "20.36.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "distlib" }, - { name = "filelock" }, - { name = "platformdirs" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/aa/a3/4d310fa5f00863544e1d0f4de93bddec248499ccf97d4791bc3122c9d4f3/virtualenv-20.36.1.tar.gz", hash = "sha256:8befb5c81842c641f8ee658481e42641c68b5eab3521d8e092d18320902466ba", size = 6032239, upload-time = "2026-01-09T18:21:01.296Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6a/2a/dc2228b2888f51192c7dc766106cd475f1b768c10caaf9727659726f7391/virtualenv-20.36.1-py3-none-any.whl", hash = "sha256:575a8d6b124ef88f6f51d56d656132389f961062a9177016a50e4f507bbcc19f", size = 6008258, upload-time = "2026-01-09T18:20:59.425Z" }, -] - [[package]] name = "volcengine-compat" version = "1.0.156" @@ -7750,7 +7943,7 @@ wheels = [ [[package]] name = "wandb" -version = "0.23.1" +version = "0.25.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -7764,17 +7957,29 @@ dependencies = [ { name = "sentry-sdk" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0a/cc/770ae3aa7ae44f6792f7ecb81c14c0e38b672deb35235719bb1006519487/wandb-0.23.1.tar.gz", hash = "sha256:f6fb1e3717949b29675a69359de0eeb01e67d3360d581947d5b3f98c273567d6", size = 44298053, upload-time = "2025-12-03T02:25:10.79Z" } +sdist = { url = "https://files.pythonhosted.org/packages/60/bb/eb579bf9abac70934a014a9d4e45346aab307994f3021d201bebe5fa25ec/wandb-0.25.1.tar.gz", hash = "sha256:b2a95cd777ecbe7499599a43158834983448a0048329bc7210ef46ca18d21994", size = 43983308, upload-time = "2026-03-10T23:51:44.227Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/0b/c3d7053dfd93fd259a63c7818d9c4ac2ba0642ff8dc8db98662ea0cf9cc0/wandb-0.23.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:358e15471d19b7d73fc464e37371c19d44d39e433252ac24df107aff993a286b", size = 21527293, upload-time = "2025-12-03T02:24:48.011Z" }, - { url = "https://files.pythonhosted.org/packages/ee/9f/059420fa0cb6c511dc5c5a50184122b6aca7b178cb2aa210139e354020da/wandb-0.23.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:110304407f4b38f163bdd50ed5c5225365e4df3092f13089c30171a75257b575", size = 22745926, upload-time = "2025-12-03T02:24:50.519Z" }, - { url = "https://files.pythonhosted.org/packages/96/b6/fd465827c14c64d056d30b4c9fcf4dac889a6969dba64489a88fc4ffa333/wandb-0.23.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:6cc984cf85feb2f8ee0451d76bc9fb7f39da94956bb8183e30d26284cf203b65", size = 21212973, upload-time = "2025-12-03T02:24:52.828Z" }, - { url = "https://files.pythonhosted.org/packages/5c/ee/9a8bb9a39cc1f09c3060456cc79565110226dc4099a719af5c63432da21d/wandb-0.23.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:67431cd3168d79fdb803e503bd669c577872ffd5dadfa86de733b3274b93088e", size = 22887885, upload-time = "2025-12-03T02:24:55.281Z" }, - { url = "https://files.pythonhosted.org/packages/6d/4d/8d9e75add529142e037b05819cb3ab1005679272950128d69d218b7e5b2e/wandb-0.23.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:07be70c0baa97ea25fadc4a9d0097f7371eef6dcacc5ceb525c82491a31e9244", size = 21250967, upload-time = "2025-12-03T02:24:57.603Z" }, - { url = "https://files.pythonhosted.org/packages/97/72/0b35cddc4e4168f03c759b96d9f671ad18aec8bdfdd84adfea7ecb3f5701/wandb-0.23.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:216c95b08e0a2ec6a6008373b056d597573d565e30b43a7a93c35a171485ee26", size = 22988382, upload-time = "2025-12-03T02:25:00.518Z" }, - { url = "https://files.pythonhosted.org/packages/c0/6d/e78093d49d68afb26f5261a70fc7877c34c114af5c2ee0ab3b1af85f5e76/wandb-0.23.1-py3-none-win32.whl", hash = "sha256:fb5cf0f85692f758a5c36ab65fea96a1284126de64e836610f92ddbb26df5ded", size = 22150756, upload-time = "2025-12-03T02:25:02.734Z" }, - { url = "https://files.pythonhosted.org/packages/05/27/4f13454b44c9eceaac3d6e4e4efa2230b6712d613ff9bf7df010eef4fd18/wandb-0.23.1-py3-none-win_amd64.whl", hash = "sha256:21c8c56e436eb707b7d54f705652e030d48e5cfcba24cf953823eb652e30e714", size = 22150760, upload-time = "2025-12-03T02:25:05.106Z" }, - { url = "https://files.pythonhosted.org/packages/30/20/6c091d451e2a07689bfbfaeb7592d488011420e721de170884fedd68c644/wandb-0.23.1-py3-none-win_arm64.whl", hash = "sha256:8aee7f3bb573f2c0acf860f497ca9c684f9b35f2ca51011ba65af3d4592b77c1", size = 20137463, upload-time = "2025-12-03T02:25:08.317Z" }, + { url = "https://files.pythonhosted.org/packages/e7/d8/873553b6818499d1b1de314067d528b892897baf0dc81fedc0e845abc2dd/wandb-0.25.1-py3-none-macosx_12_0_arm64.whl", hash = "sha256:9bb0679a3e2dcd96db9d9b6d3e17d046241d8d122974b24facb85cc93309a8c9", size = 23615900, upload-time = "2026-03-10T23:51:06.278Z" }, + { url = "https://files.pythonhosted.org/packages/71/ea/b131f319aaa5d0bf7572b6bfcff3dd89e1cf92b17eee443bbab71d12d74c/wandb-0.25.1-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:0fb13ed18914027523e7b4fc20380c520e0d10da0ee452f924a13f84509fbe12", size = 25576144, upload-time = "2026-03-10T23:51:11.527Z" }, + { url = "https://files.pythonhosted.org/packages/70/5f/81508581f0bb77b0495665c1c78e77606a48e66e855ca71ba7c8ae29efa4/wandb-0.25.1-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:cc4521eb5223429ddab5e8eee9b42fdf4caabdf0bc4e0e809042720e5fbef0ed", size = 23070425, upload-time = "2026-03-10T23:51:15.71Z" }, + { url = "https://files.pythonhosted.org/packages/f2/c7/445155ef010e2e35d190797d7c36ff441e062a5b566a6da4778e22233395/wandb-0.25.1-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:e73b4c55b947edae349232d5845204d30fac88e18eb4ad1d4b96bf7cf898405a", size = 25628142, upload-time = "2026-03-10T23:51:19.326Z" }, + { url = "https://files.pythonhosted.org/packages/d5/63/f5c55ee00cf481ef1ccd3c385a0585ad52e7840d08419d4f82ddbeeea959/wandb-0.25.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:22b84065aa398e1624d2e5ad79e08bc4d2af41a6db61697b03b3aaba332977c6", size = 23123172, upload-time = "2026-03-10T23:51:23.418Z" }, + { url = "https://files.pythonhosted.org/packages/3e/d9/19eb7974c0e9253bcbaee655222c0f0e1a52e63e9479ee711b4208f8ac31/wandb-0.25.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:005c4c6b5126ef8f4b4110e5372d950918b00637d6dc4b615ad17445f9739478", size = 25714479, upload-time = "2026-03-10T23:51:27.421Z" }, + { url = "https://files.pythonhosted.org/packages/11/19/466c1d03323a4a0ed7d4036a59b18d6b6f67cb5032e444205927e226b18d/wandb-0.25.1-py3-none-win32.whl", hash = "sha256:8f2d04f16b88d65bfba9d79fb945f6c64e2686215469a841936e0972be8ec6a5", size = 24967338, upload-time = "2026-03-10T23:51:31.833Z" }, + { url = "https://files.pythonhosted.org/packages/89/22/680d34c1587f3a979c701b66d71aa7c42b4ef2fdf0774f67034e618e834e/wandb-0.25.1-py3-none-win_amd64.whl", hash = "sha256:62db5166de14456156d7a85953a58733a631228e6d4248a753605f75f75fb845", size = 24967343, upload-time = "2026-03-10T23:51:36.026Z" }, + { url = "https://files.pythonhosted.org/packages/c4/e8/76836b75d401ff5912aaf513176e64557ceaec4c4946bfd38a698ff84d48/wandb-0.25.1-py3-none-win_arm64.whl", hash = "sha256:cc7c34b70cf4b7be4d395541e82e325fd9d2be978d62c9ec01f1a7141523b6bb", size = 22080774, upload-time = "2026-03-10T23:51:40.196Z" }, +] + +[[package]] +name = "wasabi" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ac/f9/054e6e2f1071e963b5e746b48d1e3727470b2a490834d18ad92364929db3/wasabi-1.1.3.tar.gz", hash = "sha256:4bb3008f003809db0c3e28b4daf20906ea871a2bb43f9914197d540f4f2e0878", size = 30391, upload-time = "2024-05-31T16:56:18.99Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/06/7c/34330a89da55610daa5f245ddce5aab81244321101614751e7537f125133/wasabi-1.1.3-py3-none-any.whl", hash = "sha256:f76e16e8f7e79f8c4c8be49b4024ac725713ab10cd7f19350ad18a8e3f71728c", size = 27880, upload-time = "2024-05-31T16:56:16.699Z" }, ] [[package]] @@ -7820,22 +8025,41 @@ wheels = [ [[package]] name = "wcwidth" -version = "0.2.14" +version = "0.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/24/30/6b0809f4510673dc723187aeaf24c7f5459922d01e2f794277a3dfb90345/wcwidth-0.2.14.tar.gz", hash = "sha256:4d478375d31bc5395a3c55c40ccdf3354688364cd61c4f6adacaa9215d0b3605", size = 102293, upload-time = "2025-09-22T16:29:53.023Z" } +sdist = { url = "https://files.pythonhosted.org/packages/35/a2/8e3becb46433538a38726c948d3399905a4c7cabd0df578ede5dc51f0ec2/wcwidth-0.6.0.tar.gz", hash = "sha256:cdc4e4262d6ef9a1a57e018384cbeb1208d8abbc64176027e2c2455c81313159", size = 159684, upload-time = "2026-02-06T19:19:40.919Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/b5/123f13c975e9f27ab9c0770f514345bd406d0e8d3b7a0723af9d43f710af/wcwidth-0.2.14-py2.py3-none-any.whl", hash = "sha256:a7bb560c8aee30f9957e5f9895805edd20602f2d7f720186dfd906e82b4982e1", size = 37286, upload-time = "2025-09-22T16:29:51.641Z" }, + { url = "https://files.pythonhosted.org/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad", size = 94189, upload-time = "2026-02-06T19:19:39.646Z" }, +] + +[[package]] +name = "weasel" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cloudpathlib" }, + { name = "confection" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "smart-open" }, + { name = "srsly" }, + { name = "typer-slim" }, + { name = "wasabi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/09/d7/edd9c24e60cf8e5de130aa2e8af3b01521f4d0216c371d01212f580d0d8e/weasel-0.4.3.tar.gz", hash = "sha256:f293d6174398e8f478c78481e00c503ee4b82ea7a3e6d0d6a01e46a6b1396845", size = 38733, upload-time = "2025-11-13T23:52:28.193Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/74/a148b41572656904a39dfcfed3f84dd1066014eed94e209223ae8e9d088d/weasel-0.4.3-py3-none-any.whl", hash = "sha256:08f65b5d0dbded4879e08a64882de9b9514753d9eaa4c4e2a576e33666ac12cf", size = 50757, upload-time = "2025-11-13T23:52:26.982Z" }, ] [[package]] name = "weave" -version = "0.52.17" +version = "0.52.25" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "diskcache" }, - { name = "eval-type-backport" }, - { name = "gql", extra = ["aiohttp", "requests"] }, + { name = "gql", extra = ["httpx"] }, { name = "jsonschema" }, { name = "packaging" }, { name = "polyfile-weave" }, @@ -7845,14 +8069,14 @@ dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, { name = "wandb" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/09/95/27e05d954972a83372a3ceb6b5db6136bc4f649fa69d8009b27c144ca111/weave-0.52.17.tar.gz", hash = "sha256:940aaf892b65c72c67cb893e97ed5339136a4b33a7ea85d52ed36671111826ef", size = 609149, upload-time = "2025-11-13T22:09:51.045Z" } +sdist = { url = "https://files.pythonhosted.org/packages/de/c1/3650fd0c1ebbe1bb7cfd4ae549de477def97b29c4632a0aacb8e76c5b632/weave-0.52.25.tar.gz", hash = "sha256:7e1260f5cd7eff0b97e5008ef191e68a5b7b611c07aeea8bc81626f10ee1bab8", size = 657154, upload-time = "2026-01-20T20:12:18.263Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ed/0b/ae7860d2b0c02e7efab26815a9a5286d3b0f9f4e0356446f2896351bf770/weave-0.52.17-py3-none-any.whl", hash = "sha256:5772ef82521a033829c921115c5779399581a7ae06d81dfd527126e2115d16d4", size = 765887, upload-time = "2025-11-13T22:09:49.161Z" }, + { url = "https://files.pythonhosted.org/packages/af/11/02d464838a6fa66228ae5ad4d29d68a9661675a0c787e53d1cd691a5067d/weave-0.52.25-py3-none-any.whl", hash = "sha256:5d0a302059ae507df8d3fd4e39f61a5236612b18272456065056f859bd2be1ee", size = 822409, upload-time = "2026-01-20T20:12:16.356Z" }, ] [[package]] name = "weaviate-client" -version = "4.17.0" +version = "4.20.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "authlib" }, @@ -7863,9 +8087,9 @@ dependencies = [ { name = "pydantic" }, { name = "validators" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bd/0e/e4582b007427187a9fde55fa575db4b766c81929d2b43a3dd8becce50567/weaviate_client-4.17.0.tar.gz", hash = "sha256:731d58d84b0989df4db399b686357ed285fb95971a492ccca8dec90bb2343c51", size = 769019, upload-time = "2025-09-26T11:20:27.381Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/1c/82b560254f612f95b644849d86e092da6407f17965d61e22b583b30b72cf/weaviate_client-4.20.4.tar.gz", hash = "sha256:08703234b59e4e03739f39e740e9e88cb50cd0aa147d9408b88ea6ce995c37b6", size = 809529, upload-time = "2026-03-10T15:08:13.845Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/5b/c5/2da3a45866da7a935dab8ad07be05dcaee48b3ad4955144583b651929be7/weaviate_client-4.17.0-py3-none-any.whl", hash = "sha256:60e4a355b90537ee1e942ab0b76a94750897a13d9cf13c5a6decbd166d0ca8b5", size = 582763, upload-time = "2025-09-26T11:20:25.864Z" }, + { url = "https://files.pythonhosted.org/packages/1d/d7/9461c3e7d8c44080d2307078e33dc7fefefa3171c8f930f2b83a5cbf67f2/weaviate_client-4.20.4-py3-none-any.whl", hash = "sha256:7af3a213bebcb30dcf456b0db8b6225d8926106b835d7b883276de9dc1c301fe", size = 619517, upload-time = "2026-03-10T15:08:12.047Z" }, ] [[package]] @@ -7978,16 +8202,17 @@ wheels = [ [[package]] name = "xinference-client" -version = "1.2.2" +version = "2.3.1" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "aiohttp" }, { name = "pydantic" }, { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4b/cf/7f825a311b11d1e0f7947a94f88adcf1d31e707c54a6d76d61a5d98604ed/xinference-client-1.2.2.tar.gz", hash = "sha256:85d2ba0fcbaae616b06719c422364123cbac97f3e3c82e614095fe6d0e630ed0", size = 44824, upload-time = "2025-02-08T09:28:56.692Z" } +sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/33aeef9cffdc331de0046c25412622c5a16226d1b4e0cca9ed512ad00b9a/xinference_client-2.3.1.tar.gz", hash = "sha256:23ae225f47ff9adf4c6f7718c54993d1be8c704d727509f6e5cb670de3e02c4d", size = 58414, upload-time = "2026-03-15T05:53:23.994Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/0f/fc58e062cf2f7506a33d2fe5446a1e88eb7f64914addffd7ed8b12749712/xinference_client-1.2.2-py3-none-any.whl", hash = "sha256:6941d87cf61283a9d6e81cee6cb2609a183d34c6b7d808c6ba0c33437520518f", size = 25723, upload-time = "2025-02-08T09:28:54.046Z" }, + { url = "https://files.pythonhosted.org/packages/74/8d/d9ab0a457718050a279b9bb6515b7245d114118dc5e275f190ef2628dd16/xinference_client-2.3.1-py3-none-any.whl", hash = "sha256:f7c4f0b56635b46be9cfd9b2affa8e15275491597ac9b958e14b13da5745133e", size = 40012, upload-time = "2026-03-15T05:53:22.797Z" }, ] [[package]] @@ -8010,11 +8235,11 @@ wheels = [ [[package]] name = "xmltodict" -version = "1.0.2" +version = "1.0.4" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/aa/917ceeed4dbb80d2f04dbd0c784b7ee7bba8ae5a54837ef0e5e062cd3cfb/xmltodict-1.0.2.tar.gz", hash = "sha256:54306780b7c2175a3967cad1db92f218207e5bc1aba697d887807c0fb68b7649", size = 25725, upload-time = "2025-09-17T21:59:26.459Z" } +sdist = { url = "https://files.pythonhosted.org/packages/19/70/80f3b7c10d2630aa66414bf23d210386700aa390547278c789afa994fd7e/xmltodict-1.0.4.tar.gz", hash = "sha256:6d94c9f834dd9e44514162799d344d815a3a4faec913717a9ecbfa5be1bb8e61", size = 26124, upload-time = "2026-02-22T02:21:22.074Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c0/20/69a0e6058bc5ea74892d089d64dfc3a62ba78917ec5e2cfa70f7c92ba3a5/xmltodict-1.0.2-py3-none-any.whl", hash = "sha256:62d0fddb0dcbc9f642745d8bbf4d81fd17d6dfaec5a15b5c1876300aad92af0d", size = 13893, upload-time = "2025-09-17T21:59:24.859Z" }, + { url = "https://files.pythonhosted.org/packages/38/34/98a2f52245f4d47be93b580dae5f9861ef58977d73a79eb47c58f1ad1f3a/xmltodict-1.0.4-py3-none-any.whl", hash = "sha256:a4a00d300b0e1c59fc2bfccb53d7b2e88c32f200df138a0dd2229f842497026a", size = 13580, upload-time = "2026-02-22T02:21:21.039Z" }, ] [[package]] @@ -8062,48 +8287,52 @@ wheels = [ [[package]] name = "yarl" -version = "1.18.3" +version = "1.23.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "idna" }, { name = "multidict" }, { name = "propcache" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b7/9d/4b94a8e6d2b51b599516a5cb88e5bc99b4d8d4583e468057eaa29d5f0918/yarl-1.18.3.tar.gz", hash = "sha256:ac1801c45cbf77b6c99242eeff4fffb5e4e73a800b5c4ad4fc0be5def634d2e1", size = 181062, upload-time = "2024-12-01T20:35:23.292Z" } +sdist = { url = "https://files.pythonhosted.org/packages/23/6e/beb1beec874a72f23815c1434518bfc4ed2175065173fb138c3705f658d4/yarl-1.23.0.tar.gz", hash = "sha256:53b1ea6ca88ebd4420379c330aea57e258408dd0df9af0992e5de2078dc9f5d5", size = 194676, upload-time = "2026-03-01T22:07:53.373Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/40/93/282b5f4898d8e8efaf0790ba6d10e2245d2c9f30e199d1a85cae9356098c/yarl-1.18.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8503ad47387b8ebd39cbbbdf0bf113e17330ffd339ba1144074da24c545f0069", size = 141555, upload-time = "2024-12-01T20:33:08.819Z" }, - { url = "https://files.pythonhosted.org/packages/6d/9c/0a49af78df099c283ca3444560f10718fadb8a18dc8b3edf8c7bd9fd7d89/yarl-1.18.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:02ddb6756f8f4517a2d5e99d8b2f272488e18dd0bfbc802f31c16c6c20f22193", size = 94351, upload-time = "2024-12-01T20:33:10.609Z" }, - { url = "https://files.pythonhosted.org/packages/5a/a1/205ab51e148fdcedad189ca8dd587794c6f119882437d04c33c01a75dece/yarl-1.18.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:67a283dd2882ac98cc6318384f565bffc751ab564605959df4752d42483ad889", size = 92286, upload-time = "2024-12-01T20:33:12.322Z" }, - { url = "https://files.pythonhosted.org/packages/ed/fe/88b690b30f3f59275fb674f5f93ddd4a3ae796c2b62e5bb9ece8a4914b83/yarl-1.18.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d980e0325b6eddc81331d3f4551e2a333999fb176fd153e075c6d1c2530aa8a8", size = 340649, upload-time = "2024-12-01T20:33:13.842Z" }, - { url = "https://files.pythonhosted.org/packages/07/eb/3b65499b568e01f36e847cebdc8d7ccb51fff716dbda1ae83c3cbb8ca1c9/yarl-1.18.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b643562c12680b01e17239be267bc306bbc6aac1f34f6444d1bded0c5ce438ca", size = 356623, upload-time = "2024-12-01T20:33:15.535Z" }, - { url = "https://files.pythonhosted.org/packages/33/46/f559dc184280b745fc76ec6b1954de2c55595f0ec0a7614238b9ebf69618/yarl-1.18.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c017a3b6df3a1bd45b9fa49a0f54005e53fbcad16633870104b66fa1a30a29d8", size = 354007, upload-time = "2024-12-01T20:33:17.518Z" }, - { url = "https://files.pythonhosted.org/packages/af/ba/1865d85212351ad160f19fb99808acf23aab9a0f8ff31c8c9f1b4d671fc9/yarl-1.18.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75674776d96d7b851b6498f17824ba17849d790a44d282929c42dbb77d4f17ae", size = 344145, upload-time = "2024-12-01T20:33:20.071Z" }, - { url = "https://files.pythonhosted.org/packages/94/cb/5c3e975d77755d7b3d5193e92056b19d83752ea2da7ab394e22260a7b824/yarl-1.18.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ccaa3a4b521b780a7e771cc336a2dba389a0861592bbce09a476190bb0c8b4b3", size = 336133, upload-time = "2024-12-01T20:33:22.515Z" }, - { url = "https://files.pythonhosted.org/packages/19/89/b77d3fd249ab52a5c40859815765d35c91425b6bb82e7427ab2f78f5ff55/yarl-1.18.3-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:2d06d3005e668744e11ed80812e61efd77d70bb7f03e33c1598c301eea20efbb", size = 347967, upload-time = "2024-12-01T20:33:24.139Z" }, - { url = "https://files.pythonhosted.org/packages/35/bd/f6b7630ba2cc06c319c3235634c582a6ab014d52311e7d7c22f9518189b5/yarl-1.18.3-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:9d41beda9dc97ca9ab0b9888cb71f7539124bc05df02c0cff6e5acc5a19dcc6e", size = 346397, upload-time = "2024-12-01T20:33:26.205Z" }, - { url = "https://files.pythonhosted.org/packages/18/1a/0b4e367d5a72d1f095318344848e93ea70da728118221f84f1bf6c1e39e7/yarl-1.18.3-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:ba23302c0c61a9999784e73809427c9dbedd79f66a13d84ad1b1943802eaaf59", size = 350206, upload-time = "2024-12-01T20:33:27.83Z" }, - { url = "https://files.pythonhosted.org/packages/b5/cf/320fff4367341fb77809a2d8d7fe75b5d323a8e1b35710aafe41fdbf327b/yarl-1.18.3-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:6748dbf9bfa5ba1afcc7556b71cda0d7ce5f24768043a02a58846e4a443d808d", size = 362089, upload-time = "2024-12-01T20:33:29.565Z" }, - { url = "https://files.pythonhosted.org/packages/57/cf/aadba261d8b920253204085268bad5e8cdd86b50162fcb1b10c10834885a/yarl-1.18.3-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:0b0cad37311123211dc91eadcb322ef4d4a66008d3e1bdc404808992260e1a0e", size = 366267, upload-time = "2024-12-01T20:33:31.449Z" }, - { url = "https://files.pythonhosted.org/packages/54/58/fb4cadd81acdee6dafe14abeb258f876e4dd410518099ae9a35c88d8097c/yarl-1.18.3-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:0fb2171a4486bb075316ee754c6d8382ea6eb8b399d4ec62fde2b591f879778a", size = 359141, upload-time = "2024-12-01T20:33:33.79Z" }, - { url = "https://files.pythonhosted.org/packages/9a/7a/4c571597589da4cd5c14ed2a0b17ac56ec9ee7ee615013f74653169e702d/yarl-1.18.3-cp311-cp311-win32.whl", hash = "sha256:61b1a825a13bef4a5f10b1885245377d3cd0bf87cba068e1d9a88c2ae36880e1", size = 84402, upload-time = "2024-12-01T20:33:35.689Z" }, - { url = "https://files.pythonhosted.org/packages/ae/7b/8600250b3d89b625f1121d897062f629883c2f45339623b69b1747ec65fa/yarl-1.18.3-cp311-cp311-win_amd64.whl", hash = "sha256:b9d60031cf568c627d028239693fd718025719c02c9f55df0a53e587aab951b5", size = 91030, upload-time = "2024-12-01T20:33:37.511Z" }, - { url = "https://files.pythonhosted.org/packages/33/85/bd2e2729752ff4c77338e0102914897512e92496375e079ce0150a6dc306/yarl-1.18.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1dd4bdd05407ced96fed3d7f25dbbf88d2ffb045a0db60dbc247f5b3c5c25d50", size = 142644, upload-time = "2024-12-01T20:33:39.204Z" }, - { url = "https://files.pythonhosted.org/packages/ff/74/1178322cc0f10288d7eefa6e4a85d8d2e28187ccab13d5b844e8b5d7c88d/yarl-1.18.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7c33dd1931a95e5d9a772d0ac5e44cac8957eaf58e3c8da8c1414de7dd27c576", size = 94962, upload-time = "2024-12-01T20:33:40.808Z" }, - { url = "https://files.pythonhosted.org/packages/be/75/79c6acc0261e2c2ae8a1c41cf12265e91628c8c58ae91f5ff59e29c0787f/yarl-1.18.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:25b411eddcfd56a2f0cd6a384e9f4f7aa3efee14b188de13048c25b5e91f1640", size = 92795, upload-time = "2024-12-01T20:33:42.322Z" }, - { url = "https://files.pythonhosted.org/packages/6b/32/927b2d67a412c31199e83fefdce6e645247b4fb164aa1ecb35a0f9eb2058/yarl-1.18.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:436c4fc0a4d66b2badc6c5fc5ef4e47bb10e4fd9bf0c79524ac719a01f3607c2", size = 332368, upload-time = "2024-12-01T20:33:43.956Z" }, - { url = "https://files.pythonhosted.org/packages/19/e5/859fca07169d6eceeaa4fde1997c91d8abde4e9a7c018e371640c2da2b71/yarl-1.18.3-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e35ef8683211db69ffe129a25d5634319a677570ab6b2eba4afa860f54eeaf75", size = 342314, upload-time = "2024-12-01T20:33:46.046Z" }, - { url = "https://files.pythonhosted.org/packages/08/75/76b63ccd91c9e03ab213ef27ae6add2e3400e77e5cdddf8ed2dbc36e3f21/yarl-1.18.3-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:84b2deecba4a3f1a398df819151eb72d29bfeb3b69abb145a00ddc8d30094512", size = 341987, upload-time = "2024-12-01T20:33:48.352Z" }, - { url = "https://files.pythonhosted.org/packages/1a/e1/a097d5755d3ea8479a42856f51d97eeff7a3a7160593332d98f2709b3580/yarl-1.18.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00e5a1fea0fd4f5bfa7440a47eff01d9822a65b4488f7cff83155a0f31a2ecba", size = 336914, upload-time = "2024-12-01T20:33:50.875Z" }, - { url = "https://files.pythonhosted.org/packages/0b/42/e1b4d0e396b7987feceebe565286c27bc085bf07d61a59508cdaf2d45e63/yarl-1.18.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0e883008013c0e4aef84dcfe2a0b172c4d23c2669412cf5b3371003941f72bb", size = 325765, upload-time = "2024-12-01T20:33:52.641Z" }, - { url = "https://files.pythonhosted.org/packages/7e/18/03a5834ccc9177f97ca1bbb245b93c13e58e8225276f01eedc4cc98ab820/yarl-1.18.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:5a3f356548e34a70b0172d8890006c37be92995f62d95a07b4a42e90fba54272", size = 344444, upload-time = "2024-12-01T20:33:54.395Z" }, - { url = "https://files.pythonhosted.org/packages/c8/03/a713633bdde0640b0472aa197b5b86e90fbc4c5bc05b727b714cd8a40e6d/yarl-1.18.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ccd17349166b1bee6e529b4add61727d3f55edb7babbe4069b5764c9587a8cc6", size = 340760, upload-time = "2024-12-01T20:33:56.286Z" }, - { url = "https://files.pythonhosted.org/packages/eb/99/f6567e3f3bbad8fd101886ea0276c68ecb86a2b58be0f64077396cd4b95e/yarl-1.18.3-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:b958ddd075ddba5b09bb0be8a6d9906d2ce933aee81100db289badbeb966f54e", size = 346484, upload-time = "2024-12-01T20:33:58.375Z" }, - { url = "https://files.pythonhosted.org/packages/8e/a9/84717c896b2fc6cb15bd4eecd64e34a2f0a9fd6669e69170c73a8b46795a/yarl-1.18.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:c7d79f7d9aabd6011004e33b22bc13056a3e3fb54794d138af57f5ee9d9032cb", size = 359864, upload-time = "2024-12-01T20:34:00.22Z" }, - { url = "https://files.pythonhosted.org/packages/1e/2e/d0f5f1bef7ee93ed17e739ec8dbcb47794af891f7d165fa6014517b48169/yarl-1.18.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4891ed92157e5430874dad17b15eb1fda57627710756c27422200c52d8a4e393", size = 364537, upload-time = "2024-12-01T20:34:03.54Z" }, - { url = "https://files.pythonhosted.org/packages/97/8a/568d07c5d4964da5b02621a517532adb8ec5ba181ad1687191fffeda0ab6/yarl-1.18.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ce1af883b94304f493698b00d0f006d56aea98aeb49d75ec7d98cd4a777e9285", size = 357861, upload-time = "2024-12-01T20:34:05.73Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e3/924c3f64b6b3077889df9a1ece1ed8947e7b61b0a933f2ec93041990a677/yarl-1.18.3-cp312-cp312-win32.whl", hash = "sha256:f91c4803173928a25e1a55b943c81f55b8872f0018be83e3ad4938adffb77dd2", size = 84097, upload-time = "2024-12-01T20:34:07.664Z" }, - { url = "https://files.pythonhosted.org/packages/34/45/0e055320daaabfc169b21ff6174567b2c910c45617b0d79c68d7ab349b02/yarl-1.18.3-cp312-cp312-win_amd64.whl", hash = "sha256:7e2ee16578af3b52ac2f334c3b1f92262f47e02cc6193c598502bd46f5cd1477", size = 90399, upload-time = "2024-12-01T20:34:09.61Z" }, - { url = "https://files.pythonhosted.org/packages/f5/4b/a06e0ec3d155924f77835ed2d167ebd3b211a7b0853da1cf8d8414d784ef/yarl-1.18.3-py3-none-any.whl", hash = "sha256:b57f4f58099328dfb26c6a771d09fb20dbbae81d20cfb66141251ea063bd101b", size = 45109, upload-time = "2024-12-01T20:35:20.834Z" }, + { url = "https://files.pythonhosted.org/packages/a2/aa/60da938b8f0997ba3a911263c40d82b6f645a67902a490b46f3355e10fae/yarl-1.23.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b35d13d549077713e4414f927cdc388d62e543987c572baee613bf82f11a4b99", size = 123641, upload-time = "2026-03-01T22:04:42.841Z" }, + { url = "https://files.pythonhosted.org/packages/24/84/e237607faf4e099dbb8a4f511cfd5efcb5f75918baad200ff7380635631b/yarl-1.23.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cbb0fef01f0c6b38cb0f39b1f78fc90b807e0e3c86a7ff3ce74ad77ce5c7880c", size = 86248, upload-time = "2026-03-01T22:04:44.757Z" }, + { url = "https://files.pythonhosted.org/packages/b2/0d/71ceabc14c146ba8ee3804ca7b3d42b1664c8440439de5214d366fec7d3a/yarl-1.23.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dc52310451fc7c629e13c4e061cbe2dd01684d91f2f8ee2821b083c58bd72432", size = 85988, upload-time = "2026-03-01T22:04:46.365Z" }, + { url = "https://files.pythonhosted.org/packages/8c/6c/4a90d59c572e46b270ca132aca66954f1175abd691f74c1ef4c6711828e2/yarl-1.23.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2c6b50c7b0464165472b56b42d4c76a7b864597007d9c085e8b63e185cf4a7a", size = 100566, upload-time = "2026-03-01T22:04:47.639Z" }, + { url = "https://files.pythonhosted.org/packages/49/fb/c438fb5108047e629f6282a371e6e91cf3f97ee087c4fb748a1f32ceef55/yarl-1.23.0-cp311-cp311-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:aafe5dcfda86c8af00386d7781d4c2181b5011b7be3f2add5e99899ea925df05", size = 92079, upload-time = "2026-03-01T22:04:48.925Z" }, + { url = "https://files.pythonhosted.org/packages/d9/13/d269aa1aed3e4f50a5a103f96327210cc5fa5dd2d50882778f13c7a14606/yarl-1.23.0-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9ee33b875f0b390564c1fb7bc528abf18c8ee6073b201c6ae8524aca778e2d83", size = 108741, upload-time = "2026-03-01T22:04:50.838Z" }, + { url = "https://files.pythonhosted.org/packages/85/fb/115b16f22c37ea4437d323e472945bea97301c8ec6089868fa560abab590/yarl-1.23.0-cp311-cp311-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:4c41e021bc6d7affb3364dc1e1e5fa9582b470f283748784bd6ea0558f87f42c", size = 108099, upload-time = "2026-03-01T22:04:52.499Z" }, + { url = "https://files.pythonhosted.org/packages/9a/64/c53487d9f4968045b8afa51aed7ca44f58b2589e772f32745f3744476c82/yarl-1.23.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:99c8a9ed30f4164bc4c14b37a90208836cbf50d4ce2a57c71d0f52c7fb4f7598", size = 102678, upload-time = "2026-03-01T22:04:55.176Z" }, + { url = "https://files.pythonhosted.org/packages/85/59/cd98e556fbb2bf8fab29c1a722f67ad45c5f3447cac798ab85620d1e70af/yarl-1.23.0-cp311-cp311-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:f2af5c81a1f124609d5f33507082fc3f739959d4719b56877ab1ee7e7b3d602b", size = 100803, upload-time = "2026-03-01T22:04:56.588Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c0/b39770b56d4a9f0bb5f77e2f1763cd2d75cc2f6c0131e3b4c360348fcd65/yarl-1.23.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:6b41389c19b07c760c7e427a3462e8ab83c4bb087d127f0e854c706ce1b9215c", size = 100163, upload-time = "2026-03-01T22:04:58.492Z" }, + { url = "https://files.pythonhosted.org/packages/e7/64/6980f99ab00e1f0ff67cb84766c93d595b067eed07439cfccfc8fb28c1a6/yarl-1.23.0-cp311-cp311-musllinux_1_2_armv7l.whl", hash = "sha256:1dc702e42d0684f42d6519c8d581e49c96cefaaab16691f03566d30658ee8788", size = 93859, upload-time = "2026-03-01T22:05:00.268Z" }, + { url = "https://files.pythonhosted.org/packages/38/69/912e6c5e146793e5d4b5fe39ff5b00f4d22463dfd5a162bec565ac757673/yarl-1.23.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:0e40111274f340d32ebcc0a5668d54d2b552a6cca84c9475859d364b380e3222", size = 108202, upload-time = "2026-03-01T22:05:02.273Z" }, + { url = "https://files.pythonhosted.org/packages/59/97/35ca6767524687ad64e5f5c31ad54bc76d585585a9fcb40f649e7e82ffed/yarl-1.23.0-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:4764a6a7588561a9aef92f65bda2c4fb58fe7c675c0883862e6df97559de0bfb", size = 99866, upload-time = "2026-03-01T22:05:03.597Z" }, + { url = "https://files.pythonhosted.org/packages/d3/1c/1a3387ee6d73589f6f2a220ae06f2984f6c20b40c734989b0a44f5987308/yarl-1.23.0-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:03214408cfa590df47728b84c679ae4ef00be2428e11630277be0727eba2d7cc", size = 107852, upload-time = "2026-03-01T22:05:04.986Z" }, + { url = "https://files.pythonhosted.org/packages/a4/b8/35c0750fcd5a3f781058bfd954515dd4b1eab45e218cbb85cf11132215f1/yarl-1.23.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:170e26584b060879e29fac213e4228ef063f39128723807a312e5c7fec28eff2", size = 102919, upload-time = "2026-03-01T22:05:06.397Z" }, + { url = "https://files.pythonhosted.org/packages/e5/1c/9a1979aec4a81896d597bcb2177827f2dbee3f5b7cc48b2d0dadb644b41d/yarl-1.23.0-cp311-cp311-win32.whl", hash = "sha256:51430653db848d258336cfa0244427b17d12db63d42603a55f0d4546f50f25b5", size = 82602, upload-time = "2026-03-01T22:05:08.444Z" }, + { url = "https://files.pythonhosted.org/packages/93/22/b85eca6fa2ad9491af48c973e4c8cf6b103a73dbb271fe3346949449fca0/yarl-1.23.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf49a3ae946a87083ef3a34c8f677ae4243f5b824bfc4c69672e72b3d6719d46", size = 87461, upload-time = "2026-03-01T22:05:10.145Z" }, + { url = "https://files.pythonhosted.org/packages/93/95/07e3553fe6f113e6864a20bdc53a78113cda3b9ced8784ee52a52c9f80d8/yarl-1.23.0-cp311-cp311-win_arm64.whl", hash = "sha256:b39cb32a6582750b6cc77bfb3c49c0f8760dc18dc96ec9fb55fbb0f04e08b928", size = 82336, upload-time = "2026-03-01T22:05:11.554Z" }, + { url = "https://files.pythonhosted.org/packages/88/8a/94615bc31022f711add374097ad4144d569e95ff3c38d39215d07ac153a0/yarl-1.23.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1932b6b8bba8d0160a9d1078aae5838a66039e8832d41d2992daa9a3a08f7860", size = 124737, upload-time = "2026-03-01T22:05:12.897Z" }, + { url = "https://files.pythonhosted.org/packages/e3/6f/c6554045d59d64052698add01226bc867b52fe4a12373415d7991fdca95d/yarl-1.23.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:411225bae281f114067578891bc75534cfb3d92a3b4dfef7a6ca78ba354e6069", size = 87029, upload-time = "2026-03-01T22:05:14.376Z" }, + { url = "https://files.pythonhosted.org/packages/19/2a/725ecc166d53438bc88f76822ed4b1e3b10756e790bafd7b523fe97c322d/yarl-1.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13a563739ae600a631c36ce096615fe307f131344588b0bc0daec108cdb47b25", size = 86310, upload-time = "2026-03-01T22:05:15.71Z" }, + { url = "https://files.pythonhosted.org/packages/99/30/58260ed98e6ff7f90ba84442c1ddd758c9170d70327394a6227b310cd60f/yarl-1.23.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9cbf44c5cb4a7633d078788e1b56387e3d3cf2b8139a3be38040b22d6c3221c8", size = 97587, upload-time = "2026-03-01T22:05:17.384Z" }, + { url = "https://files.pythonhosted.org/packages/76/0a/8b08aac08b50682e65759f7f8dde98ae8168f72487e7357a5d684c581ef9/yarl-1.23.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53ad387048f6f09a8969631e4de3f1bf70c50e93545d64af4f751b2498755072", size = 92528, upload-time = "2026-03-01T22:05:18.804Z" }, + { url = "https://files.pythonhosted.org/packages/52/07/0b7179101fe5f8385ec6c6bb5d0cb9f76bd9fb4a769591ab6fb5cdbfc69a/yarl-1.23.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4a59ba56f340334766f3a4442e0efd0af895fae9e2b204741ef885c446b3a1a8", size = 105339, upload-time = "2026-03-01T22:05:20.235Z" }, + { url = "https://files.pythonhosted.org/packages/d3/8a/36d82869ab5ec829ca8574dfcb92b51286fcfb1e9c7a73659616362dc880/yarl-1.23.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:803a3c3ce4acc62eaf01eaca1208dcf0783025ef27572c3336502b9c232005e7", size = 105061, upload-time = "2026-03-01T22:05:22.268Z" }, + { url = "https://files.pythonhosted.org/packages/66/3e/868e5c3364b6cee19ff3e1a122194fa4ce51def02c61023970442162859e/yarl-1.23.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3d2bff8f37f8d0f96c7ec554d16945050d54462d6e95414babaa18bfafc7f51", size = 100132, upload-time = "2026-03-01T22:05:23.638Z" }, + { url = "https://files.pythonhosted.org/packages/cf/26/9c89acf82f08a52cb52d6d39454f8d18af15f9d386a23795389d1d423823/yarl-1.23.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c75eb09e8d55bceb4367e83496ff8ef2bc7ea6960efb38e978e8073ea59ecb67", size = 99289, upload-time = "2026-03-01T22:05:25.749Z" }, + { url = "https://files.pythonhosted.org/packages/6f/54/5b0db00d2cb056922356104468019c0a132e89c8d3ab67d8ede9f4483d2a/yarl-1.23.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877b0738624280e34c55680d6054a307aa94f7d52fa0e3034a9cc6e790871da7", size = 96950, upload-time = "2026-03-01T22:05:27.318Z" }, + { url = "https://files.pythonhosted.org/packages/f6/40/10fa93811fd439341fad7e0718a86aca0de9548023bbb403668d6555acab/yarl-1.23.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b5405bb8f0e783a988172993cfc627e4d9d00432d6bbac65a923041edacf997d", size = 93960, upload-time = "2026-03-01T22:05:28.738Z" }, + { url = "https://files.pythonhosted.org/packages/bc/d2/8ae2e6cd77d0805f4526e30ec43b6f9a3dfc542d401ac4990d178e4bf0cf/yarl-1.23.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c3a3598a832590c5a3ce56ab5576361b5688c12cb1d39429cf5dba30b510760", size = 104703, upload-time = "2026-03-01T22:05:30.438Z" }, + { url = "https://files.pythonhosted.org/packages/2f/0c/b3ceacf82c3fe21183ce35fa2acf5320af003d52bc1fcf5915077681142e/yarl-1.23.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8419ebd326430d1cbb7efb5292330a2cf39114e82df5cc3d83c9a0d5ebeaf2f2", size = 98325, upload-time = "2026-03-01T22:05:31.835Z" }, + { url = "https://files.pythonhosted.org/packages/9d/e0/12900edd28bdab91a69bd2554b85ad7b151f64e8b521fe16f9ad2f56477a/yarl-1.23.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:be61f6fff406ca40e3b1d84716fde398fc08bc63dd96d15f3a14230a0973ed86", size = 105067, upload-time = "2026-03-01T22:05:33.358Z" }, + { url = "https://files.pythonhosted.org/packages/15/61/74bb1182cf79c9bbe4eb6b1f14a57a22d7a0be5e9cedf8e2d5c2086474c3/yarl-1.23.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ceb13c5c858d01321b5d9bb65e4cf37a92169ea470b70fec6f236b2c9dd7e34", size = 100285, upload-time = "2026-03-01T22:05:35.4Z" }, + { url = "https://files.pythonhosted.org/packages/69/7f/cd5ef733f2550de6241bd8bd8c3febc78158b9d75f197d9c7baa113436af/yarl-1.23.0-cp312-cp312-win32.whl", hash = "sha256:fffc45637bcd6538de8b85f51e3df3223e4ad89bccbfca0481c08c7fc8b7ed7d", size = 82359, upload-time = "2026-03-01T22:05:36.811Z" }, + { url = "https://files.pythonhosted.org/packages/f5/be/25216a49daeeb7af2bec0db22d5e7df08ed1d7c9f65d78b14f3b74fd72fc/yarl-1.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:f69f57305656a4852f2a7203efc661d8c042e6cc67f7acd97d8667fb448a426e", size = 87674, upload-time = "2026-03-01T22:05:38.171Z" }, + { url = "https://files.pythonhosted.org/packages/d2/35/aeab955d6c425b227d5b7247eafb24f2653fedc32f95373a001af5dfeb9e/yarl-1.23.0-cp312-cp312-win_arm64.whl", hash = "sha256:6e87a6e8735b44816e7db0b2fbc9686932df473c826b0d9743148432e10bb9b9", size = 81879, upload-time = "2026-03-01T22:05:40.006Z" }, + { url = "https://files.pythonhosted.org/packages/69/68/c8739671f5699c7dc470580a4f821ef37c32c4cb0b047ce223a7f115757f/yarl-1.23.0-py3-none-any.whl", hash = "sha256:a2df6afe50dea8ae15fa34c9f824a3ee958d785fd5d089063d960bae1daa0a3f", size = 48288, upload-time = "2026-03-01T22:07:51.388Z" }, ] [[package]] @@ -8126,22 +8355,22 @@ wheels = [ [[package]] name = "zope-interface" -version = "8.1.1" +version = "8.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/71/c9/5ec8679a04d37c797d343f650c51ad67d178f0001c363e44b6ac5f97a9da/zope_interface-8.1.1.tar.gz", hash = "sha256:51b10e6e8e238d719636a401f44f1e366146912407b58453936b781a19be19ec", size = 254748, upload-time = "2025-11-15T08:32:52.404Z" } +sdist = { url = "https://files.pythonhosted.org/packages/86/a4/77daa5ba398996d16bb43fc721599d27d03eae68fe3c799de1963c72e228/zope_interface-8.2.tar.gz", hash = "sha256:afb20c371a601d261b4f6edb53c3c418c249db1a9717b0baafc9a9bb39ba1224", size = 254019, upload-time = "2026-01-09T07:51:07.253Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/77/fc/d84bac27332bdefe8c03f7289d932aeb13a5fd6aeedba72b0aa5b18276ff/zope_interface-8.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:e8a0fdd5048c1bb733e4693eae9bc4145a19419ea6a1c95299318a93fe9f3d72", size = 207955, upload-time = "2025-11-15T08:36:45.902Z" }, - { url = "https://files.pythonhosted.org/packages/52/02/e1234eb08b10b5cf39e68372586acc7f7bbcd18176f6046433a8f6b8b263/zope_interface-8.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a4cb0ea75a26b606f5bc8524fbce7b7d8628161b6da002c80e6417ce5ec757c0", size = 208398, upload-time = "2025-11-15T08:36:47.016Z" }, - { url = "https://files.pythonhosted.org/packages/3c/be/aabda44d4bc490f9966c2b77fa7822b0407d852cb909b723f2d9e05d2427/zope_interface-8.1.1-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:c267b00b5a49a12743f5e1d3b4beef45479d696dab090f11fe3faded078a5133", size = 255079, upload-time = "2025-11-15T08:36:48.157Z" }, - { url = "https://files.pythonhosted.org/packages/d8/7f/4fbc7c2d7cb310e5a91b55db3d98e98d12b262014c1fcad9714fe33c2adc/zope_interface-8.1.1-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:e25d3e2b9299e7ec54b626573673bdf0d740cf628c22aef0a3afef85b438aa54", size = 259850, upload-time = "2025-11-15T08:36:49.544Z" }, - { url = "https://files.pythonhosted.org/packages/fe/2c/dc573fffe59cdbe8bbbdd2814709bdc71c4870893e7226700bc6a08c5e0c/zope_interface-8.1.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:63db1241804417aff95ac229c13376c8c12752b83cc06964d62581b493e6551b", size = 261033, upload-time = "2025-11-15T08:36:51.061Z" }, - { url = "https://files.pythonhosted.org/packages/0e/51/1ac50e5ee933d9e3902f3400bda399c128a5c46f9f209d16affe3d4facc5/zope_interface-8.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:9639bf4ed07b5277fb231e54109117c30d608254685e48a7104a34618bcbfc83", size = 212215, upload-time = "2025-11-15T08:36:52.553Z" }, - { url = "https://files.pythonhosted.org/packages/08/3d/f5b8dd2512f33bfab4faba71f66f6873603d625212206dd36f12403ae4ca/zope_interface-8.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:a16715808408db7252b8c1597ed9008bdad7bf378ed48eb9b0595fad4170e49d", size = 208660, upload-time = "2025-11-15T08:36:53.579Z" }, - { url = "https://files.pythonhosted.org/packages/e5/41/c331adea9b11e05ff9ac4eb7d3032b24c36a3654ae9f2bf4ef2997048211/zope_interface-8.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce6b58752acc3352c4aa0b55bbeae2a941d61537e6afdad2467a624219025aae", size = 208851, upload-time = "2025-11-15T08:36:54.854Z" }, - { url = "https://files.pythonhosted.org/packages/25/00/7a8019c3bb8b119c5f50f0a4869183a4b699ca004a7f87ce98382e6b364c/zope_interface-8.1.1-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:807778883d07177713136479de7fd566f9056a13aef63b686f0ab4807c6be259", size = 259292, upload-time = "2025-11-15T08:36:56.409Z" }, - { url = "https://files.pythonhosted.org/packages/1a/fc/b70e963bf89345edffdd5d16b61e789fdc09365972b603e13785360fea6f/zope_interface-8.1.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:50e5eb3b504a7d63dc25211b9298071d5b10a3eb754d6bf2f8ef06cb49f807ab", size = 264741, upload-time = "2025-11-15T08:36:57.675Z" }, - { url = "https://files.pythonhosted.org/packages/96/fe/7d0b5c0692b283901b34847f2b2f50d805bfff4b31de4021ac9dfb516d2a/zope_interface-8.1.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:eee6f93b2512ec9466cf30c37548fd3ed7bc4436ab29cd5943d7a0b561f14f0f", size = 264281, upload-time = "2025-11-15T08:36:58.968Z" }, - { url = "https://files.pythonhosted.org/packages/2b/2c/a7cebede1cf2757be158bcb151fe533fa951038cfc5007c7597f9f86804b/zope_interface-8.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:80edee6116d569883c58ff8efcecac3b737733d646802036dc337aa839a5f06b", size = 212327, upload-time = "2025-11-15T08:37:00.4Z" }, + { url = "https://files.pythonhosted.org/packages/98/97/9c2aa8caae79915ed64eb114e18816f178984c917aa9adf2a18345e4f2e5/zope_interface-8.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c65ade7ea85516e428651048489f5e689e695c79188761de8c622594d1e13322", size = 208081, upload-time = "2026-01-09T08:05:06.623Z" }, + { url = "https://files.pythonhosted.org/packages/34/86/4e2fcb01a8f6780ac84923748e450af0805531f47c0956b83065c99ab543/zope_interface-8.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a1ef4b43659e1348f35f38e7d1a6bbc1682efde239761f335ffc7e31e798b65b", size = 208522, upload-time = "2026-01-09T08:05:07.986Z" }, + { url = "https://files.pythonhosted.org/packages/f6/eb/08e277da32ddcd4014922854096cf6dcb7081fad415892c2da1bedefbf02/zope_interface-8.2-cp311-cp311-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:dfc4f44e8de2ff4eba20af4f0a3ca42d3c43ab24a08e49ccd8558b7a4185b466", size = 255198, upload-time = "2026-01-09T08:05:09.532Z" }, + { url = "https://files.pythonhosted.org/packages/ea/a1/b32484f3281a5dc83bc713ad61eca52c543735cdf204543172087a074a74/zope_interface-8.2-cp311-cp311-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:8f094bfb49179ec5dc9981cb769af1275702bd64720ef94874d9e34da1390d4c", size = 259970, upload-time = "2026-01-09T08:05:11.477Z" }, + { url = "https://files.pythonhosted.org/packages/f6/81/bca0e8ae1e487d4093a8a7cfed2118aa2d4758c8cfd66e59d2af09d71f1c/zope_interface-8.2-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d2bb8e7364e18f083bf6744ccf30433b2a5f236c39c95df8514e3c13007098ce", size = 261153, upload-time = "2026-01-09T08:05:13.402Z" }, + { url = "https://files.pythonhosted.org/packages/40/1e/e3ff2a708011e56b10b271b038d4cb650a8ad5b7d24352fe2edf6d6b187a/zope_interface-8.2-cp311-cp311-win_amd64.whl", hash = "sha256:6f4b4dfcfdfaa9177a600bb31cebf711fdb8c8e9ed84f14c61c420c6aa398489", size = 212330, upload-time = "2026-01-09T08:05:15.267Z" }, + { url = "https://files.pythonhosted.org/packages/e0/a0/1e1fabbd2e9c53ef92b69df6d14f4adc94ec25583b1380336905dc37e9a0/zope_interface-8.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:624b6787fc7c3e45fa401984f6add2c736b70a7506518c3b537ffaacc4b29d4c", size = 208785, upload-time = "2026-01-09T08:05:17.348Z" }, + { url = "https://files.pythonhosted.org/packages/c3/2a/88d098a06975c722a192ef1fb7d623d1b57c6a6997cf01a7aabb45ab1970/zope_interface-8.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:bc9ded9e97a0ed17731d479596ed1071e53b18e6fdb2fc33af1e43f5fd2d3aaa", size = 208976, upload-time = "2026-01-09T08:05:18.792Z" }, + { url = "https://files.pythonhosted.org/packages/e9/e8/757398549fdfd2f8c89f32c82ae4d2f0537ae2a5d2f21f4a2f711f5a059f/zope_interface-8.2-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:532367553e4420c80c0fc0cabcc2c74080d495573706f66723edee6eae53361d", size = 259411, upload-time = "2026-01-09T08:05:20.567Z" }, + { url = "https://files.pythonhosted.org/packages/91/af/502601f0395ce84dff622f63cab47488657a04d0065547df42bee3a680ff/zope_interface-8.2-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:2bf9cf275468bafa3c72688aad8cfcbe3d28ee792baf0b228a1b2d93bd1d541a", size = 264859, upload-time = "2026-01-09T08:05:22.234Z" }, + { url = "https://files.pythonhosted.org/packages/89/0c/d2f765b9b4814a368a7c1b0ac23b68823c6789a732112668072fe596945d/zope_interface-8.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0009d2d3c02ea783045d7804da4fd016245e5c5de31a86cebba66dd6914d59a2", size = 264398, upload-time = "2026-01-09T08:05:23.853Z" }, + { url = "https://files.pythonhosted.org/packages/4a/81/2f171fbc4222066957e6b9220c4fb9146792540102c37e6d94e5d14aad97/zope_interface-8.2-cp312-cp312-win_amd64.whl", hash = "sha256:845d14e580220ae4544bd4d7eb800f0b6034fe5585fc2536806e0a26c2ee6640", size = 212444, upload-time = "2026-01-09T08:05:25.148Z" }, ] [[package]] diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000000..54ac2a4b36 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,16 @@ +coverage: + status: + project: + default: + target: auto + +flags: + web: + paths: + - "web/" + carryforward: true + + api: + paths: + - "api/" + carryforward: true diff --git a/dev/pyrefly-check-local b/dev/pyrefly-check-local new file mode 100755 index 0000000000..8fa5f121fc --- /dev/null +++ b/dev/pyrefly-check-local @@ -0,0 +1,36 @@ +#!/bin/bash + +set -euo pipefail + +SCRIPT_DIR="$(dirname "$(realpath "$0")")" +REPO_ROOT="$SCRIPT_DIR/.." +cd "$REPO_ROOT" + +EXCLUDES_FILE="api/pyrefly-local-excludes.txt" + +pyrefly_args=( + "--summary=none" + "--use-ignore-files=false" + "--disable-project-excludes-heuristics=true" + "--project-excludes=.venv" + "--project-excludes=migrations/" + "--project-excludes=tests/" +) + +if [[ -f "$EXCLUDES_FILE" ]]; then + while IFS= read -r exclude; do + [[ -z "$exclude" || "${exclude:0:1}" == "#" ]] && continue + pyrefly_args+=("--project-excludes=$exclude") + done < "$EXCLUDES_FILE" +fi + +tmp_output="$(mktemp)" +set +e +uv run --directory api --dev pyrefly check "${pyrefly_args[@]}" >"$tmp_output" 2>&1 +pyrefly_status=$? +set -e + +uv run --directory api python libs/pyrefly_diagnostics.py < "$tmp_output" +rm -f "$tmp_output" + +exit "$pyrefly_status" diff --git a/dev/pytest/pytest_config_tests.py b/dev/pytest/pytest_config_tests.py index 1ec95deb09..1ae115f85c 100644 --- a/dev/pytest/pytest_config_tests.py +++ b/dev/pytest/pytest_config_tests.py @@ -38,7 +38,6 @@ BASE_API_AND_DOCKER_CONFIG_SET_DIFF = { "UPSTASH_VECTOR_URL", "USING_UGC_INDEX", "WEAVIATE_BATCH_SIZE", - "WEAVIATE_GRPC_ENABLED", } BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { @@ -86,7 +85,6 @@ BASE_API_AND_DOCKER_COMPOSE_CONFIG_SET_DIFF = { "VIKINGDB_CONNECTION_TIMEOUT", "VIKINGDB_SOCKET_TIMEOUT", "WEAVIATE_BATCH_SIZE", - "WEAVIATE_GRPC_ENABLED", } API_CONFIG_SET = set(dotenv_values(Path("api") / Path(".env.example")).keys()) diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 3c11a079cc..126aebf7bd 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -21,3 +21,4 @@ pytest --timeout "${PYTEST_TIMEOUT}" api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/oceanbase \ api/tests/integration_tests/vdb/tidb_vector \ api/tests/integration_tests/vdb/huawei \ + api/tests/integration_tests/vdb/hologres \ diff --git a/dev/start-worker b/dev/start-worker index 0450851b56..8baa36f1ed 100755 --- a/dev/start-worker +++ b/dev/start-worker @@ -21,6 +21,7 @@ show_help() { echo "" echo "Available queues:" echo " dataset - RAG indexing and document processing" + echo " dataset_summary - LLM-heavy summary index generation (isolated from indexing)" echo " workflow - Workflow triggers (community edition)" echo " workflow_professional - Professional tier workflows (cloud edition)" echo " workflow_team - Team tier workflows (cloud edition)" @@ -106,10 +107,10 @@ if [[ -z "${QUEUES}" ]]; then # Configure queues based on edition if [[ "${EDITION}" == "CLOUD" ]]; then # Cloud edition: separate queues for dataset and trigger tasks - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" else # Community edition (SELF_HOSTED): dataset and workflow have separate queues - QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" + QUEUES="dataset,dataset_summary,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution" fi echo "No queues specified, using edition-based defaults: ${QUEUES}" diff --git a/docker/.env.example b/docker/.env.example index 3d0009711d..9d6cd65318 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -349,6 +349,9 @@ REDIS_SSL_CERTFILE= REDIS_SSL_KEYFILE= # Path to client private key file for SSL authentication REDIS_DB=0 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # Whether to use Redis Sentinel mode. # If set to true, the application will automatically discover and connect to the master node through Sentinel. @@ -538,7 +541,7 @@ SUPABASE_URL=your-server-url # ------------------------------ # The type of vector store to use. -# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`. +# Supported values are `weaviate`, `oceanbase`, `seekdb`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `opengauss`, `tablestore`, `vastbase`, `tidb`, `tidb_on_qdrant`, `baidu`, `lindorm`, `huawei_cloud`, `upstash`, `matrixone`, `clickzetta`, `alibabacloud_mysql`, `iris`, `hologres`. VECTOR_STORE=weaviate # Prefix used to create collection name in vector database VECTOR_INDEX_NAME_PREFIX=Vector_index @@ -602,6 +605,20 @@ COUCHBASE_PASSWORD=password COUCHBASE_BUCKET_NAME=Embeddings COUCHBASE_SCOPE_NAME=_default +# Hologres configurations, only available when VECTOR_STORE is `hologres` +# access_key_id is used as the PG username, access_key_secret is used as the PG password +HOLOGRES_HOST= +HOLOGRES_PORT=80 +HOLOGRES_DATABASE= +HOLOGRES_ACCESS_KEY_ID= +HOLOGRES_ACCESS_KEY_SECRET= +HOLOGRES_SCHEMA=public +HOLOGRES_TOKENIZER=jieba +HOLOGRES_DISTANCE_METHOD=Cosine +HOLOGRES_BASE_QUANTIZATION_TYPE=rabitq +HOLOGRES_MAX_DEGREE=64 +HOLOGRES_EF_CONSTRUCTION=400 + # pgvector configurations, only available when VECTOR_STORE is `pgvector` PGVECTOR_HOST=pgvector PGVECTOR_PORT=5432 @@ -1529,24 +1546,25 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200 SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30 -# Redis URL used for PubSub between API and +# Redis URL used for event bus between API and # celery worker # defaults to url constructed from `REDIS_*` # configurations -PUBSUB_REDIS_URL= -# Pub/sub channel type for streaming events. -# valid options are: +EVENT_BUS_REDIS_URL= +# Event transport type. Options are: # -# - pubsub: for normal Pub/Sub -# - sharded: for sharded Pub/Sub +# - pubsub: normal Pub/Sub (at-most-once) +# - sharded: sharded Pub/Sub (at-most-once) +# - streams: Redis Streams (at-least-once, recommended to avoid subscriber races) # -# It's highly recommended to use sharded Pub/Sub AND redis cluster -# for large deployments. -PUBSUB_REDIS_CHANNEL_TYPE=pubsub -# Whether to use Redis cluster mode while running -# PubSub. +# Note: Before enabling 'streams' in production, estimate your expected event volume and retention needs. +# Configure Redis memory limits and stream trimming appropriately (e.g., MAXLEN and key expiry) to reduce +# the risk of data loss from Redis auto-eviction under memory pressure. +# Also accepts ENV: EVENT_BUS_REDIS_CHANNEL_TYPE. +EVENT_BUS_REDIS_CHANNEL_TYPE=pubsub +# Whether to use Redis cluster mode while use redis as event bus. # It's highly recommended to enable this for large deployments. -PUBSUB_REDIS_USE_CLUSTERS=false +EVENT_BUS_REDIS_USE_CLUSTERS=false # Whether to Enable human input timeout check task ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true diff --git a/docker/dify-env-sync.py b/docker/dify-env-sync.py new file mode 100755 index 0000000000..d7c762748c --- /dev/null +++ b/docker/dify-env-sync.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 + +# ================================================================ +# Dify Environment Variables Synchronization Script +# +# Features: +# - Synchronize latest settings from .env.example to .env +# - Preserve custom settings in existing .env +# - Add new environment variables +# - Detect removed environment variables +# - Create backup files +# ================================================================ + +import argparse +import re +import shutil +import sys +from datetime import datetime +from pathlib import Path + +# ANSI color codes +RED = "\033[0;31m" +GREEN = "\033[0;32m" +YELLOW = "\033[1;33m" +BLUE = "\033[0;34m" +NC = "\033[0m" # No Color + + +def supports_color() -> bool: + """Return True if the terminal supports ANSI color codes.""" + return hasattr(sys.stdout, "isatty") and sys.stdout.isatty() + + +def log_info(message: str) -> None: + """Print an informational message in blue.""" + if supports_color(): + print(f"{BLUE}[INFO]{NC} {message}") + else: + print(f"[INFO] {message}") + + +def log_success(message: str) -> None: + """Print a success message in green.""" + if supports_color(): + print(f"{GREEN}[SUCCESS]{NC} {message}") + else: + print(f"[SUCCESS] {message}") + + +def log_warning(message: str) -> None: + """Print a warning message in yellow to stderr.""" + if supports_color(): + print(f"{YELLOW}[WARNING]{NC} {message}", file=sys.stderr) + else: + print(f"[WARNING] {message}", file=sys.stderr) + + +def log_error(message: str) -> None: + """Print an error message in red to stderr.""" + if supports_color(): + print(f"{RED}[ERROR]{NC} {message}", file=sys.stderr) + else: + print(f"[ERROR] {message}", file=sys.stderr) + + +def parse_env_file(path: Path) -> dict[str, str]: + """Parse an .env-style file and return a mapping of key to raw value. + + Lines that are blank or start with '#' (after optional whitespace) are + skipped. Only lines containing '=' are considered variable definitions. + + Args: + path: Path to the .env file to parse. + + Returns: + Ordered dict mapping variable name to its value string. + """ + variables: dict[str, str] = {} + with path.open(encoding="utf-8") as fh: + for line in fh: + line = line.rstrip("\n") + # Skip blank lines and comment lines + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if "=" not in line: + continue + key, _, value = line.partition("=") + key = key.strip() + if key: + variables[key] = value.strip() + return variables + + +def check_files(work_dir: Path) -> None: + """Verify required files exist; create .env from .env.example if absent. + + Args: + work_dir: Directory that must contain .env.example (and optionally .env). + + Raises: + SystemExit: If .env.example does not exist. + """ + log_info("Checking required files...") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + if not example_file.exists(): + log_error(".env.example file not found") + sys.exit(1) + + if not env_file.exists(): + log_warning(".env file does not exist. Creating from .env.example.") + shutil.copy2(example_file, env_file) + log_success(".env file created") + + log_success("Required files verified") + + +def create_backup(work_dir: Path) -> None: + """Create a timestamped backup of the current .env file. + + Backups are placed in ``/env-backup/`` with the filename + ``.env.backup_``. + + Args: + work_dir: Directory containing the .env file to back up. + """ + env_file = work_dir / ".env" + if not env_file.exists(): + return + + backup_dir = work_dir / "env-backup" + if not backup_dir.exists(): + backup_dir.mkdir(parents=True) + log_info(f"Created backup directory: {backup_dir}") + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = backup_dir / f".env.backup_{timestamp}" + shutil.copy2(env_file, backup_file) + log_success(f"Backed up existing .env to {backup_file}") + + +def analyze_value_change(current: str, recommended: str) -> str | None: + """Analyse what kind of change occurred between two env values. + + Args: + current: Value currently set in .env. + recommended: Value present in .env.example. + + Returns: + A human-readable description string, or None when no analysis applies. + """ + use_colors = supports_color() + + def colorize(color: str, text: str) -> str: + return f"{color}{text}{NC}" if use_colors else text + + if not current and recommended: + return colorize(RED, " -> Setting from empty to recommended value") + if current and not recommended: + return colorize(RED, " -> Recommended value changed to empty") + + # Numeric comparison + if re.fullmatch(r"\d+", current) and re.fullmatch(r"\d+", recommended): + cur_int, rec_int = int(current), int(recommended) + if cur_int < rec_int: + return colorize(BLUE, f" -> Numeric increase ({current} < {recommended})") + if cur_int > rec_int: + return colorize(YELLOW, f" -> Numeric decrease ({current} > {recommended})") + return None + + # Boolean comparison + if current.lower() in {"true", "false"} and recommended.lower() in {"true", "false"}: + if current.lower() != recommended.lower(): + return colorize(BLUE, f" -> Boolean value change ({current} -> {recommended})") + return None + + # URL / endpoint + if current.startswith(("http://", "https://")) or recommended.startswith(("http://", "https://")): + return colorize(BLUE, " -> URL/endpoint change") + + # File path + if current.startswith("/") or recommended.startswith("/"): + return colorize(BLUE, " -> File path change") + + # String length + if len(current) != len(recommended): + return colorize(YELLOW, f" -> String length change ({len(current)} -> {len(recommended)} characters)") + + return None + + +def detect_differences(env_vars: dict[str, str], example_vars: dict[str, str]) -> dict[str, tuple[str, str]]: + """Find variables whose values differ between .env and .env.example. + + Only variables present in *both* files are compared; new or removed + variables are handled by separate functions. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Mapping of key -> (env_value, example_value) for every key whose + values differ. + """ + log_info("Detecting differences between .env and .env.example...") + + diffs: dict[str, tuple[str, str]] = {} + for key, example_value in example_vars.items(): + if key in env_vars and env_vars[key] != example_value: + diffs[key] = (env_vars[key], example_value) + + if diffs: + log_success(f"Detected differences in {len(diffs)} environment variables") + show_differences_detail(diffs) + else: + log_info("No differences detected") + + return diffs + + +def show_differences_detail(diffs: dict[str, tuple[str, str]]) -> None: + """Print a formatted table of differing environment variables. + + Args: + diffs: Mapping of key -> (current_value, recommended_value). + """ + use_colors = supports_color() + + log_info("") + log_info("=== Environment Variable Differences ===") + + if not diffs: + log_info("No differences to display") + return + + for count, (key, (env_value, example_value)) in enumerate(diffs.items(), start=1): + print() + if use_colors: + print(f"{YELLOW}[{count}] {key}{NC}") + print(f" {GREEN}.env (current){NC} : {env_value}") + print(f" {BLUE}.env.example (recommended){NC} : {example_value}") + else: + print(f"[{count}] {key}") + print(f" .env (current) : {env_value}") + print(f" .env.example (recommended) : {example_value}") + + analysis = analyze_value_change(env_value, example_value) + if analysis: + print(analysis) + + print() + log_info("=== Difference Analysis Complete ===") + log_info("Note: Consider changing to the recommended values above.") + log_info("Current implementation preserves .env values.") + print() + + +def detect_removed_variables(env_vars: dict[str, str], example_vars: dict[str, str]) -> list[str]: + """Identify variables present in .env but absent from .env.example. + + Args: + env_vars: Parsed key/value pairs from .env. + example_vars: Parsed key/value pairs from .env.example. + + Returns: + Sorted list of variable names that no longer appear in .env.example. + """ + log_info("Detecting removed environment variables...") + + removed = sorted(set(env_vars) - set(example_vars)) + + if removed: + log_warning("The following environment variables have been removed from .env.example:") + for var in removed: + log_warning(f" - {var}") + log_warning("Consider manually removing these variables from .env") + else: + log_success("No removed environment variables found") + + return removed + + +def sync_env_file(work_dir: Path, env_vars: dict[str, str], diffs: dict[str, tuple[str, str]]) -> None: + """Rewrite .env based on .env.example while preserving custom values. + + The output file follows the exact line structure of .env.example + (preserving comments, blank lines, and ordering). For every variable + that exists in .env with a different value from the example, the + current .env value is kept. Variables that are new in .env.example + (not present in .env at all) are added with the example's default. + + Args: + work_dir: Directory containing .env and .env.example. + env_vars: Parsed key/value pairs from the original .env. + diffs: Keys whose .env values differ from .env.example (to preserve). + """ + log_info("Starting partial synchronization of .env file...") + + example_file = work_dir / ".env.example" + new_env_file = work_dir / ".env.new" + + # Keys whose current .env value should override the example default + preserved_keys: set[str] = set(diffs.keys()) + + preserved_count = 0 + updated_count = 0 + + env_var_pattern = re.compile(r"^([A-Za-z_][A-Za-z0-9_]*)\s*=") + + with example_file.open(encoding="utf-8") as src, new_env_file.open("w", encoding="utf-8") as dst: + for line in src: + raw_line = line.rstrip("\n") + match = env_var_pattern.match(raw_line) + if match: + key = match.group(1) + if key in preserved_keys: + # Write the preserved value from .env + dst.write(f"{key}={env_vars[key]}\n") + log_info(f" Preserved: {key} (.env value)") + preserved_count += 1 + else: + # Use the example value (covers new vars and unchanged ones) + dst.write(line if line.endswith("\n") else raw_line + "\n") + updated_count += 1 + else: + # Blank line, comment, or non-variable line — keep as-is + dst.write(line if line.endswith("\n") else raw_line + "\n") + + # Atomically replace the original .env + try: + new_env_file.replace(work_dir / ".env") + except OSError as exc: + log_error(f"Failed to replace .env file: {exc}") + new_env_file.unlink(missing_ok=True) + sys.exit(1) + + log_success("Successfully created new .env file") + log_success("Partial synchronization of .env file completed") + log_info(f" Preserved .env values: {preserved_count}") + log_info(f" Updated to .env.example values: {updated_count}") + + +def show_statistics(work_dir: Path) -> None: + """Print a summary of variable counts from both env files. + + Args: + work_dir: Directory containing .env and .env.example. + """ + log_info("Synchronization statistics:") + + example_file = work_dir / ".env.example" + env_file = work_dir / ".env" + + example_count = len(parse_env_file(example_file)) if example_file.exists() else 0 + env_count = len(parse_env_file(env_file)) if env_file.exists() else 0 + + log_info(f" .env.example environment variables: {example_count}") + log_info(f" .env environment variables: {env_count}") + + +def build_arg_parser() -> argparse.ArgumentParser: + """Build and return the CLI argument parser. + + Returns: + Configured ArgumentParser instance. + """ + parser = argparse.ArgumentParser( + prog="dify-env-sync", + description=( + "Synchronize .env with .env.example: add new variables, " + "preserve custom values, and report removed variables." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=( + "Examples:\n" + " # Run from the docker/ directory (default)\n" + " python dify-env-sync.py\n\n" + " # Specify a custom working directory\n" + " python dify-env-sync.py --dir /path/to/docker\n" + ), + ) + parser.add_argument( + "--dir", + metavar="DIRECTORY", + default=".", + help="Working directory containing .env and .env.example (default: current directory)", + ) + parser.add_argument( + "--no-backup", + action="store_true", + default=False, + help="Skip creating a timestamped backup of the existing .env file", + ) + return parser + + +def main() -> None: + """Orchestrate the complete environment variable synchronization process.""" + parser = build_arg_parser() + args = parser.parse_args() + + work_dir = Path(args.dir).resolve() + + log_info("=== Dify Environment Variables Synchronization Script ===") + log_info(f"Execution started: {datetime.now()}") + log_info(f"Working directory: {work_dir}") + + # 1. Verify prerequisites + check_files(work_dir) + + # 2. Backup existing .env + if not args.no_backup: + create_backup(work_dir) + + # 3. Parse both files + env_vars = parse_env_file(work_dir / ".env") + example_vars = parse_env_file(work_dir / ".env.example") + + # 4. Report differences (values that changed in the example) + diffs = detect_differences(env_vars, example_vars) + + # 5. Report variables removed from the example + detect_removed_variables(env_vars, example_vars) + + # 6. Rewrite .env + sync_env_file(work_dir, env_vars, diffs) + + # 7. Print summary statistics + show_statistics(work_dir) + + log_success("=== Synchronization process completed successfully ===") + log_info(f"Execution finished: {datetime.now()}") + + +if __name__ == "__main__": + main() diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index fcd4800143..04bd2858ff 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.0 + image: langgenius/dify-web:1.13.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -269,7 +269,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.3-local + image: langgenius/dify-plugin-daemon:0.5.4-local restart: always environment: # Use the shared environment variables. diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 4a739bbbe0..2dca581903 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -123,7 +123,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.3-local + image: langgenius/dify-plugin-daemon:0.5.4-local restart: always env_file: - ./middleware.env diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 62421d7ec4..bf72a0f623 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -90,6 +90,7 @@ x-shared-env: &shared-api-worker-env REDIS_SSL_CERTFILE: ${REDIS_SSL_CERTFILE:-} REDIS_SSL_KEYFILE: ${REDIS_SSL_KEYFILE:-} REDIS_DB: ${REDIS_DB:-0} + REDIS_MAX_CONNECTIONS: ${REDIS_MAX_CONNECTIONS:-} REDIS_USE_SENTINEL: ${REDIS_USE_SENTINEL:-false} REDIS_SENTINELS: ${REDIS_SENTINELS:-} REDIS_SENTINEL_SERVICE_NAME: ${REDIS_SENTINEL_SERVICE_NAME:-} @@ -214,6 +215,17 @@ x-shared-env: &shared-api-worker-env COUCHBASE_PASSWORD: ${COUCHBASE_PASSWORD:-password} COUCHBASE_BUCKET_NAME: ${COUCHBASE_BUCKET_NAME:-Embeddings} COUCHBASE_SCOPE_NAME: ${COUCHBASE_SCOPE_NAME:-_default} + HOLOGRES_HOST: ${HOLOGRES_HOST:-} + HOLOGRES_PORT: ${HOLOGRES_PORT:-80} + HOLOGRES_DATABASE: ${HOLOGRES_DATABASE:-} + HOLOGRES_ACCESS_KEY_ID: ${HOLOGRES_ACCESS_KEY_ID:-} + HOLOGRES_ACCESS_KEY_SECRET: ${HOLOGRES_ACCESS_KEY_SECRET:-} + HOLOGRES_SCHEMA: ${HOLOGRES_SCHEMA:-public} + HOLOGRES_TOKENIZER: ${HOLOGRES_TOKENIZER:-jieba} + HOLOGRES_DISTANCE_METHOD: ${HOLOGRES_DISTANCE_METHOD:-Cosine} + HOLOGRES_BASE_QUANTIZATION_TYPE: ${HOLOGRES_BASE_QUANTIZATION_TYPE:-rabitq} + HOLOGRES_MAX_DEGREE: ${HOLOGRES_MAX_DEGREE:-64} + HOLOGRES_EF_CONSTRUCTION: ${HOLOGRES_EF_CONSTRUCTION:-400} PGVECTOR_HOST: ${PGVECTOR_HOST:-pgvector} PGVECTOR_PORT: ${PGVECTOR_PORT:-5432} PGVECTOR_USER: ${PGVECTOR_USER:-postgres} @@ -687,9 +699,9 @@ x-shared-env: &shared-api-worker-env SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE:-1000} SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL:-200} SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: ${SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS:-30} - PUBSUB_REDIS_URL: ${PUBSUB_REDIS_URL:-} - PUBSUB_REDIS_CHANNEL_TYPE: ${PUBSUB_REDIS_CHANNEL_TYPE:-pubsub} - PUBSUB_REDIS_USE_CLUSTERS: ${PUBSUB_REDIS_USE_CLUSTERS:-false} + EVENT_BUS_REDIS_URL: ${EVENT_BUS_REDIS_URL:-} + EVENT_BUS_REDIS_CHANNEL_TYPE: ${EVENT_BUS_REDIS_CHANNEL_TYPE:-pubsub} + EVENT_BUS_REDIS_USE_CLUSTERS: ${EVENT_BUS_REDIS_USE_CLUSTERS:-false} ENABLE_HUMAN_INPUT_TIMEOUT_TASK: ${ENABLE_HUMAN_INPUT_TIMEOUT_TASK:-true} HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: ${HUMAN_INPUT_TIMEOUT_TASK_INTERVAL:-1} SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: ${SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL:-90000} @@ -716,7 +728,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -758,7 +770,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -797,7 +809,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.0 + image: langgenius/dify-api:1.13.2 restart: always environment: # Use the shared environment variables. @@ -827,7 +839,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.0 + image: langgenius/dify-web:1.13.2 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -964,7 +976,7 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.3-local + image: langgenius/dify-plugin-daemon:0.5.4-local restart: always environment: # Use the shared environment variables. diff --git a/docker/middleware.env.example b/docker/middleware.env.example index c88dbe5511..7b28a77fe3 100644 --- a/docker/middleware.env.example +++ b/docker/middleware.env.example @@ -91,6 +91,9 @@ MYSQL_INNODB_FLUSH_LOG_AT_TRX_COMMIT=2 # ----------------------------- REDIS_HOST_VOLUME=./volumes/redis/data REDIS_PASSWORD=difyai123456 +# Optional: limit total Redis connections used by API/Worker (unset for default) +# Align with API's REDIS_MAX_CONNECTIONS in configs +REDIS_MAX_CONNECTIONS= # ------------------------------ # Environment Variables for sandbox Service diff --git a/docs/eu-ai-act-compliance.md b/docs/eu-ai-act-compliance.md new file mode 100644 index 0000000000..5fa29eed3f --- /dev/null +++ b/docs/eu-ai-act-compliance.md @@ -0,0 +1,186 @@ +# EU AI Act Compliance Guide for Dify Deployers + +Dify is an LLMOps platform for building RAG pipelines, agents, and AI workflows. If you deploy Dify in the EU — whether self-hosted or using a cloud provider — the EU AI Act applies to your deployment. This guide covers what the regulation requires and how Dify's architecture maps to those requirements. + +## Is your system in scope? + +The detailed obligations in Articles 12, 13, and 14 only apply to **high-risk AI systems** as defined in Annex III of the EU AI Act. A Dify application is high-risk if it is used for: + +- **Recruitment and HR** — screening candidates, evaluating employee performance, allocating tasks +- **Credit scoring and insurance** — assessing creditworthiness or setting premiums +- **Law enforcement** — profiling, criminal risk assessment, border control +- **Critical infrastructure** — managing energy, water, transport, or telecommunications systems +- **Education assessment** — grading students, determining admissions +- **Essential public services** — evaluating eligibility for benefits, housing, or emergency services + +Most Dify deployments (customer-facing chatbots, internal knowledge bases, content generation workflows) are **not** high-risk. If your Dify application does not fall into one of the categories above: + +- **Article 50** (end-user transparency) still applies if users interact with your application directly. See the [Article 50 section](#article-50-end-user-transparency) below. +- **GDPR** still applies if you process personal data. See the [GDPR section](#gdpr-considerations) below. +- The high-risk obligations (Articles 9-15) are less likely to apply, but risk classification is context-dependent. **Do not self-classify without legal review.** Focus on Article 50 (transparency) and GDPR (data protection) as your baseline obligations. + +If you are unsure whether your use case qualifies as high-risk, consult a qualified legal professional before proceeding. + +## Self-hosted vs cloud: different compliance profiles + +| Deployment | Your role | Dify's role | Who handles compliance? | +|-----------|----------|-------------|------------------------| +| **Self-hosted** | Provider and deployer | Framework provider — obligations under Article 25 apply only if Dify is placed on the market or put into service as part of a complete AI system bearing its name or trademark | You | +| **Dify Cloud** | Deployer | Provider and processor | Shared — Dify handles SOC 2 and GDPR for the platform; you handle AI Act obligations for your specific use case | + +Dify Cloud already has SOC 2 Type II and GDPR compliance for the platform itself. But the EU AI Act adds obligations specific to AI systems that SOC 2 does not cover: risk classification, technical documentation, transparency, and human oversight. + +## Supported providers and services + +Dify integrates with a broad range of AI providers and data stores. The following are the key ones relevant to compliance: + +- **AI providers:** HuggingFace (core), plus integrations with OpenAI, Anthropic, Google, and 100+ models via provider plugins +- **Model identifiers include:** gpt-4o, gpt-3.5-turbo, claude-3-opus, gemini-2.5-flash, whisper-1, and others +- **Vector database connections:** Extensive RAG infrastructure supporting numerous vector stores + +Dify's plugin architecture means actual provider usage depends on your configuration. Document which providers and models are active in your deployment. + +## Data flow diagram + +A typical Dify RAG deployment: + +```mermaid +graph LR + USER((User)) -->|query| DIFY[Dify Platform] + DIFY -->|prompts| LLM([LLM Provider]) + LLM -->|responses| DIFY + DIFY -->|documents| EMBED([Embedding Model]) + EMBED -->|vectors| DIFY + DIFY -->|store/retrieve| VS[(Vector Store)] + DIFY -->|knowledge| KB[(Knowledge Base)] + DIFY -->|response| USER + + classDef processor fill:#60a5fa,stroke:#1e40af,color:#000 + classDef controller fill:#4ade80,stroke:#166534,color:#000 + classDef app fill:#a78bfa,stroke:#5b21b6,color:#000 + classDef user fill:#f472b6,stroke:#be185d,color:#000 + + class USER user + class DIFY app + class LLM processor + class EMBED processor + class VS controller + class KB controller +``` + +**GDPR roles** (providers are typically processors for customer-submitted data, but the exact role depends on each provider's terms of service and processing purpose; deployers should review each provider's DPA): +- **Cloud LLM providers (OpenAI, Anthropic, Google)** typically act as processors — requires DPA. +- **Cloud embedding services** typically act as processors — requires DPA. +- **Self-hosted vector stores (Weaviate, Qdrant, pgvector):** Your organization remains the controller — no third-party transfer. +- **Cloud vector stores (Pinecone, Zilliz Cloud)** typically act as processors — requires DPA. +- **Knowledge base documents:** Your organization is the controller — stored in your infrastructure. + +## Article 11: Technical documentation + +High-risk systems need Annex IV documentation. For Dify deployments, key sections include: + +| Section | What Dify provides | What you must document | +|---------|-------------------|----------------------| +| General description | Platform capabilities, supported models | Your specific use case, intended users, deployment context | +| Development process | Dify's architecture, plugin system | Your RAG pipeline design, prompt engineering, knowledge base curation | +| Monitoring | Dify's built-in logging and analytics | Your monitoring plan, alert thresholds, incident response | +| Performance metrics | Dify's evaluation features | Your accuracy benchmarks, quality thresholds, bias testing | +| Risk management | — | Risk assessment for your specific use case | + +Some sections can be derived from Dify's architecture and your deployment configuration, as shown in the table above. The remaining sections require your input. + +## Article 12: Record-keeping + +Dify's built-in logging covers several Article 12 requirements: + +| Requirement | Dify Feature | Status | +|------------|-------------|--------| +| Conversation logs | Full conversation history with timestamps | **Covered** | +| Model tracking | Model name recorded per interaction | **Covered** | +| Token usage | Token counts per message | **Covered** | +| Cost tracking | Cost per conversation (if provider reports it) | **Partial** | +| Document retrieval | RAG source documents logged | **Covered** | +| User identification | User session tracking | **Covered** | +| Error logging | Failed generation logs | **Covered** | +| Data retention | Configurable | **Your responsibility** | + +**Retention periods:** The required retention period depends on your role under the Act. Article 18 requires **providers** of high-risk systems to retain logs and technical documentation for **10 years** after market placement. Article 26(6) requires **deployers** to retain logs for at least **6 months**. If you self-host Dify and have substantially modified the system, you may be classified as a provider rather than a deployer. Confirm the applicable retention period with legal counsel. + +## Article 13: Transparency to deployers + +Article 13 requires providers of high-risk AI systems to supply deployers with the information needed to understand and operate the system correctly. This is a **documentation obligation**, not a logging obligation. For Dify deployments, this means the upstream LLM and embedding providers must give you: + +- Instructions for use, including intended purpose and known limitations +- Accuracy metrics and performance benchmarks +- Known or foreseeable risks and residual risks after mitigation +- Technical specifications: input/output formats, training data characteristics, model architecture details + +As a deployer, collect model cards, system documentation, and accuracy reports from each AI provider your Dify application uses. Maintain these as part of your Annex IV technical documentation. + +Dify's platform features provide **supporting evidence** that can inform Article 13 documentation, but they do not satisfy Article 13 on their own: +- **Source attribution** — Dify's RAG citation feature shows which documents informed the response, supporting deployer-side auditing +- **Model identification** — Dify logs which LLM model generates responses, providing evidence for system documentation +- **Conversation logs** — execution history helps compile performance and behavior evidence + +You must independently produce system documentation covering how your specific Dify deployment uses AI, its intended purpose, performance characteristics, and residual risks. + +## Article 50: End-user transparency + +Article 50 requires deployers to inform end users that they are interacting with an AI system. This is a separate obligation from Article 13 and applies even to limited-risk systems. + +For Dify applications serving end users: + +1. **Disclose AI involvement** — tell users they are interacting with an AI system +2. **AI-generated content labeling** — identify AI-generated content as such (e.g., clear labeling in the UI) + +Dify's "citation" feature also supports end-user transparency by showing users which knowledge base documents informed the answer. + +> **Note:** Article 50 applies to chatbots and systems interacting directly with natural persons. It has a separate scope from the high-risk designation under Annex III — it applies even to limited-risk systems. + +## Article 14: Human oversight + +Article 14 requires that high-risk AI systems be designed so that natural persons can effectively oversee them. Dify provides **automated technical safeguards** that support human oversight, but they are not a substitute for it: + +| Dify Feature | What It Does | Oversight Role | +|-------------|-------------|----------------| +| Annotation/feedback system | Human review of AI outputs | **Direct oversight** — humans evaluate and correct AI responses | +| Content moderation | Built-in filtering before responses reach users | **Automated safeguard** — reduces harmful outputs but does not replace human judgment on edge cases | +| Rate limiting | Controls on API usage | **Automated safeguard** — bounds system behavior, supports overseer's ability to maintain control | +| Workflow control | Insert human review steps between AI generation and output | **Oversight enabler** — allows building approval gates into the pipeline | + +These automated controls are necessary building blocks, but Article 14 compliance requires **human oversight procedures** on top of them: +- **Escalation procedures** — define what happens when moderation triggers or edge cases arise (who is notified, what action is taken) +- **Human review pipeline** — for high-stakes decisions, route AI outputs to a qualified person before they take effect +- **Override mechanism** — a human must be able to halt AI responses or override the system's output +- **Competence requirements** — the human overseer must understand the system's capabilities, limitations, and the context of its outputs + +### Recommended pattern + +For high-risk use cases (HR, legal, medical), configure your Dify workflow to require human approval before the AI response is delivered to the end user or acted upon. + +## Knowledge base compliance + +Dify's knowledge base feature has specific compliance implications: + +1. **Data provenance:** Document where your knowledge base documents come from. Article 10 requires data governance for training data; knowledge bases are analogous. +2. **Update tracking:** When you add, remove, or update documents in the knowledge base, log the change. The AI system's behavior changes with its knowledge base. +3. **PII in documents:** If knowledge base documents contain personal data, GDPR applies to the entire RAG pipeline. Implement access controls and consider PII redaction before indexing. +4. **Copyright:** Ensure you have the right to use the documents in your knowledge base for AI-assisted generation. + +## GDPR considerations + +1. **Legal basis** (Article 6): Document why AI processing of user queries is necessary +2. **Data Processing Agreements** (Article 28): Required for each cloud LLM and embedding provider +3. **Data minimization:** Only include necessary context in prompts; avoid sending entire documents when a relevant excerpt suffices +4. **Right to erasure:** If a user requests deletion, ensure their conversations are removed from Dify's logs AND any vector store entries derived from their data +5. **Cross-border transfers:** Providers based outside the EEA — including US-based providers (OpenAI, Anthropic), and any other non-EEA providers you route to — require Standard Contractual Clauses (SCCs) or equivalent safeguards under Chapter V of the GDPR. Review each provider's transfer mechanism individually. + +## Resources + +- [EU AI Act full text](https://artificialintelligenceact.eu/) +- [Dify documentation](https://docs.dify.ai/) +- [Dify SOC 2 compliance](https://dify.ai/trust) + +--- + +*This is not legal advice. Consult a qualified professional for compliance decisions.* diff --git a/docs/tlh/README.md b/docs/tlh/README.md index a25849c443..e2acd7734c 100644 --- a/docs/tlh/README.md +++ b/docs/tlh/README.md @@ -61,7 +61,7 @@

langgenius%2Fdify | Trendshift

-Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production. Here's a list of the core features: +Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features:

**1. Workflow**: diff --git a/scripts/stress-test/common/config_helper.py b/scripts/stress-test/common/config_helper.py index 75fcbffa6f..fb34b43e26 100644 --- a/scripts/stress-test/common/config_helper.py +++ b/scripts/stress-test/common/config_helper.py @@ -6,6 +6,13 @@ from typing import Any class ConfigHelper: + _LEGACY_SECTION_MAP = { + "admin_config": "admin", + "token_config": "auth", + "app_config": "app", + "api_key_config": "api_key", + } + """Helper class for reading and writing configuration files.""" def __init__(self, base_dir: Path | None = None): @@ -50,14 +57,8 @@ class ConfigHelper: Dictionary containing config data, or None if file doesn't exist """ # Provide backward compatibility for old config names - if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: - section_map = { - "admin_config": "admin", - "token_config": "auth", - "app_config": "app", - "api_key_config": "api_key", - } - return self.get_state_section(section_map[filename]) + if filename in self._LEGACY_SECTION_MAP: + return self.get_state_section(self._LEGACY_SECTION_MAP[filename]) config_path = self.get_config_path(filename) @@ -85,14 +86,11 @@ class ConfigHelper: True if successful, False otherwise """ # Provide backward compatibility for old config names - if filename in ["admin_config", "token_config", "app_config", "api_key_config"]: - section_map = { - "admin_config": "admin", - "token_config": "auth", - "app_config": "app", - "api_key_config": "api_key", - } - return self.update_state_section(section_map[filename], data) + if filename in self._LEGACY_SECTION_MAP: + return self.update_state_section( + self._LEGACY_SECTION_MAP[filename], + data, + ) self.ensure_config_dir() config_path = self.get_config_path(filename) diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index afbb58fee1..7c8a293446 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -54,17 +54,22 @@ "publish:npm": "./scripts/publish.sh" }, "dependencies": { - "axios": "^1.13.2" + "axios": "^1.13.6" }, "devDependencies": { - "@eslint/js": "^9.39.2", - "@types/node": "^25.0.3", - "@typescript-eslint/eslint-plugin": "^8.50.1", - "@typescript-eslint/parser": "^8.50.1", - "@vitest/coverage-v8": "4.0.16", - "eslint": "^9.39.2", + "@eslint/js": "^10.0.1", + "@types/node": "^25.4.0", + "@typescript-eslint/eslint-plugin": "^8.57.0", + "@typescript-eslint/parser": "^8.57.0", + "@vitest/coverage-v8": "4.0.18", + "eslint": "^10.0.3", "tsup": "^8.5.1", "typescript": "^5.9.3", - "vitest": "^4.0.16" + "vitest": "^4.0.18" + }, + "pnpm": { + "overrides": { + "rollup@>=4.0.0,<4.59.0": "4.59.0" + } } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index 1923a0f063..c4b299cd73 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -4,41 +4,44 @@ settings: autoInstallPeers: true excludeLinksFromLockfile: false +overrides: + rollup@>=4.0.0,<4.59.0: 4.59.0 + importers: .: dependencies: axios: - specifier: ^1.13.2 - version: 1.13.5 + specifier: ^1.13.6 + version: 1.13.6 devDependencies: '@eslint/js': - specifier: ^9.39.2 - version: 9.39.2 + specifier: ^10.0.1 + version: 10.0.1(eslint@10.0.3) '@types/node': - specifier: ^25.0.3 - version: 25.0.3 + specifier: ^25.4.0 + version: 25.4.0 '@typescript-eslint/eslint-plugin': - specifier: ^8.50.1 - version: 8.50.1(@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3))(eslint@9.39.2)(typescript@5.9.3) + specifier: ^8.57.0 + version: 8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3) '@typescript-eslint/parser': - specifier: ^8.50.1 - version: 8.50.1(eslint@9.39.2)(typescript@5.9.3) + specifier: ^8.57.0 + version: 8.57.0(eslint@10.0.3)(typescript@5.9.3) '@vitest/coverage-v8': - specifier: 4.0.16 - version: 4.0.16(vitest@4.0.16(@types/node@25.0.3)) + specifier: 4.0.18 + version: 4.0.18(vitest@4.0.18(@types/node@25.4.0)) eslint: - specifier: ^9.39.2 - version: 9.39.2 + specifier: ^10.0.3 + version: 10.0.3 tsup: specifier: ^8.5.1 - version: 8.5.1(postcss@8.5.6)(typescript@5.9.3) + version: 8.5.1(postcss@8.5.8)(typescript@5.9.3) typescript: specifier: ^5.9.3 version: 5.9.3 vitest: - specifier: ^4.0.16 - version: 4.0.16(@types/node@25.0.3) + specifier: ^4.0.18 + version: 4.0.18(@types/node@25.4.0) packages: @@ -50,177 +53,177 @@ packages: resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} engines: {node: '>=6.9.0'} - '@babel/parser@7.28.5': - resolution: {integrity: sha512-KKBU1VGYR7ORr3At5HAtUQ+TV3SzRCXmA/8OdDZiLDBIZxVyzXuztPjfLd3BV1PRAQGCMWWSHYhL0F8d5uHBDQ==} + '@babel/parser@7.29.0': + resolution: {integrity: sha512-IyDgFV5GeDUVX4YdF/3CPULtVGSXXMLh1xVIgdCgxApktqnQV0r7/8Nqthg+8YLGaAtdyIlo2qIdZrbCv4+7ww==} engines: {node: '>=6.0.0'} hasBin: true - '@babel/types@7.28.5': - resolution: {integrity: sha512-qQ5m48eI/MFLQ5PxQj4PFaprjyCTLI37ElWMmNs0K8Lk3dVeOdNpB3ks8jc7yM5CDmVC73eMVk/trk3fgmrUpA==} + '@babel/types@7.29.0': + resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==} engines: {node: '>=6.9.0'} '@bcoe/v8-coverage@1.0.2': resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} engines: {node: '>=18'} - '@esbuild/aix-ppc64@0.27.2': - resolution: {integrity: sha512-GZMB+a0mOMZs4MpDbj8RJp4cw+w1WV5NYD6xzgvzUJ5Ek2jerwfO2eADyI6ExDSUED+1X8aMbegahsJi+8mgpw==} + '@esbuild/aix-ppc64@0.27.3': + resolution: {integrity: sha512-9fJMTNFTWZMh5qwrBItuziu834eOCUcEqymSH7pY+zoMVEZg3gcPuBNxH1EvfVYe9h0x/Ptw8KBzv7qxb7l8dg==} engines: {node: '>=18'} cpu: [ppc64] os: [aix] - '@esbuild/android-arm64@0.27.2': - resolution: {integrity: sha512-pvz8ZZ7ot/RBphf8fv60ljmaoydPU12VuXHImtAs0XhLLw+EXBi2BLe3OYSBslR4rryHvweW5gmkKFwTiFy6KA==} + '@esbuild/android-arm64@0.27.3': + resolution: {integrity: sha512-YdghPYUmj/FX2SYKJ0OZxf+iaKgMsKHVPF1MAq/P8WirnSpCStzKJFjOjzsW0QQ7oIAiccHdcqjbHmJxRb/dmg==} engines: {node: '>=18'} cpu: [arm64] os: [android] - '@esbuild/android-arm@0.27.2': - resolution: {integrity: sha512-DVNI8jlPa7Ujbr1yjU2PfUSRtAUZPG9I1RwW4F4xFB1Imiu2on0ADiI/c3td+KmDtVKNbi+nffGDQMfcIMkwIA==} + '@esbuild/android-arm@0.27.3': + resolution: {integrity: sha512-i5D1hPY7GIQmXlXhs2w8AWHhenb00+GxjxRncS2ZM7YNVGNfaMxgzSGuO8o8SJzRc/oZwU2bcScvVERk03QhzA==} engines: {node: '>=18'} cpu: [arm] os: [android] - '@esbuild/android-x64@0.27.2': - resolution: {integrity: sha512-z8Ank4Byh4TJJOh4wpz8g2vDy75zFL0TlZlkUkEwYXuPSgX8yzep596n6mT7905kA9uHZsf/o2OJZubl2l3M7A==} + '@esbuild/android-x64@0.27.3': + resolution: {integrity: sha512-IN/0BNTkHtk8lkOM8JWAYFg4ORxBkZQf9zXiEOfERX/CzxW3Vg1ewAhU7QSWQpVIzTW+b8Xy+lGzdYXV6UZObQ==} engines: {node: '>=18'} cpu: [x64] os: [android] - '@esbuild/darwin-arm64@0.27.2': - resolution: {integrity: sha512-davCD2Zc80nzDVRwXTcQP/28fiJbcOwvdolL0sOiOsbwBa72kegmVU0Wrh1MYrbuCL98Omp5dVhQFWRKR2ZAlg==} + '@esbuild/darwin-arm64@0.27.3': + resolution: {integrity: sha512-Re491k7ByTVRy0t3EKWajdLIr0gz2kKKfzafkth4Q8A5n1xTHrkqZgLLjFEHVD+AXdUGgQMq+Godfq45mGpCKg==} engines: {node: '>=18'} cpu: [arm64] os: [darwin] - '@esbuild/darwin-x64@0.27.2': - resolution: {integrity: sha512-ZxtijOmlQCBWGwbVmwOF/UCzuGIbUkqB1faQRf5akQmxRJ1ujusWsb3CVfk/9iZKr2L5SMU5wPBi1UWbvL+VQA==} + '@esbuild/darwin-x64@0.27.3': + resolution: {integrity: sha512-vHk/hA7/1AckjGzRqi6wbo+jaShzRowYip6rt6q7VYEDX4LEy1pZfDpdxCBnGtl+A5zq8iXDcyuxwtv3hNtHFg==} engines: {node: '>=18'} cpu: [x64] os: [darwin] - '@esbuild/freebsd-arm64@0.27.2': - resolution: {integrity: sha512-lS/9CN+rgqQ9czogxlMcBMGd+l8Q3Nj1MFQwBZJyoEKI50XGxwuzznYdwcav6lpOGv5BqaZXqvBSiB/kJ5op+g==} + '@esbuild/freebsd-arm64@0.27.3': + resolution: {integrity: sha512-ipTYM2fjt3kQAYOvo6vcxJx3nBYAzPjgTCk7QEgZG8AUO3ydUhvelmhrbOheMnGOlaSFUoHXB6un+A7q4ygY9w==} engines: {node: '>=18'} cpu: [arm64] os: [freebsd] - '@esbuild/freebsd-x64@0.27.2': - resolution: {integrity: sha512-tAfqtNYb4YgPnJlEFu4c212HYjQWSO/w/h/lQaBK7RbwGIkBOuNKQI9tqWzx7Wtp7bTPaGC6MJvWI608P3wXYA==} + '@esbuild/freebsd-x64@0.27.3': + resolution: {integrity: sha512-dDk0X87T7mI6U3K9VjWtHOXqwAMJBNN2r7bejDsc+j03SEjtD9HrOl8gVFByeM0aJksoUuUVU9TBaZa2rgj0oA==} engines: {node: '>=18'} cpu: [x64] os: [freebsd] - '@esbuild/linux-arm64@0.27.2': - resolution: {integrity: sha512-hYxN8pr66NsCCiRFkHUAsxylNOcAQaxSSkHMMjcpx0si13t1LHFphxJZUiGwojB1a/Hd5OiPIqDdXONia6bhTw==} + '@esbuild/linux-arm64@0.27.3': + resolution: {integrity: sha512-sZOuFz/xWnZ4KH3YfFrKCf1WyPZHakVzTiqji3WDc0BCl2kBwiJLCXpzLzUBLgmp4veFZdvN5ChW4Eq/8Fc2Fg==} engines: {node: '>=18'} cpu: [arm64] os: [linux] - '@esbuild/linux-arm@0.27.2': - resolution: {integrity: sha512-vWfq4GaIMP9AIe4yj1ZUW18RDhx6EPQKjwe7n8BbIecFtCQG4CfHGaHuh7fdfq+y3LIA2vGS/o9ZBGVxIDi9hw==} + '@esbuild/linux-arm@0.27.3': + resolution: {integrity: sha512-s6nPv2QkSupJwLYyfS+gwdirm0ukyTFNl3KTgZEAiJDd+iHZcbTPPcWCcRYH+WlNbwChgH2QkE9NSlNrMT8Gfw==} engines: {node: '>=18'} cpu: [arm] os: [linux] - '@esbuild/linux-ia32@0.27.2': - resolution: {integrity: sha512-MJt5BRRSScPDwG2hLelYhAAKh9imjHK5+NE/tvnRLbIqUWa+0E9N4WNMjmp/kXXPHZGqPLxggwVhz7QP8CTR8w==} + '@esbuild/linux-ia32@0.27.3': + resolution: {integrity: sha512-yGlQYjdxtLdh0a3jHjuwOrxQjOZYD/C9PfdbgJJF3TIZWnm/tMd/RcNiLngiu4iwcBAOezdnSLAwQDPqTmtTYg==} engines: {node: '>=18'} cpu: [ia32] os: [linux] - '@esbuild/linux-loong64@0.27.2': - resolution: {integrity: sha512-lugyF1atnAT463aO6KPshVCJK5NgRnU4yb3FUumyVz+cGvZbontBgzeGFO1nF+dPueHD367a2ZXe1NtUkAjOtg==} + '@esbuild/linux-loong64@0.27.3': + resolution: {integrity: sha512-WO60Sn8ly3gtzhyjATDgieJNet/KqsDlX5nRC5Y3oTFcS1l0KWba+SEa9Ja1GfDqSF1z6hif/SkpQJbL63cgOA==} engines: {node: '>=18'} cpu: [loong64] os: [linux] - '@esbuild/linux-mips64el@0.27.2': - resolution: {integrity: sha512-nlP2I6ArEBewvJ2gjrrkESEZkB5mIoaTswuqNFRv/WYd+ATtUpe9Y09RnJvgvdag7he0OWgEZWhviS1OTOKixw==} + '@esbuild/linux-mips64el@0.27.3': + resolution: {integrity: sha512-APsymYA6sGcZ4pD6k+UxbDjOFSvPWyZhjaiPyl/f79xKxwTnrn5QUnXR5prvetuaSMsb4jgeHewIDCIWljrSxw==} engines: {node: '>=18'} cpu: [mips64el] os: [linux] - '@esbuild/linux-ppc64@0.27.2': - resolution: {integrity: sha512-C92gnpey7tUQONqg1n6dKVbx3vphKtTHJaNG2Ok9lGwbZil6DrfyecMsp9CrmXGQJmZ7iiVXvvZH6Ml5hL6XdQ==} + '@esbuild/linux-ppc64@0.27.3': + resolution: {integrity: sha512-eizBnTeBefojtDb9nSh4vvVQ3V9Qf9Df01PfawPcRzJH4gFSgrObw+LveUyDoKU3kxi5+9RJTCWlj4FjYXVPEA==} engines: {node: '>=18'} cpu: [ppc64] os: [linux] - '@esbuild/linux-riscv64@0.27.2': - resolution: {integrity: sha512-B5BOmojNtUyN8AXlK0QJyvjEZkWwy/FKvakkTDCziX95AowLZKR6aCDhG7LeF7uMCXEJqwa8Bejz5LTPYm8AvA==} + '@esbuild/linux-riscv64@0.27.3': + resolution: {integrity: sha512-3Emwh0r5wmfm3ssTWRQSyVhbOHvqegUDRd0WhmXKX2mkHJe1SFCMJhagUleMq+Uci34wLSipf8Lagt4LlpRFWQ==} engines: {node: '>=18'} cpu: [riscv64] os: [linux] - '@esbuild/linux-s390x@0.27.2': - resolution: {integrity: sha512-p4bm9+wsPwup5Z8f4EpfN63qNagQ47Ua2znaqGH6bqLlmJ4bx97Y9JdqxgGZ6Y8xVTixUnEkoKSHcpRlDnNr5w==} + '@esbuild/linux-s390x@0.27.3': + resolution: {integrity: sha512-pBHUx9LzXWBc7MFIEEL0yD/ZVtNgLytvx60gES28GcWMqil8ElCYR4kvbV2BDqsHOvVDRrOxGySBM9Fcv744hw==} engines: {node: '>=18'} cpu: [s390x] os: [linux] - '@esbuild/linux-x64@0.27.2': - resolution: {integrity: sha512-uwp2Tip5aPmH+NRUwTcfLb+W32WXjpFejTIOWZFw/v7/KnpCDKG66u4DLcurQpiYTiYwQ9B7KOeMJvLCu/OvbA==} + '@esbuild/linux-x64@0.27.3': + resolution: {integrity: sha512-Czi8yzXUWIQYAtL/2y6vogER8pvcsOsk5cpwL4Gk5nJqH5UZiVByIY8Eorm5R13gq+DQKYg0+JyQoytLQas4dA==} engines: {node: '>=18'} cpu: [x64] os: [linux] - '@esbuild/netbsd-arm64@0.27.2': - resolution: {integrity: sha512-Kj6DiBlwXrPsCRDeRvGAUb/LNrBASrfqAIok+xB0LxK8CHqxZ037viF13ugfsIpePH93mX7xfJp97cyDuTZ3cw==} + '@esbuild/netbsd-arm64@0.27.3': + resolution: {integrity: sha512-sDpk0RgmTCR/5HguIZa9n9u+HVKf40fbEUt+iTzSnCaGvY9kFP0YKBWZtJaraonFnqef5SlJ8/TiPAxzyS+UoA==} engines: {node: '>=18'} cpu: [arm64] os: [netbsd] - '@esbuild/netbsd-x64@0.27.2': - resolution: {integrity: sha512-HwGDZ0VLVBY3Y+Nw0JexZy9o/nUAWq9MlV7cahpaXKW6TOzfVno3y3/M8Ga8u8Yr7GldLOov27xiCnqRZf0tCA==} + '@esbuild/netbsd-x64@0.27.3': + resolution: {integrity: sha512-P14lFKJl/DdaE00LItAukUdZO5iqNH7+PjoBm+fLQjtxfcfFE20Xf5CrLsmZdq5LFFZzb5JMZ9grUwvtVYzjiA==} engines: {node: '>=18'} cpu: [x64] os: [netbsd] - '@esbuild/openbsd-arm64@0.27.2': - resolution: {integrity: sha512-DNIHH2BPQ5551A7oSHD0CKbwIA/Ox7+78/AWkbS5QoRzaqlev2uFayfSxq68EkonB+IKjiuxBFoV8ESJy8bOHA==} + '@esbuild/openbsd-arm64@0.27.3': + resolution: {integrity: sha512-AIcMP77AvirGbRl/UZFTq5hjXK+2wC7qFRGoHSDrZ5v5b8DK/GYpXW3CPRL53NkvDqb9D+alBiC/dV0Fb7eJcw==} engines: {node: '>=18'} cpu: [arm64] os: [openbsd] - '@esbuild/openbsd-x64@0.27.2': - resolution: {integrity: sha512-/it7w9Nb7+0KFIzjalNJVR5bOzA9Vay+yIPLVHfIQYG/j+j9VTH84aNB8ExGKPU4AzfaEvN9/V4HV+F+vo8OEg==} + '@esbuild/openbsd-x64@0.27.3': + resolution: {integrity: sha512-DnW2sRrBzA+YnE70LKqnM3P+z8vehfJWHXECbwBmH/CU51z6FiqTQTHFenPlHmo3a8UgpLyH3PT+87OViOh1AQ==} engines: {node: '>=18'} cpu: [x64] os: [openbsd] - '@esbuild/openharmony-arm64@0.27.2': - resolution: {integrity: sha512-LRBbCmiU51IXfeXk59csuX/aSaToeG7w48nMwA6049Y4J4+VbWALAuXcs+qcD04rHDuSCSRKdmY63sruDS5qag==} + '@esbuild/openharmony-arm64@0.27.3': + resolution: {integrity: sha512-NinAEgr/etERPTsZJ7aEZQvvg/A6IsZG/LgZy+81wON2huV7SrK3e63dU0XhyZP4RKGyTm7aOgmQk0bGp0fy2g==} engines: {node: '>=18'} cpu: [arm64] os: [openharmony] - '@esbuild/sunos-x64@0.27.2': - resolution: {integrity: sha512-kMtx1yqJHTmqaqHPAzKCAkDaKsffmXkPHThSfRwZGyuqyIeBvf08KSsYXl+abf5HDAPMJIPnbBfXvP2ZC2TfHg==} + '@esbuild/sunos-x64@0.27.3': + resolution: {integrity: sha512-PanZ+nEz+eWoBJ8/f8HKxTTD172SKwdXebZ0ndd953gt1HRBbhMsaNqjTyYLGLPdoWHy4zLU7bDVJztF5f3BHA==} engines: {node: '>=18'} cpu: [x64] os: [sunos] - '@esbuild/win32-arm64@0.27.2': - resolution: {integrity: sha512-Yaf78O/B3Kkh+nKABUF++bvJv5Ijoy9AN1ww904rOXZFLWVc5OLOfL56W+C8F9xn5JQZa3UX6m+IktJnIb1Jjg==} + '@esbuild/win32-arm64@0.27.3': + resolution: {integrity: sha512-B2t59lWWYrbRDw/tjiWOuzSsFh1Y/E95ofKz7rIVYSQkUYBjfSgf6oeYPNWHToFRr2zx52JKApIcAS/D5TUBnA==} engines: {node: '>=18'} cpu: [arm64] os: [win32] - '@esbuild/win32-ia32@0.27.2': - resolution: {integrity: sha512-Iuws0kxo4yusk7sw70Xa2E2imZU5HoixzxfGCdxwBdhiDgt9vX9VUCBhqcwY7/uh//78A1hMkkROMJq9l27oLQ==} + '@esbuild/win32-ia32@0.27.3': + resolution: {integrity: sha512-QLKSFeXNS8+tHW7tZpMtjlNb7HKau0QDpwm49u0vUp9y1WOF+PEzkU84y9GqYaAVW8aH8f3GcBck26jh54cX4Q==} engines: {node: '>=18'} cpu: [ia32] os: [win32] - '@esbuild/win32-x64@0.27.2': - resolution: {integrity: sha512-sRdU18mcKf7F+YgheI/zGf5alZatMUTKj/jNS6l744f9u3WFu4v7twcUI9vu4mknF4Y9aDlblIie0IM+5xxaqQ==} + '@esbuild/win32-x64@0.27.3': + resolution: {integrity: sha512-4uJGhsxuptu3OcpVAzli+/gWusVGwZZHTlS63hh++ehExkVT8SgiEf7/uC/PclrPPkLhZqGgCTjd0VWLo6xMqA==} engines: {node: '>=18'} cpu: [x64] os: [win32] - '@eslint-community/eslint-utils@4.9.0': - resolution: {integrity: sha512-ayVFHdtZ+hsq1t2Dy24wCmGXGe4q9Gu3smhLYALJrr473ZH27MsnSL+LKUlimp4BWJqMDMLmPpx/Q9R3OAlL4g==} + '@eslint-community/eslint-utils@4.9.1': + resolution: {integrity: sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} peerDependencies: eslint: ^6.0.0 || ^7.0.0 || >=8.0.0 @@ -229,33 +232,34 @@ packages: resolution: {integrity: sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==} engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} - '@eslint/config-array@0.21.1': - resolution: {integrity: sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-array@0.23.3': + resolution: {integrity: sha512-j+eEWmB6YYLwcNOdlwQ6L2OsptI/LO6lNBuLIqe5R7RetD658HLoF+Mn7LzYmAWWNNzdC6cqP+L6r8ujeYXWLw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/config-helpers@0.4.2': - resolution: {integrity: sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/config-helpers@0.5.3': + resolution: {integrity: sha512-lzGN0onllOZCGroKJmRwY6QcEHxbjBw1gwB8SgRSqK8YbbtEXMvKynsXc3553ckIEBxsbMBU7oOZXKIPGZNeZw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/core@0.17.0': - resolution: {integrity: sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/core@1.1.1': + resolution: {integrity: sha512-QUPblTtE51/7/Zhfv8BDwO0qkkzQL7P/aWWbqcf4xWLEYn1oKjdO0gglQBB4GAsu7u6wjijbCmzsUTy6mnk6oQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/eslintrc@3.3.3': - resolution: {integrity: sha512-Kr+LPIUVKz2qkx1HAMH8q1q6azbqBAsXJUxBl/ODDuVPX45Z9DfwB8tPjTi6nNZ8BuM3nbJxC5zCAg5elnBUTQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/js@10.0.1': + resolution: {integrity: sha512-zeR9k5pd4gxjZ0abRoIaxdc7I3nDktoXZk2qOv9gCNWx3mVwEn32VRhyLaRsDiJjTs0xq/T8mfPtyuXu7GWBcA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} + peerDependencies: + eslint: ^10.0.0 + peerDependenciesMeta: + eslint: + optional: true - '@eslint/js@9.39.2': - resolution: {integrity: sha512-q1mjIoW1VX4IvSocvM/vbTiveKC4k9eLrajNEuSsmjymSDEbpGddtpfOoN7YGAqBK3NG+uqo8ia4PDTt8buCYA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/object-schema@3.0.3': + resolution: {integrity: sha512-iM869Pugn9Nsxbh/YHRqYiqd23AmIbxJOcpUMOuWCVNdoQJ5ZtwL6h3t0bcZzJUlC3Dq9jCFCESBZnX0GTv7iQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - '@eslint/object-schema@2.1.7': - resolution: {integrity: sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - - '@eslint/plugin-kit@0.4.1': - resolution: {integrity: sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + '@eslint/plugin-kit@0.6.1': + resolution: {integrity: sha512-iH1B076HoAshH1mLpHMgwdGeTs0CYwL0SPMkGuSebZrwBp16v415e9NZXg2jtrqPVQjf6IANe2Vtlr5KswtcZQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} '@humanfs/core@0.19.1': resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} @@ -286,113 +290,128 @@ packages: '@jridgewell/trace-mapping@0.3.31': resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} - '@rollup/rollup-android-arm-eabi@4.54.0': - resolution: {integrity: sha512-OywsdRHrFvCdvsewAInDKCNyR3laPA2mc9bRYJ6LBp5IyvF3fvXbbNR0bSzHlZVFtn6E0xw2oZlyjg4rKCVcng==} + '@rollup/rollup-android-arm-eabi@4.59.0': + resolution: {integrity: sha512-upnNBkA6ZH2VKGcBj9Fyl9IGNPULcjXRlg0LLeaioQWueH30p6IXtJEbKAgvyv+mJaMxSm1l6xwDXYjpEMiLMg==} cpu: [arm] os: [android] - '@rollup/rollup-android-arm64@4.54.0': - resolution: {integrity: sha512-Skx39Uv+u7H224Af+bDgNinitlmHyQX1K/atIA32JP3JQw6hVODX5tkbi2zof/E69M1qH2UoN3Xdxgs90mmNYw==} + '@rollup/rollup-android-arm64@4.59.0': + resolution: {integrity: sha512-hZ+Zxj3SySm4A/DylsDKZAeVg0mvi++0PYVceVyX7hemkw7OreKdCvW2oQ3T1FMZvCaQXqOTHb8qmBShoqk69Q==} cpu: [arm64] os: [android] - '@rollup/rollup-darwin-arm64@4.54.0': - resolution: {integrity: sha512-k43D4qta/+6Fq+nCDhhv9yP2HdeKeP56QrUUTW7E6PhZP1US6NDqpJj4MY0jBHlJivVJD5P8NxrjuobZBJTCRw==} + '@rollup/rollup-darwin-arm64@4.59.0': + resolution: {integrity: sha512-W2Psnbh1J8ZJw0xKAd8zdNgF9HRLkdWwwdWqubSVk0pUuQkoHnv7rx4GiF9rT4t5DIZGAsConRE3AxCdJ4m8rg==} cpu: [arm64] os: [darwin] - '@rollup/rollup-darwin-x64@4.54.0': - resolution: {integrity: sha512-cOo7biqwkpawslEfox5Vs8/qj83M/aZCSSNIWpVzfU2CYHa2G3P1UN5WF01RdTHSgCkri7XOlTdtk17BezlV3A==} + '@rollup/rollup-darwin-x64@4.59.0': + resolution: {integrity: sha512-ZW2KkwlS4lwTv7ZVsYDiARfFCnSGhzYPdiOU4IM2fDbL+QGlyAbjgSFuqNRbSthybLbIJ915UtZBtmuLrQAT/w==} cpu: [x64] os: [darwin] - '@rollup/rollup-freebsd-arm64@4.54.0': - resolution: {integrity: sha512-miSvuFkmvFbgJ1BevMa4CPCFt5MPGw094knM64W9I0giUIMMmRYcGW/JWZDriaw/k1kOBtsWh1z6nIFV1vPNtA==} + '@rollup/rollup-freebsd-arm64@4.59.0': + resolution: {integrity: sha512-EsKaJ5ytAu9jI3lonzn3BgG8iRBjV4LxZexygcQbpiU0wU0ATxhNVEpXKfUa0pS05gTcSDMKpn3Sx+QB9RlTTA==} cpu: [arm64] os: [freebsd] - '@rollup/rollup-freebsd-x64@4.54.0': - resolution: {integrity: sha512-KGXIs55+b/ZfZsq9aR026tmr/+7tq6VG6MsnrvF4H8VhwflTIuYh+LFUlIsRdQSgrgmtM3fVATzEAj4hBQlaqQ==} + '@rollup/rollup-freebsd-x64@4.59.0': + resolution: {integrity: sha512-d3DuZi2KzTMjImrxoHIAODUZYoUUMsuUiY4SRRcJy6NJoZ6iIqWnJu9IScV9jXysyGMVuW+KNzZvBLOcpdl3Vg==} cpu: [x64] os: [freebsd] - '@rollup/rollup-linux-arm-gnueabihf@4.54.0': - resolution: {integrity: sha512-EHMUcDwhtdRGlXZsGSIuXSYwD5kOT9NVnx9sqzYiwAc91wfYOE1g1djOEDseZJKKqtHAHGwnGPQu3kytmfaXLQ==} + '@rollup/rollup-linux-arm-gnueabihf@4.59.0': + resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] - '@rollup/rollup-linux-arm-musleabihf@4.54.0': - resolution: {integrity: sha512-+pBrqEjaakN2ySv5RVrj/qLytYhPKEUwk+e3SFU5jTLHIcAtqh2rLrd/OkbNuHJpsBgxsD8ccJt5ga/SeG0JmA==} + '@rollup/rollup-linux-arm-musleabihf@4.59.0': + resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] - '@rollup/rollup-linux-arm64-gnu@4.54.0': - resolution: {integrity: sha512-NSqc7rE9wuUaRBsBp5ckQ5CVz5aIRKCwsoa6WMF7G01sX3/qHUw/z4pv+D+ahL1EIKy6Enpcnz1RY8pf7bjwng==} + '@rollup/rollup-linux-arm64-gnu@4.59.0': + resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] - '@rollup/rollup-linux-arm64-musl@4.54.0': - resolution: {integrity: sha512-gr5vDbg3Bakga5kbdpqx81m2n9IX8M6gIMlQQIXiLTNeQW6CucvuInJ91EuCJ/JYvc+rcLLsDFcfAD1K7fMofg==} + '@rollup/rollup-linux-arm64-musl@4.59.0': + resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] - '@rollup/rollup-linux-loong64-gnu@4.54.0': - resolution: {integrity: sha512-gsrtB1NA3ZYj2vq0Rzkylo9ylCtW/PhpLEivlgWe0bpgtX5+9j9EZa0wtZiCjgu6zmSeZWyI/e2YRX1URozpIw==} + '@rollup/rollup-linux-loong64-gnu@4.59.0': + resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] - '@rollup/rollup-linux-ppc64-gnu@4.54.0': - resolution: {integrity: sha512-y3qNOfTBStmFNq+t4s7Tmc9hW2ENtPg8FeUD/VShI7rKxNW7O4fFeaYbMsd3tpFlIg1Q8IapFgy7Q9i2BqeBvA==} + '@rollup/rollup-linux-loong64-musl@4.59.0': + resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} + cpu: [loong64] + os: [linux] + + '@rollup/rollup-linux-ppc64-gnu@4.59.0': + resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] - '@rollup/rollup-linux-riscv64-gnu@4.54.0': - resolution: {integrity: sha512-89sepv7h2lIVPsFma8iwmccN7Yjjtgz0Rj/Ou6fEqg3HDhpCa+Et+YSufy27i6b0Wav69Qv4WBNl3Rs6pwhebQ==} + '@rollup/rollup-linux-ppc64-musl@4.59.0': + resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} + cpu: [ppc64] + os: [linux] + + '@rollup/rollup-linux-riscv64-gnu@4.59.0': + resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] - '@rollup/rollup-linux-riscv64-musl@4.54.0': - resolution: {integrity: sha512-ZcU77ieh0M2Q8Ur7D5X7KvK+UxbXeDHwiOt/CPSBTI1fBmeDMivW0dPkdqkT4rOgDjrDDBUed9x4EgraIKoR2A==} + '@rollup/rollup-linux-riscv64-musl@4.59.0': + resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] - '@rollup/rollup-linux-s390x-gnu@4.54.0': - resolution: {integrity: sha512-2AdWy5RdDF5+4YfG/YesGDDtbyJlC9LHmL6rZw6FurBJ5n4vFGupsOBGfwMRjBYH7qRQowT8D/U4LoSvVwOhSQ==} + '@rollup/rollup-linux-s390x-gnu@4.59.0': + resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] - '@rollup/rollup-linux-x64-gnu@4.54.0': - resolution: {integrity: sha512-WGt5J8Ij/rvyqpFexxk3ffKqqbLf9AqrTBbWDk7ApGUzaIs6V+s2s84kAxklFwmMF/vBNGrVdYgbblCOFFezMQ==} + '@rollup/rollup-linux-x64-gnu@4.59.0': + resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] - '@rollup/rollup-linux-x64-musl@4.54.0': - resolution: {integrity: sha512-JzQmb38ATzHjxlPHuTH6tE7ojnMKM2kYNzt44LO/jJi8BpceEC8QuXYA908n8r3CNuG/B3BV8VR3Hi1rYtmPiw==} + '@rollup/rollup-linux-x64-musl@4.59.0': + resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] - '@rollup/rollup-openharmony-arm64@4.54.0': - resolution: {integrity: sha512-huT3fd0iC7jigGh7n3q/+lfPcXxBi+om/Rs3yiFxjvSxbSB6aohDFXbWvlspaqjeOh+hx7DDHS+5Es5qRkWkZg==} + '@rollup/rollup-openbsd-x64@4.59.0': + resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} + cpu: [x64] + os: [openbsd] + + '@rollup/rollup-openharmony-arm64@4.59.0': + resolution: {integrity: sha512-tt9KBJqaqp5i5HUZzoafHZX8b5Q2Fe7UjYERADll83O4fGqJ49O1FsL6LpdzVFQcpwvnyd0i+K/VSwu/o/nWlA==} cpu: [arm64] os: [openharmony] - '@rollup/rollup-win32-arm64-msvc@4.54.0': - resolution: {integrity: sha512-c2V0W1bsKIKfbLMBu/WGBz6Yci8nJ/ZJdheE0EwB73N3MvHYKiKGs3mVilX4Gs70eGeDaMqEob25Tw2Gb9Nqyw==} + '@rollup/rollup-win32-arm64-msvc@4.59.0': + resolution: {integrity: sha512-V5B6mG7OrGTwnxaNUzZTDTjDS7F75PO1ae6MJYdiMu60sq0CqN5CVeVsbhPxalupvTX8gXVSU9gq+Rx1/hvu6A==} cpu: [arm64] os: [win32] - '@rollup/rollup-win32-ia32-msvc@4.54.0': - resolution: {integrity: sha512-woEHgqQqDCkAzrDhvDipnSirm5vxUXtSKDYTVpZG3nUdW/VVB5VdCYA2iReSj/u3yCZzXID4kuKG7OynPnB3WQ==} + '@rollup/rollup-win32-ia32-msvc@4.59.0': + resolution: {integrity: sha512-UKFMHPuM9R0iBegwzKF4y0C4J9u8C6MEJgFuXTBerMk7EJ92GFVFYBfOZaSGLu6COf7FxpQNqhNS4c4icUPqxA==} cpu: [ia32] os: [win32] - '@rollup/rollup-win32-x64-gnu@4.54.0': - resolution: {integrity: sha512-dzAc53LOuFvHwbCEOS0rPbXp6SIhAf2txMP5p6mGyOXXw5mWY8NGGbPMPrs4P1WItkfApDathBj/NzMLUZ9rtQ==} + '@rollup/rollup-win32-x64-gnu@4.59.0': + resolution: {integrity: sha512-laBkYlSS1n2L8fSo1thDNGrCTQMmxjYY5G0WFWjFFYZkKPjsMBsgJfGf4TLxXrF6RyhI60L8TMOjBMvXiTcxeA==} cpu: [x64] os: [win32] - '@rollup/rollup-win32-x64-msvc@4.54.0': - resolution: {integrity: sha512-hYT5d3YNdSh3mbCU1gwQyPgQd3T2ne0A3KG8KSBdav5TiBg6eInVmV+TeR5uHufiIgSFg0XsOWGW5/RhNcSvPg==} + '@rollup/rollup-win32-x64-msvc@4.59.0': + resolution: {integrity: sha512-2HRCml6OztYXyJXAvdDXPKcawukWY2GpR5/nxKp4iBgiO3wcoEGkAaqctIbZcNB6KlUQBIqt8VYkNSj2397EfA==} cpu: [x64] os: [win32] @@ -405,88 +424,91 @@ packages: '@types/deep-eql@4.0.2': resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + '@types/esrecurse@4.3.1': + resolution: {integrity: sha512-xJBAbDifo5hpffDBuHl0Y8ywswbiAp/Wi7Y/GtAgSlZyIABppyurxVueOPE8LUQOxdlgi6Zqce7uoEpqNTeiUw==} + '@types/estree@1.0.8': resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} '@types/json-schema@7.0.15': resolution: {integrity: sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==} - '@types/node@25.0.3': - resolution: {integrity: sha512-W609buLVRVmeW693xKfzHeIV6nJGGz98uCPfeXI1ELMLXVeKYZ9m15fAMSaUPBHYLGFsVRcMmSCksQOrZV9BYA==} + '@types/node@25.4.0': + resolution: {integrity: sha512-9wLpoeWuBlcbBpOY3XmzSTG3oscB6xjBEEtn+pYXTfhyXhIxC5FsBer2KTopBlvKEiW9l13po9fq+SJY/5lkhw==} - '@typescript-eslint/eslint-plugin@8.50.1': - resolution: {integrity: sha512-PKhLGDq3JAg0Jk/aK890knnqduuI/Qj+udH7wCf0217IGi4gt+acgCyPVe79qoT+qKUvHMDQkwJeKW9fwl8Cyw==} + '@typescript-eslint/eslint-plugin@8.57.0': + resolution: {integrity: sha512-qeu4rTHR3/IaFORbD16gmjq9+rEs9fGKdX0kF6BKSfi+gCuG3RCKLlSBYzn/bGsY9Tj7KE/DAQStbp8AHJGHEQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - '@typescript-eslint/parser': ^8.50.1 - eslint: ^8.57.0 || ^9.0.0 + '@typescript-eslint/parser': ^8.57.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/parser@8.50.1': - resolution: {integrity: sha512-hM5faZwg7aVNa819m/5r7D0h0c9yC4DUlWAOvHAtISdFTc8xB86VmX5Xqabrama3wIPJ/q9RbGS1worb6JfnMg==} + '@typescript-eslint/parser@8.57.0': + resolution: {integrity: sha512-XZzOmihLIr8AD1b9hL9ccNMzEMWt/dE2u7NyTY9jJG6YNiNthaD5XtUHVF2uCXZ15ng+z2hT3MVuxnUYhq6k1g==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/project-service@8.50.1': - resolution: {integrity: sha512-E1ur1MCVf+YiP89+o4Les/oBAVzmSbeRB0MQLfSlYtbWU17HPxZ6Bhs5iYmKZRALvEuBoXIZMOIRRc/P++Ortg==} + '@typescript-eslint/project-service@8.57.0': + resolution: {integrity: sha512-pR+dK0BlxCLxtWfaKQWtYr7MhKmzqZxuii+ZjuFlZlIGRZm22HnXFqa2eY+90MUz8/i80YJmzFGDUsi8dMOV5w==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/scope-manager@8.50.1': - resolution: {integrity: sha512-mfRx06Myt3T4vuoHaKi8ZWNTPdzKPNBhiblze5N50//TSHOAQQevl/aolqA/BcqqbJ88GUnLqjjcBc8EWdBcVw==} + '@typescript-eslint/scope-manager@8.57.0': + resolution: {integrity: sha512-nvExQqAHF01lUM66MskSaZulpPL5pgy5hI5RfrxviLgzZVffB5yYzw27uK/ft8QnKXI2X0LBrHJFr1TaZtAibw==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/tsconfig-utils@8.50.1': - resolution: {integrity: sha512-ooHmotT/lCWLXi55G4mvaUF60aJa012QzvLK0Y+Mp4WdSt17QhMhWOaBWeGTFVkb2gDgBe19Cxy1elPXylslDw==} + '@typescript-eslint/tsconfig-utils@8.57.0': + resolution: {integrity: sha512-LtXRihc5ytjJIQEH+xqjB0+YgsV4/tW35XKX3GTZHpWtcC8SPkT/d4tqdf1cKtesryHm2bgp6l555NYcT2NLvA==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/type-utils@8.50.1': - resolution: {integrity: sha512-7J3bf022QZE42tYMO6SL+6lTPKFk/WphhRPe9Tw/el+cEwzLz1Jjz2PX3GtGQVxooLDKeMVmMt7fWpYRdG5Etg==} + '@typescript-eslint/type-utils@8.57.0': + resolution: {integrity: sha512-yjgh7gmDcJ1+TcEg8x3uWQmn8ifvSupnPfjP21twPKrDP/pTHlEQgmKcitzF/rzPSmv7QjJ90vRpN4U+zoUjwQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/types@8.50.1': - resolution: {integrity: sha512-v5lFIS2feTkNyMhd7AucE/9j/4V9v5iIbpVRncjk/K0sQ6Sb+Np9fgYS/63n6nwqahHQvbmujeBL7mp07Q9mlA==} + '@typescript-eslint/types@8.57.0': + resolution: {integrity: sha512-dTLI8PEXhjUC7B9Kre+u0XznO696BhXcTlOn0/6kf1fHaQW8+VjJAVHJ3eTI14ZapTxdkOmc80HblPQLaEeJdg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@typescript-eslint/typescript-estree@8.50.1': - resolution: {integrity: sha512-woHPdW+0gj53aM+cxchymJCrh0cyS7BTIdcDxWUNsclr9VDkOSbqC13juHzxOmQ22dDkMZEpZB+3X1WpUvzgVQ==} + '@typescript-eslint/typescript-estree@8.57.0': + resolution: {integrity: sha512-m7faHcyVg0BT3VdYTlX8GdJEM7COexXxS6KqGopxdtkQRvBanK377QDHr4W/vIPAR+ah9+B/RclSW5ldVniO1Q==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/utils@8.50.1': - resolution: {integrity: sha512-lCLp8H1T9T7gPbEuJSnHwnSuO9mDf8mfK/Nion5mZmiEaQD9sWf9W4dfeFqRyqRjF06/kBuTmAqcs9sewM2NbQ==} + '@typescript-eslint/utils@8.57.0': + resolution: {integrity: sha512-5iIHvpD3CZe06riAsbNxxreP+MuYgVUsV0n4bwLH//VJmgtt54sQeY2GszntJ4BjYCpMzrfVh2SBnUQTtys2lQ==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} peerDependencies: - eslint: ^8.57.0 || ^9.0.0 + eslint: ^8.57.0 || ^9.0.0 || ^10.0.0 typescript: '>=4.8.4 <6.0.0' - '@typescript-eslint/visitor-keys@8.50.1': - resolution: {integrity: sha512-IrDKrw7pCRUR94zeuCSUWQ+w8JEf5ZX5jl/e6AHGSLi1/zIr0lgutfn/7JpfCey+urpgQEdrZVYzCaVVKiTwhQ==} + '@typescript-eslint/visitor-keys@8.57.0': + resolution: {integrity: sha512-zm6xx8UT/Xy2oSr2ZXD0pZo7Jx2XsCoID2IUh9YSTFRu7z+WdwYTRk6LhUftm1crwqbuoF6I8zAFeCMw0YjwDg==} engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} - '@vitest/coverage-v8@4.0.16': - resolution: {integrity: sha512-2rNdjEIsPRzsdu6/9Eq0AYAzYdpP6Bx9cje9tL3FE5XzXRQF1fNU9pe/1yE8fCrS0HD+fBtt6gLPh6LI57tX7A==} + '@vitest/coverage-v8@4.0.18': + resolution: {integrity: sha512-7i+N2i0+ME+2JFZhfuz7Tg/FqKtilHjGyGvoHYQ6iLV0zahbsJ9sljC9OcFcPDbhYKCet+sG8SsVqlyGvPflZg==} peerDependencies: - '@vitest/browser': 4.0.16 - vitest: 4.0.16 + '@vitest/browser': 4.0.18 + vitest: 4.0.18 peerDependenciesMeta: '@vitest/browser': optional: true - '@vitest/expect@4.0.16': - resolution: {integrity: sha512-eshqULT2It7McaJkQGLkPjPjNph+uevROGuIMJdG3V+0BSR2w9u6J9Lwu+E8cK5TETlfou8GRijhafIMhXsimA==} + '@vitest/expect@4.0.18': + resolution: {integrity: sha512-8sCWUyckXXYvx4opfzVY03EOiYVxyNrHS5QxX3DAIi5dpJAAkyJezHCP77VMX4HKA2LDT/Jpfo8i2r5BE3GnQQ==} - '@vitest/mocker@4.0.16': - resolution: {integrity: sha512-yb6k4AZxJTB+q9ycAvsoxGn+j/po0UaPgajllBgt1PzoMAAmJGYFdDk0uCcRcxb3BrME34I6u8gHZTQlkqSZpg==} + '@vitest/mocker@4.0.18': + resolution: {integrity: sha512-HhVd0MDnzzsgevnOWCBj5Otnzobjy5wLBe4EdeeFGv8luMsGcYqDuFRMcttKWZA5vVO8RFjexVovXvAM4JoJDQ==} peerDependencies: msw: ^2.4.9 vite: ^6.0.0 || ^7.0.0-0 @@ -496,65 +518,57 @@ packages: vite: optional: true - '@vitest/pretty-format@4.0.16': - resolution: {integrity: sha512-eNCYNsSty9xJKi/UdVD8Ou16alu7AYiS2fCPRs0b1OdhJiV89buAXQLpTbe+X8V9L6qrs9CqyvU7OaAopJYPsA==} + '@vitest/pretty-format@4.0.18': + resolution: {integrity: sha512-P24GK3GulZWC5tz87ux0m8OADrQIUVDPIjjj65vBXYG17ZeU3qD7r+MNZ1RNv4l8CGU2vtTRqixrOi9fYk/yKw==} - '@vitest/runner@4.0.16': - resolution: {integrity: sha512-VWEDm5Wv9xEo80ctjORcTQRJ539EGPB3Pb9ApvVRAY1U/WkHXmmYISqU5E79uCwcW7xYUV38gwZD+RV755fu3Q==} + '@vitest/runner@4.0.18': + resolution: {integrity: sha512-rpk9y12PGa22Jg6g5M3UVVnTS7+zycIGk9ZNGN+m6tZHKQb7jrP7/77WfZy13Y/EUDd52NDsLRQhYKtv7XfPQw==} - '@vitest/snapshot@4.0.16': - resolution: {integrity: sha512-sf6NcrYhYBsSYefxnry+DR8n3UV4xWZwWxYbCJUt2YdvtqzSPR7VfGrY0zsv090DAbjFZsi7ZaMi1KnSRyK1XA==} + '@vitest/snapshot@4.0.18': + resolution: {integrity: sha512-PCiV0rcl7jKQjbgYqjtakly6T1uwv/5BQ9SwBLekVg/EaYeQFPiXcgrC2Y7vDMA8dM1SUEAEV82kgSQIlXNMvA==} - '@vitest/spy@4.0.16': - resolution: {integrity: sha512-4jIOWjKP0ZUaEmJm00E0cOBLU+5WE0BpeNr3XN6TEF05ltro6NJqHWxXD0kA8/Zc8Nh23AT8WQxwNG+WeROupw==} + '@vitest/spy@4.0.18': + resolution: {integrity: sha512-cbQt3PTSD7P2OARdVW3qWER5EGq7PHlvE+QfzSC0lbwO+xnt7+XH06ZzFjFRgzUX//JmpxrCu92VdwvEPlWSNw==} - '@vitest/utils@4.0.16': - resolution: {integrity: sha512-h8z9yYhV3e1LEfaQ3zdypIrnAg/9hguReGZoS7Gl0aBG5xgA410zBqECqmaF/+RkTggRsfnzc1XaAHA6bmUufA==} + '@vitest/utils@4.0.18': + resolution: {integrity: sha512-msMRKLMVLWygpK3u2Hybgi4MNjcYJvwTb0Ru09+fOyCXIgT5raYP041DRRdiJiI3k/2U6SEbAETB3YtBrUkCFA==} acorn-jsx@5.3.2: resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} peerDependencies: acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 - acorn@8.15.0: - resolution: {integrity: sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==} + acorn@8.16.0: + resolution: {integrity: sha512-UVJyE9MttOsBQIDKw1skb9nAwQuR5wuGD3+82K6JgJlm/Y+KI92oNsMNGZCYdDsVtRHSak0pcV5Dno5+4jh9sw==} engines: {node: '>=0.4.0'} hasBin: true - ajv@6.12.6: - resolution: {integrity: sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==} - - ansi-styles@4.3.0: - resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} - engines: {node: '>=8'} + ajv@6.14.0: + resolution: {integrity: sha512-IWrosm/yrn43eiKqkfkHis7QioDleaXQHdDVPKg0FSwwd/DuvyX79TZnFOnYpB7dcsFAMmtFztZuXPDvSePkFw==} any-promise@1.3.0: resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} - argparse@2.0.1: - resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} - assertion-error@2.0.1: resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} engines: {node: '>=12'} - ast-v8-to-istanbul@0.3.10: - resolution: {integrity: sha512-p4K7vMz2ZSk3wN8l5o3y2bJAoZXT3VuJI5OLTATY/01CYWumWvwkUw0SqDBnNq6IiTO3qDa1eSQDibAV8g7XOQ==} + ast-v8-to-istanbul@0.3.12: + resolution: {integrity: sha512-BRRC8VRZY2R4Z4lFIL35MwNXmwVqBityvOIwETtsCSwvjl0IdgFsy9NhdaA6j74nUdtJJlIypeRhpDam19Wq3g==} asynckit@0.4.0: resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} - axios@1.13.5: - resolution: {integrity: sha512-cz4ur7Vb0xS4/KUN0tPWe44eqxrIu31me+fbang3ijiNscE129POzipJJA6zniq2C/Z6sJCjMimjS8Lc/GAs8Q==} + axios@1.13.6: + resolution: {integrity: sha512-ChTCHMouEe2kn713WHbQGcuYrr6fXTBiu460OTwWrWob16g1bXn4vtz07Ope7ewMozJAnEquLk5lWQWtBig9DQ==} - balanced-match@1.0.2: - resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} + balanced-match@4.0.4: + resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==} + engines: {node: 18 || 20 || >=22} - brace-expansion@1.1.12: - resolution: {integrity: sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==} - - brace-expansion@2.0.2: - resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==} + brace-expansion@5.0.4: + resolution: {integrity: sha512-h+DEnpVvxmfVefa4jFbCf5HdH5YMDXRsmKflpf1pILZWRFlTbJpxeU55nJl4Smt5HQaGzg1o6RHFPJaOqnmBDg==} + engines: {node: 18 || 20 || >=22} bundle-require@5.1.0: resolution: {integrity: sha512-3WrrOuZiyaaZPWiEt4G3+IffISVC9HYlWueJEBWED4ZH4aIAC2PnkdnuRrR94M+w6yGWn4AglWtJtBI8YqvgoA==} @@ -570,29 +584,14 @@ packages: resolution: {integrity: sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==} engines: {node: '>= 0.4'} - callsites@3.1.0: - resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} - engines: {node: '>=6'} - chai@6.2.2: resolution: {integrity: sha512-NUPRluOfOiTKBKvWPtSD4PhFvWCqOi0BGStNWs57X9js7XGTprSmFoz5F0tWhR4WPjNeR9jXqdC7/UpSJTnlRg==} engines: {node: '>=18'} - chalk@4.1.2: - resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} - engines: {node: '>=10'} - chokidar@4.0.3: resolution: {integrity: sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==} engines: {node: '>= 14.16.0'} - color-convert@2.0.1: - resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} - engines: {node: '>=7.0.0'} - - color-name@1.1.4: - resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} - combined-stream@1.0.8: resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} engines: {node: '>= 0.8'} @@ -601,9 +600,6 @@ packages: resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==} engines: {node: '>= 6'} - concat-map@0.0.1: - resolution: {integrity: sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==} - confbox@0.1.8: resolution: {integrity: sha512-RMtmw0iFkeR4YV+fUOSucriAQNb9g8zFR52MWCtl+cCZOFRNL6zeB395vPzFhEjjn4fMxXudmELnl/KF/WrK6w==} @@ -654,8 +650,8 @@ packages: resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} engines: {node: '>= 0.4'} - esbuild@0.27.2: - resolution: {integrity: sha512-HyNQImnsOC7X9PMNaCIeAm4ISCQXs5a5YasTXVliKv4uuBo1dKrG0A+uQS8M5eXjVMnLg3WgXaKvprHlFJQffw==} + esbuild@0.27.3: + resolution: {integrity: sha512-8VwMnyGCONIs6cWue2IdpHxHnAjzxnw2Zr7MkVxB2vjmQ2ivqGFb4LEG3SMnv0Gb2F/G/2yA8zUaiL1gywDCCg==} engines: {node: '>=18'} hasBin: true @@ -663,21 +659,21 @@ packages: resolution: {integrity: sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==} engines: {node: '>=10'} - eslint-scope@8.4.0: - resolution: {integrity: sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint-scope@9.1.2: + resolution: {integrity: sha512-xS90H51cKw0jltxmvmHy2Iai1LIqrfbw57b79w/J7MfvDfkIkFZ+kj6zC3BjtUwh150HsSSdxXZcsuv72miDFQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} eslint-visitor-keys@3.4.3: resolution: {integrity: sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==} engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} - eslint-visitor-keys@4.2.1: - resolution: {integrity: sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint-visitor-keys@5.0.1: + resolution: {integrity: sha512-tD40eHxA35h0PEIZNeIjkHoDR4YjjJp34biM0mDvplBe//mB+IHCqHDGV7pxF+7MklTvighcCPPZC7ynWyjdTA==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - eslint@9.39.2: - resolution: {integrity: sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + eslint@10.0.3: + resolution: {integrity: sha512-COV33RzXZkqhG9P2rZCFl9ZmJ7WL+gQSCRzE7RhkbclbQPtLAWReL7ysA0Sh4c8Im2U9ynybdR56PV0XcKvqaQ==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} hasBin: true peerDependencies: jiti: '*' @@ -685,12 +681,12 @@ packages: jiti: optional: true - espree@10.4.0: - resolution: {integrity: sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==} - engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + espree@11.2.0: + resolution: {integrity: sha512-7p3DrVEIopW1B1avAGLuCSh1jubc01H2JHc8B4qqGblmg5gI9yumBgACjWo4JlIc04ufug4xJ3SQI8HkS/Rgzw==} + engines: {node: ^20.19.0 || ^22.13.0 || >=24} - esquery@1.6.0: - resolution: {integrity: sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==} + esquery@1.7.0: + resolution: {integrity: sha512-Ap6G0WQwcU/LHsvLwON1fAQX9Zp0A2Y6Y/cJBl9r/JbW90Zyg4/zbG6zzKa2OTALELarYHmKu0GhpM5EO+7T0g==} engines: {node: '>=0.10'} esrecurse@4.3.0: @@ -745,8 +741,8 @@ packages: resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} engines: {node: '>=16'} - flatted@3.3.3: - resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==} + flatted@3.4.1: + resolution: {integrity: sha512-IxfVbRFVlV8V/yRaGzk0UVIcsKKHMSfYw66T/u4nTwlWteQePsxe//LjudR1AMX4tZW3WFCh3Zqa/sjlqpbURQ==} follow-redirects@1.15.11: resolution: {integrity: sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==} @@ -781,10 +777,6 @@ packages: resolution: {integrity: sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==} engines: {node: '>=10.13.0'} - globals@14.0.0: - resolution: {integrity: sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==} - engines: {node: '>=18'} - gopd@1.2.0: resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} engines: {node: '>= 0.4'} @@ -816,10 +808,6 @@ packages: resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} engines: {node: '>= 4'} - import-fresh@3.3.1: - resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==} - engines: {node: '>=6'} - imurmurhash@0.1.4: resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} engines: {node: '>=0.8.19'} @@ -843,10 +831,6 @@ packages: resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==} engines: {node: '>=10'} - istanbul-lib-source-maps@5.0.6: - resolution: {integrity: sha512-yg2d+Em4KizZC5niWhQaIomgf5WlL4vOOjZ5xGCmF8SnPE/mDWWXgvRExdcpCgh9lLRRa1/fSYp2ymmbJ1pI+A==} - engines: {node: '>=10'} - istanbul-reports@3.2.0: resolution: {integrity: sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==} engines: {node: '>=8'} @@ -855,12 +839,8 @@ packages: resolution: {integrity: sha512-34wB/Y7MW7bzjKRjUKTa46I2Z7eV62Rkhva+KkopW7Qvv/OSWBqvkSY7vusOPrNuZcUG3tApvdVgNB8POj3SPw==} engines: {node: '>=10'} - js-tokens@9.0.1: - resolution: {integrity: sha512-mxa9E9ITFOt0ban3j6L5MpjwegGz6lBQmM1IJkWeBZGcMxto50+eWdjC/52xDbS2vy0k7vIMK0Fe2wfL9OQSpQ==} - - js-yaml@4.1.1: - resolution: {integrity: sha512-qQKT4zQxXl8lLwBtHMWwaTcGfFOZviOJet3Oy/xmGk2gZH677CJM9EvtfdSkgWcATZhj/55JZ0rmy3myCT5lsA==} - hasBin: true + js-tokens@10.0.0: + resolution: {integrity: sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==} json-buffer@3.0.1: resolution: {integrity: sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==} @@ -893,14 +873,11 @@ packages: resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} engines: {node: '>=10'} - lodash.merge@4.6.2: - resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} - magic-string@0.30.21: resolution: {integrity: sha512-vd2F4YUyEXKGcLHoq+TEyCjxueSeHnFxyyjNp80yg0XV4vUhnDer/lvvlqM/arB5bXQN5K2/3oinyCRyx8T2CQ==} - magicast@0.5.1: - resolution: {integrity: sha512-xrHS24IxaLrvuo613F719wvOIv9xPHFWQHuvGUBmPnCA/3MQxKI3b+r7n1jAoDHmsbC5bRhTZYR77invLAxVnw==} + magicast@0.5.2: + resolution: {integrity: sha512-E3ZJh4J3S9KfwdjZhe2afj6R9lGIN5Pher1pF39UGrXRqq/VDaGVIGN13BjHd2u8B61hArAGOnso7nBOouW3TQ==} make-dir@4.0.0: resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==} @@ -918,15 +895,12 @@ packages: resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} engines: {node: '>= 0.6'} - minimatch@3.1.2: - resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} + minimatch@10.2.4: + resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} + engines: {node: 18 || 20 || >=22} - minimatch@9.0.5: - resolution: {integrity: sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==} - engines: {node: '>=16 || 14 >=14.17'} - - mlly@1.8.0: - resolution: {integrity: sha512-l8D9ODSRWLe2KHJSifWGwBqpTZXIXTeo8mlKjY+E2HAakaTeNpqAyBZ8GSqLzHgw4XmHmC8whvpjJNMbFZN7/g==} + mlly@1.8.1: + resolution: {integrity: sha512-SnL6sNutTwRWWR/vcmCYHSADjiEesp5TGQQ0pXyLhW5IoeibRlF/CbSLailbB3CNqJUk9cVJ9dUDnbD7GrcHBQ==} ms@2.1.3: resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} @@ -961,10 +935,6 @@ packages: resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} engines: {node: '>=10'} - parent-module@1.0.1: - resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} - engines: {node: '>=6'} - path-exists@4.0.0: resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==} engines: {node: '>=8'} @@ -1008,8 +978,8 @@ packages: yaml: optional: true - postcss@8.5.6: - resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==} + postcss@8.5.8: + resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} engines: {node: ^10 || ^12 || >=14} prelude-ls@1.2.1: @@ -1027,21 +997,17 @@ packages: resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} engines: {node: '>= 14.18.0'} - resolve-from@4.0.0: - resolution: {integrity: sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==} - engines: {node: '>=4'} - resolve-from@5.0.0: resolution: {integrity: sha512-qYg9KP24dD5qka9J47d0aVky0N+b4fTU89LN9iDnjB5waksiC49rvMB0PrUJQGoTmH50XPiqOvAjDfaijGxYZw==} engines: {node: '>=8'} - rollup@4.54.0: - resolution: {integrity: sha512-3nk8Y3a9Ea8szgKhinMlGMhGMw89mqule3KWczxhIzqudyHdCIOHw8WJlj/r329fACjKLEh13ZSk7oE22kyeIw==} + rollup@4.59.0: + resolution: {integrity: sha512-2oMpl67a3zCH9H79LeMcbDhXW/UmWG/y2zuqnF2jQq5uq9TbM9TVyXvA4+t+ne2IIkBdrLpAaRQAvo7YI/Yyeg==} engines: {node: '>=18.0.0', npm: '>=8.0.0'} hasBin: true - semver@7.7.3: - resolution: {integrity: sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==} + semver@7.7.4: + resolution: {integrity: sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==} engines: {node: '>=10'} hasBin: true @@ -1070,10 +1036,6 @@ packages: std-env@3.10.0: resolution: {integrity: sha512-5GS12FdOZNliM5mAOxFRg7Ir0pWz8MdpYm6AY6VPkGpbA7ZzmbzNcBJQ0GPvvyWgcY7QAhCgf9Uy89I03faLkg==} - strip-json-comments@3.1.1: - resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} - engines: {node: '>=8'} - sucrase@3.35.1: resolution: {integrity: sha512-DhuTmvZWux4H1UOnWMB3sk0sbaCVOoQZjv8u1rDoTV0HTdGem9hkAZtl4JZy8P2z4Bg0nT+YMeOFyVr4zcG5Tw==} engines: {node: '>=16 || 14 >=14.17'} @@ -1112,8 +1074,8 @@ packages: resolution: {integrity: sha512-L0Orpi8qGpRG//Nd+H90vFB+3iHnue1zSSGmNOOCh1GLJ7rUKVwV2HvijphGQS2UmhUZewS9VgvxYIdgr+fG1A==} hasBin: true - ts-api-utils@2.1.0: - resolution: {integrity: sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==} + ts-api-utils@2.4.0: + resolution: {integrity: sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==} engines: {node: '>=18.12'} peerDependencies: typescript: '>=4.8.4' @@ -1149,17 +1111,17 @@ packages: engines: {node: '>=14.17'} hasBin: true - ufo@1.6.1: - resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} + ufo@1.6.3: + resolution: {integrity: sha512-yDJTmhydvl5lJzBmy/hyOAA0d+aqCBuwl818haVdYCRrWV84o7YyeVm4QlVHStqNrrJSTb6jKuFAVqAFsr+K3Q==} - undici-types@7.16.0: - resolution: {integrity: sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==} + undici-types@7.18.2: + resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==} uri-js@4.4.1: resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} - vite@7.3.0: - resolution: {integrity: sha512-dZwN5L1VlUBewiP6H9s2+B3e3Jg96D0vzN+Ry73sOefebhYr9f94wwkMNN/9ouoU8pV1BqA1d1zGk8928cx0rg==} + vite@7.3.1: + resolution: {integrity: sha512-w+N7Hifpc3gRjZ63vYBXA56dvvRlNWRczTdmCBBa+CotUzAPf5b7YMdMR/8CQoeYE5LX3W4wj6RYTgonm1b9DA==} engines: {node: ^20.19.0 || >=22.12.0} hasBin: true peerDependencies: @@ -1198,18 +1160,18 @@ packages: yaml: optional: true - vitest@4.0.16: - resolution: {integrity: sha512-E4t7DJ9pESL6E3I8nFjPa4xGUd3PmiWDLsDztS2qXSJWfHtbQnwAWylaBvSNY48I3vr8PTqIZlyK8TE3V3CA4Q==} + vitest@4.0.18: + resolution: {integrity: sha512-hOQuK7h0FGKgBAas7v0mSAsnvrIgAvWmRFjmzpJ7SwFHH3g1k2u37JtYwOwmEKhK6ZO3v9ggDBBm0La1LCK4uQ==} engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} hasBin: true peerDependencies: '@edge-runtime/vm': '*' '@opentelemetry/api': ^1.9.0 '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 - '@vitest/browser-playwright': 4.0.16 - '@vitest/browser-preview': 4.0.16 - '@vitest/browser-webdriverio': 4.0.16 - '@vitest/ui': 4.0.16 + '@vitest/browser-playwright': 4.0.18 + '@vitest/browser-preview': 4.0.18 + '@vitest/browser-webdriverio': 4.0.18 + '@vitest/ui': 4.0.18 happy-dom: '*' jsdom: '*' peerDependenciesMeta: @@ -1256,139 +1218,127 @@ snapshots: '@babel/helper-validator-identifier@7.28.5': {} - '@babel/parser@7.28.5': + '@babel/parser@7.29.0': dependencies: - '@babel/types': 7.28.5 + '@babel/types': 7.29.0 - '@babel/types@7.28.5': + '@babel/types@7.29.0': dependencies: '@babel/helper-string-parser': 7.27.1 '@babel/helper-validator-identifier': 7.28.5 '@bcoe/v8-coverage@1.0.2': {} - '@esbuild/aix-ppc64@0.27.2': + '@esbuild/aix-ppc64@0.27.3': optional: true - '@esbuild/android-arm64@0.27.2': + '@esbuild/android-arm64@0.27.3': optional: true - '@esbuild/android-arm@0.27.2': + '@esbuild/android-arm@0.27.3': optional: true - '@esbuild/android-x64@0.27.2': + '@esbuild/android-x64@0.27.3': optional: true - '@esbuild/darwin-arm64@0.27.2': + '@esbuild/darwin-arm64@0.27.3': optional: true - '@esbuild/darwin-x64@0.27.2': + '@esbuild/darwin-x64@0.27.3': optional: true - '@esbuild/freebsd-arm64@0.27.2': + '@esbuild/freebsd-arm64@0.27.3': optional: true - '@esbuild/freebsd-x64@0.27.2': + '@esbuild/freebsd-x64@0.27.3': optional: true - '@esbuild/linux-arm64@0.27.2': + '@esbuild/linux-arm64@0.27.3': optional: true - '@esbuild/linux-arm@0.27.2': + '@esbuild/linux-arm@0.27.3': optional: true - '@esbuild/linux-ia32@0.27.2': + '@esbuild/linux-ia32@0.27.3': optional: true - '@esbuild/linux-loong64@0.27.2': + '@esbuild/linux-loong64@0.27.3': optional: true - '@esbuild/linux-mips64el@0.27.2': + '@esbuild/linux-mips64el@0.27.3': optional: true - '@esbuild/linux-ppc64@0.27.2': + '@esbuild/linux-ppc64@0.27.3': optional: true - '@esbuild/linux-riscv64@0.27.2': + '@esbuild/linux-riscv64@0.27.3': optional: true - '@esbuild/linux-s390x@0.27.2': + '@esbuild/linux-s390x@0.27.3': optional: true - '@esbuild/linux-x64@0.27.2': + '@esbuild/linux-x64@0.27.3': optional: true - '@esbuild/netbsd-arm64@0.27.2': + '@esbuild/netbsd-arm64@0.27.3': optional: true - '@esbuild/netbsd-x64@0.27.2': + '@esbuild/netbsd-x64@0.27.3': optional: true - '@esbuild/openbsd-arm64@0.27.2': + '@esbuild/openbsd-arm64@0.27.3': optional: true - '@esbuild/openbsd-x64@0.27.2': + '@esbuild/openbsd-x64@0.27.3': optional: true - '@esbuild/openharmony-arm64@0.27.2': + '@esbuild/openharmony-arm64@0.27.3': optional: true - '@esbuild/sunos-x64@0.27.2': + '@esbuild/sunos-x64@0.27.3': optional: true - '@esbuild/win32-arm64@0.27.2': + '@esbuild/win32-arm64@0.27.3': optional: true - '@esbuild/win32-ia32@0.27.2': + '@esbuild/win32-ia32@0.27.3': optional: true - '@esbuild/win32-x64@0.27.2': + '@esbuild/win32-x64@0.27.3': optional: true - '@eslint-community/eslint-utils@4.9.0(eslint@9.39.2)': + '@eslint-community/eslint-utils@4.9.1(eslint@10.0.3)': dependencies: - eslint: 9.39.2 + eslint: 10.0.3 eslint-visitor-keys: 3.4.3 '@eslint-community/regexpp@4.12.2': {} - '@eslint/config-array@0.21.1': + '@eslint/config-array@0.23.3': dependencies: - '@eslint/object-schema': 2.1.7 + '@eslint/object-schema': 3.0.3 debug: 4.4.3 - minimatch: 3.1.2 + minimatch: 10.2.4 transitivePeerDependencies: - supports-color - '@eslint/config-helpers@0.4.2': + '@eslint/config-helpers@0.5.3': dependencies: - '@eslint/core': 0.17.0 + '@eslint/core': 1.1.1 - '@eslint/core@0.17.0': + '@eslint/core@1.1.1': dependencies: '@types/json-schema': 7.0.15 - '@eslint/eslintrc@3.3.3': + '@eslint/js@10.0.1(eslint@10.0.3)': + optionalDependencies: + eslint: 10.0.3 + + '@eslint/object-schema@3.0.3': {} + + '@eslint/plugin-kit@0.6.1': dependencies: - ajv: 6.12.6 - debug: 4.4.3 - espree: 10.4.0 - globals: 14.0.0 - ignore: 5.3.2 - import-fresh: 3.3.1 - js-yaml: 4.1.1 - minimatch: 3.1.2 - strip-json-comments: 3.1.1 - transitivePeerDependencies: - - supports-color - - '@eslint/js@9.39.2': {} - - '@eslint/object-schema@2.1.7': {} - - '@eslint/plugin-kit@0.4.1': - dependencies: - '@eslint/core': 0.17.0 + '@eslint/core': 1.1.1 levn: 0.4.1 '@humanfs/core@0.19.1': {} @@ -1416,70 +1366,79 @@ snapshots: '@jridgewell/resolve-uri': 3.1.2 '@jridgewell/sourcemap-codec': 1.5.5 - '@rollup/rollup-android-arm-eabi@4.54.0': + '@rollup/rollup-android-arm-eabi@4.59.0': optional: true - '@rollup/rollup-android-arm64@4.54.0': + '@rollup/rollup-android-arm64@4.59.0': optional: true - '@rollup/rollup-darwin-arm64@4.54.0': + '@rollup/rollup-darwin-arm64@4.59.0': optional: true - '@rollup/rollup-darwin-x64@4.54.0': + '@rollup/rollup-darwin-x64@4.59.0': optional: true - '@rollup/rollup-freebsd-arm64@4.54.0': + '@rollup/rollup-freebsd-arm64@4.59.0': optional: true - '@rollup/rollup-freebsd-x64@4.54.0': + '@rollup/rollup-freebsd-x64@4.59.0': optional: true - '@rollup/rollup-linux-arm-gnueabihf@4.54.0': + '@rollup/rollup-linux-arm-gnueabihf@4.59.0': optional: true - '@rollup/rollup-linux-arm-musleabihf@4.54.0': + '@rollup/rollup-linux-arm-musleabihf@4.59.0': optional: true - '@rollup/rollup-linux-arm64-gnu@4.54.0': + '@rollup/rollup-linux-arm64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-arm64-musl@4.54.0': + '@rollup/rollup-linux-arm64-musl@4.59.0': optional: true - '@rollup/rollup-linux-loong64-gnu@4.54.0': + '@rollup/rollup-linux-loong64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-ppc64-gnu@4.54.0': + '@rollup/rollup-linux-loong64-musl@4.59.0': optional: true - '@rollup/rollup-linux-riscv64-gnu@4.54.0': + '@rollup/rollup-linux-ppc64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-riscv64-musl@4.54.0': + '@rollup/rollup-linux-ppc64-musl@4.59.0': optional: true - '@rollup/rollup-linux-s390x-gnu@4.54.0': + '@rollup/rollup-linux-riscv64-gnu@4.59.0': optional: true - '@rollup/rollup-linux-x64-gnu@4.54.0': + '@rollup/rollup-linux-riscv64-musl@4.59.0': optional: true - '@rollup/rollup-linux-x64-musl@4.54.0': + '@rollup/rollup-linux-s390x-gnu@4.59.0': optional: true - '@rollup/rollup-openharmony-arm64@4.54.0': + '@rollup/rollup-linux-x64-gnu@4.59.0': optional: true - '@rollup/rollup-win32-arm64-msvc@4.54.0': + '@rollup/rollup-linux-x64-musl@4.59.0': optional: true - '@rollup/rollup-win32-ia32-msvc@4.54.0': + '@rollup/rollup-openbsd-x64@4.59.0': optional: true - '@rollup/rollup-win32-x64-gnu@4.54.0': + '@rollup/rollup-openharmony-arm64@4.59.0': optional: true - '@rollup/rollup-win32-x64-msvc@4.54.0': + '@rollup/rollup-win32-arm64-msvc@4.59.0': + optional: true + + '@rollup/rollup-win32-ia32-msvc@4.59.0': + optional: true + + '@rollup/rollup-win32-x64-gnu@4.59.0': + optional: true + + '@rollup/rollup-win32-x64-msvc@4.59.0': optional: true '@standard-schema/spec@1.1.0': {} @@ -1491,193 +1450,186 @@ snapshots: '@types/deep-eql@4.0.2': {} + '@types/esrecurse@4.3.1': {} + '@types/estree@1.0.8': {} '@types/json-schema@7.0.15': {} - '@types/node@25.0.3': + '@types/node@25.4.0': dependencies: - undici-types: 7.16.0 + undici-types: 7.18.2 - '@typescript-eslint/eslint-plugin@8.50.1(@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3))(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/eslint-plugin@8.57.0(@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3))(eslint@10.0.3)(typescript@5.9.3)': dependencies: '@eslint-community/regexpp': 4.12.2 - '@typescript-eslint/parser': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/type-utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.50.1 - eslint: 9.39.2 + '@typescript-eslint/parser': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/type-utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.0 + eslint: 10.0.3 ignore: 7.0.5 natural-compare: 1.4.0 - ts-api-utils: 2.1.0(typescript@5.9.3) + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/parser@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/parser@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + '@typescript-eslint/visitor-keys': 8.57.0 debug: 4.4.3 - eslint: 9.39.2 + eslint: 10.0.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/project-service@8.50.1(typescript@5.9.3)': + '@typescript-eslint/project-service@8.57.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/tsconfig-utils': 8.50.1(typescript@5.9.3) - '@typescript-eslint/types': 8.50.1 + '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 debug: 4.4.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/scope-manager@8.50.1': + '@typescript-eslint/scope-manager@8.57.0': dependencies: - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/visitor-keys': 8.57.0 - '@typescript-eslint/tsconfig-utils@8.50.1(typescript@5.9.3)': + '@typescript-eslint/tsconfig-utils@8.57.0(typescript@5.9.3)': dependencies: typescript: 5.9.3 - '@typescript-eslint/type-utils@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/type-utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - '@typescript-eslint/utils': 8.50.1(eslint@9.39.2)(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + '@typescript-eslint/utils': 8.57.0(eslint@10.0.3)(typescript@5.9.3) debug: 4.4.3 - eslint: 9.39.2 - ts-api-utils: 2.1.0(typescript@5.9.3) + eslint: 10.0.3 + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/types@8.50.1': {} + '@typescript-eslint/types@8.57.0': {} - '@typescript-eslint/typescript-estree@8.50.1(typescript@5.9.3)': + '@typescript-eslint/typescript-estree@8.57.0(typescript@5.9.3)': dependencies: - '@typescript-eslint/project-service': 8.50.1(typescript@5.9.3) - '@typescript-eslint/tsconfig-utils': 8.50.1(typescript@5.9.3) - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/visitor-keys': 8.50.1 + '@typescript-eslint/project-service': 8.57.0(typescript@5.9.3) + '@typescript-eslint/tsconfig-utils': 8.57.0(typescript@5.9.3) + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/visitor-keys': 8.57.0 debug: 4.4.3 - minimatch: 9.0.5 - semver: 7.7.3 + minimatch: 10.2.4 + semver: 7.7.4 tinyglobby: 0.2.15 - ts-api-utils: 2.1.0(typescript@5.9.3) + ts-api-utils: 2.4.0(typescript@5.9.3) typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/utils@8.50.1(eslint@9.39.2)(typescript@5.9.3)': + '@typescript-eslint/utils@8.57.0(eslint@10.0.3)(typescript@5.9.3)': dependencies: - '@eslint-community/eslint-utils': 4.9.0(eslint@9.39.2) - '@typescript-eslint/scope-manager': 8.50.1 - '@typescript-eslint/types': 8.50.1 - '@typescript-eslint/typescript-estree': 8.50.1(typescript@5.9.3) - eslint: 9.39.2 + '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) + '@typescript-eslint/scope-manager': 8.57.0 + '@typescript-eslint/types': 8.57.0 + '@typescript-eslint/typescript-estree': 8.57.0(typescript@5.9.3) + eslint: 10.0.3 typescript: 5.9.3 transitivePeerDependencies: - supports-color - '@typescript-eslint/visitor-keys@8.50.1': + '@typescript-eslint/visitor-keys@8.57.0': dependencies: - '@typescript-eslint/types': 8.50.1 - eslint-visitor-keys: 4.2.1 + '@typescript-eslint/types': 8.57.0 + eslint-visitor-keys: 5.0.1 - '@vitest/coverage-v8@4.0.16(vitest@4.0.16(@types/node@25.0.3))': + '@vitest/coverage-v8@4.0.18(vitest@4.0.18(@types/node@25.4.0))': dependencies: '@bcoe/v8-coverage': 1.0.2 - '@vitest/utils': 4.0.16 - ast-v8-to-istanbul: 0.3.10 + '@vitest/utils': 4.0.18 + ast-v8-to-istanbul: 0.3.12 istanbul-lib-coverage: 3.2.2 istanbul-lib-report: 3.0.1 - istanbul-lib-source-maps: 5.0.6 istanbul-reports: 3.2.0 - magicast: 0.5.1 + magicast: 0.5.2 obug: 2.1.1 std-env: 3.10.0 tinyrainbow: 3.0.3 - vitest: 4.0.16(@types/node@25.0.3) - transitivePeerDependencies: - - supports-color + vitest: 4.0.18(@types/node@25.4.0) - '@vitest/expect@4.0.16': + '@vitest/expect@4.0.18': dependencies: '@standard-schema/spec': 1.1.0 '@types/chai': 5.2.3 - '@vitest/spy': 4.0.16 - '@vitest/utils': 4.0.16 + '@vitest/spy': 4.0.18 + '@vitest/utils': 4.0.18 chai: 6.2.2 tinyrainbow: 3.0.3 - '@vitest/mocker@4.0.16(vite@7.3.0(@types/node@25.0.3))': + '@vitest/mocker@4.0.18(vite@7.3.1(@types/node@25.4.0))': dependencies: - '@vitest/spy': 4.0.16 + '@vitest/spy': 4.0.18 estree-walker: 3.0.3 magic-string: 0.30.21 optionalDependencies: - vite: 7.3.0(@types/node@25.0.3) + vite: 7.3.1(@types/node@25.4.0) - '@vitest/pretty-format@4.0.16': + '@vitest/pretty-format@4.0.18': dependencies: tinyrainbow: 3.0.3 - '@vitest/runner@4.0.16': + '@vitest/runner@4.0.18': dependencies: - '@vitest/utils': 4.0.16 + '@vitest/utils': 4.0.18 pathe: 2.0.3 - '@vitest/snapshot@4.0.16': + '@vitest/snapshot@4.0.18': dependencies: - '@vitest/pretty-format': 4.0.16 + '@vitest/pretty-format': 4.0.18 magic-string: 0.30.21 pathe: 2.0.3 - '@vitest/spy@4.0.16': {} + '@vitest/spy@4.0.18': {} - '@vitest/utils@4.0.16': + '@vitest/utils@4.0.18': dependencies: - '@vitest/pretty-format': 4.0.16 + '@vitest/pretty-format': 4.0.18 tinyrainbow: 3.0.3 - acorn-jsx@5.3.2(acorn@8.15.0): + acorn-jsx@5.3.2(acorn@8.16.0): dependencies: - acorn: 8.15.0 + acorn: 8.16.0 - acorn@8.15.0: {} + acorn@8.16.0: {} - ajv@6.12.6: + ajv@6.14.0: dependencies: fast-deep-equal: 3.1.3 fast-json-stable-stringify: 2.1.0 json-schema-traverse: 0.4.1 uri-js: 4.4.1 - ansi-styles@4.3.0: - dependencies: - color-convert: 2.0.1 - any-promise@1.3.0: {} - argparse@2.0.1: {} - assertion-error@2.0.1: {} - ast-v8-to-istanbul@0.3.10: + ast-v8-to-istanbul@0.3.12: dependencies: '@jridgewell/trace-mapping': 0.3.31 estree-walker: 3.0.3 - js-tokens: 9.0.1 + js-tokens: 10.0.0 asynckit@0.4.0: {} - axios@1.13.5: + axios@1.13.6: dependencies: follow-redirects: 1.15.11 form-data: 4.0.5 @@ -1685,20 +1637,15 @@ snapshots: transitivePeerDependencies: - debug - balanced-match@1.0.2: {} + balanced-match@4.0.4: {} - brace-expansion@1.1.12: + brace-expansion@5.0.4: dependencies: - balanced-match: 1.0.2 - concat-map: 0.0.1 + balanced-match: 4.0.4 - brace-expansion@2.0.2: + bundle-require@5.1.0(esbuild@0.27.3): dependencies: - balanced-match: 1.0.2 - - bundle-require@5.1.0(esbuild@0.27.2): - dependencies: - esbuild: 0.27.2 + esbuild: 0.27.3 load-tsconfig: 0.2.5 cac@6.7.14: {} @@ -1708,33 +1655,18 @@ snapshots: es-errors: 1.3.0 function-bind: 1.1.2 - callsites@3.1.0: {} - chai@6.2.2: {} - chalk@4.1.2: - dependencies: - ansi-styles: 4.3.0 - supports-color: 7.2.0 - chokidar@4.0.3: dependencies: readdirp: 4.1.2 - color-convert@2.0.1: - dependencies: - color-name: 1.1.4 - - color-name@1.1.4: {} - combined-stream@1.0.8: dependencies: delayed-stream: 1.0.0 commander@4.1.1: {} - concat-map@0.0.1: {} - confbox@0.1.8: {} consola@3.4.2: {} @@ -1776,69 +1708,68 @@ snapshots: has-tostringtag: 1.0.2 hasown: 2.0.2 - esbuild@0.27.2: + esbuild@0.27.3: optionalDependencies: - '@esbuild/aix-ppc64': 0.27.2 - '@esbuild/android-arm': 0.27.2 - '@esbuild/android-arm64': 0.27.2 - '@esbuild/android-x64': 0.27.2 - '@esbuild/darwin-arm64': 0.27.2 - '@esbuild/darwin-x64': 0.27.2 - '@esbuild/freebsd-arm64': 0.27.2 - '@esbuild/freebsd-x64': 0.27.2 - '@esbuild/linux-arm': 0.27.2 - '@esbuild/linux-arm64': 0.27.2 - '@esbuild/linux-ia32': 0.27.2 - '@esbuild/linux-loong64': 0.27.2 - '@esbuild/linux-mips64el': 0.27.2 - '@esbuild/linux-ppc64': 0.27.2 - '@esbuild/linux-riscv64': 0.27.2 - '@esbuild/linux-s390x': 0.27.2 - '@esbuild/linux-x64': 0.27.2 - '@esbuild/netbsd-arm64': 0.27.2 - '@esbuild/netbsd-x64': 0.27.2 - '@esbuild/openbsd-arm64': 0.27.2 - '@esbuild/openbsd-x64': 0.27.2 - '@esbuild/openharmony-arm64': 0.27.2 - '@esbuild/sunos-x64': 0.27.2 - '@esbuild/win32-arm64': 0.27.2 - '@esbuild/win32-ia32': 0.27.2 - '@esbuild/win32-x64': 0.27.2 + '@esbuild/aix-ppc64': 0.27.3 + '@esbuild/android-arm': 0.27.3 + '@esbuild/android-arm64': 0.27.3 + '@esbuild/android-x64': 0.27.3 + '@esbuild/darwin-arm64': 0.27.3 + '@esbuild/darwin-x64': 0.27.3 + '@esbuild/freebsd-arm64': 0.27.3 + '@esbuild/freebsd-x64': 0.27.3 + '@esbuild/linux-arm': 0.27.3 + '@esbuild/linux-arm64': 0.27.3 + '@esbuild/linux-ia32': 0.27.3 + '@esbuild/linux-loong64': 0.27.3 + '@esbuild/linux-mips64el': 0.27.3 + '@esbuild/linux-ppc64': 0.27.3 + '@esbuild/linux-riscv64': 0.27.3 + '@esbuild/linux-s390x': 0.27.3 + '@esbuild/linux-x64': 0.27.3 + '@esbuild/netbsd-arm64': 0.27.3 + '@esbuild/netbsd-x64': 0.27.3 + '@esbuild/openbsd-arm64': 0.27.3 + '@esbuild/openbsd-x64': 0.27.3 + '@esbuild/openharmony-arm64': 0.27.3 + '@esbuild/sunos-x64': 0.27.3 + '@esbuild/win32-arm64': 0.27.3 + '@esbuild/win32-ia32': 0.27.3 + '@esbuild/win32-x64': 0.27.3 escape-string-regexp@4.0.0: {} - eslint-scope@8.4.0: + eslint-scope@9.1.2: dependencies: + '@types/esrecurse': 4.3.1 + '@types/estree': 1.0.8 esrecurse: 4.3.0 estraverse: 5.3.0 eslint-visitor-keys@3.4.3: {} - eslint-visitor-keys@4.2.1: {} + eslint-visitor-keys@5.0.1: {} - eslint@9.39.2: + eslint@10.0.3: dependencies: - '@eslint-community/eslint-utils': 4.9.0(eslint@9.39.2) + '@eslint-community/eslint-utils': 4.9.1(eslint@10.0.3) '@eslint-community/regexpp': 4.12.2 - '@eslint/config-array': 0.21.1 - '@eslint/config-helpers': 0.4.2 - '@eslint/core': 0.17.0 - '@eslint/eslintrc': 3.3.3 - '@eslint/js': 9.39.2 - '@eslint/plugin-kit': 0.4.1 + '@eslint/config-array': 0.23.3 + '@eslint/config-helpers': 0.5.3 + '@eslint/core': 1.1.1 + '@eslint/plugin-kit': 0.6.1 '@humanfs/node': 0.16.7 '@humanwhocodes/module-importer': 1.0.1 '@humanwhocodes/retry': 0.4.3 '@types/estree': 1.0.8 - ajv: 6.12.6 - chalk: 4.1.2 + ajv: 6.14.0 cross-spawn: 7.0.6 debug: 4.4.3 escape-string-regexp: 4.0.0 - eslint-scope: 8.4.0 - eslint-visitor-keys: 4.2.1 - espree: 10.4.0 - esquery: 1.6.0 + eslint-scope: 9.1.2 + eslint-visitor-keys: 5.0.1 + espree: 11.2.0 + esquery: 1.7.0 esutils: 2.0.3 fast-deep-equal: 3.1.3 file-entry-cache: 8.0.0 @@ -1848,20 +1779,19 @@ snapshots: imurmurhash: 0.1.4 is-glob: 4.0.3 json-stable-stringify-without-jsonify: 1.0.1 - lodash.merge: 4.6.2 - minimatch: 3.1.2 + minimatch: 10.2.4 natural-compare: 1.4.0 optionator: 0.9.4 transitivePeerDependencies: - supports-color - espree@10.4.0: + espree@11.2.0: dependencies: - acorn: 8.15.0 - acorn-jsx: 5.3.2(acorn@8.15.0) - eslint-visitor-keys: 4.2.1 + acorn: 8.16.0 + acorn-jsx: 5.3.2(acorn@8.16.0) + eslint-visitor-keys: 5.0.1 - esquery@1.6.0: + esquery@1.7.0: dependencies: estraverse: 5.3.0 @@ -1901,15 +1831,15 @@ snapshots: fix-dts-default-cjs-exports@1.0.1: dependencies: magic-string: 0.30.21 - mlly: 1.8.0 - rollup: 4.54.0 + mlly: 1.8.1 + rollup: 4.59.0 flat-cache@4.0.1: dependencies: - flatted: 3.3.3 + flatted: 3.4.1 keyv: 4.5.4 - flatted@3.3.3: {} + flatted@3.4.1: {} follow-redirects@1.15.11: {} @@ -1948,8 +1878,6 @@ snapshots: dependencies: is-glob: 4.0.3 - globals@14.0.0: {} - gopd@1.2.0: {} has-flag@4.0.0: {} @@ -1970,11 +1898,6 @@ snapshots: ignore@7.0.5: {} - import-fresh@3.3.1: - dependencies: - parent-module: 1.0.1 - resolve-from: 4.0.0 - imurmurhash@0.1.4: {} is-extglob@2.1.1: {} @@ -1993,14 +1916,6 @@ snapshots: make-dir: 4.0.0 supports-color: 7.2.0 - istanbul-lib-source-maps@5.0.6: - dependencies: - '@jridgewell/trace-mapping': 0.3.31 - debug: 4.4.3 - istanbul-lib-coverage: 3.2.2 - transitivePeerDependencies: - - supports-color - istanbul-reports@3.2.0: dependencies: html-escaper: 2.0.2 @@ -2008,11 +1923,7 @@ snapshots: joycon@3.1.1: {} - js-tokens@9.0.1: {} - - js-yaml@4.1.1: - dependencies: - argparse: 2.0.1 + js-tokens@10.0.0: {} json-buffer@3.0.1: {} @@ -2039,21 +1950,19 @@ snapshots: dependencies: p-locate: 5.0.0 - lodash.merge@4.6.2: {} - magic-string@0.30.21: dependencies: '@jridgewell/sourcemap-codec': 1.5.5 - magicast@0.5.1: + magicast@0.5.2: dependencies: - '@babel/parser': 7.28.5 - '@babel/types': 7.28.5 + '@babel/parser': 7.29.0 + '@babel/types': 7.29.0 source-map-js: 1.2.1 make-dir@4.0.0: dependencies: - semver: 7.7.3 + semver: 7.7.4 math-intrinsics@1.1.0: {} @@ -2063,20 +1972,16 @@ snapshots: dependencies: mime-db: 1.52.0 - minimatch@3.1.2: + minimatch@10.2.4: dependencies: - brace-expansion: 1.1.12 + brace-expansion: 5.0.4 - minimatch@9.0.5: + mlly@1.8.1: dependencies: - brace-expansion: 2.0.2 - - mlly@1.8.0: - dependencies: - acorn: 8.15.0 + acorn: 8.16.0 pathe: 2.0.3 pkg-types: 1.3.1 - ufo: 1.6.1 + ufo: 1.6.3 ms@2.1.3: {} @@ -2111,10 +2016,6 @@ snapshots: dependencies: p-limit: 3.1.0 - parent-module@1.0.1: - dependencies: - callsites: 3.1.0 - path-exists@4.0.0: {} path-key@3.1.1: {} @@ -2130,16 +2031,16 @@ snapshots: pkg-types@1.3.1: dependencies: confbox: 0.1.8 - mlly: 1.8.0 + mlly: 1.8.1 pathe: 2.0.3 - postcss-load-config@6.0.1(postcss@8.5.6): + postcss-load-config@6.0.1(postcss@8.5.8): dependencies: lilconfig: 3.1.3 optionalDependencies: - postcss: 8.5.6 + postcss: 8.5.8 - postcss@8.5.6: + postcss@8.5.8: dependencies: nanoid: 3.3.11 picocolors: 1.1.1 @@ -2153,39 +2054,40 @@ snapshots: readdirp@4.1.2: {} - resolve-from@4.0.0: {} - resolve-from@5.0.0: {} - rollup@4.54.0: + rollup@4.59.0: dependencies: '@types/estree': 1.0.8 optionalDependencies: - '@rollup/rollup-android-arm-eabi': 4.54.0 - '@rollup/rollup-android-arm64': 4.54.0 - '@rollup/rollup-darwin-arm64': 4.54.0 - '@rollup/rollup-darwin-x64': 4.54.0 - '@rollup/rollup-freebsd-arm64': 4.54.0 - '@rollup/rollup-freebsd-x64': 4.54.0 - '@rollup/rollup-linux-arm-gnueabihf': 4.54.0 - '@rollup/rollup-linux-arm-musleabihf': 4.54.0 - '@rollup/rollup-linux-arm64-gnu': 4.54.0 - '@rollup/rollup-linux-arm64-musl': 4.54.0 - '@rollup/rollup-linux-loong64-gnu': 4.54.0 - '@rollup/rollup-linux-ppc64-gnu': 4.54.0 - '@rollup/rollup-linux-riscv64-gnu': 4.54.0 - '@rollup/rollup-linux-riscv64-musl': 4.54.0 - '@rollup/rollup-linux-s390x-gnu': 4.54.0 - '@rollup/rollup-linux-x64-gnu': 4.54.0 - '@rollup/rollup-linux-x64-musl': 4.54.0 - '@rollup/rollup-openharmony-arm64': 4.54.0 - '@rollup/rollup-win32-arm64-msvc': 4.54.0 - '@rollup/rollup-win32-ia32-msvc': 4.54.0 - '@rollup/rollup-win32-x64-gnu': 4.54.0 - '@rollup/rollup-win32-x64-msvc': 4.54.0 + '@rollup/rollup-android-arm-eabi': 4.59.0 + '@rollup/rollup-android-arm64': 4.59.0 + '@rollup/rollup-darwin-arm64': 4.59.0 + '@rollup/rollup-darwin-x64': 4.59.0 + '@rollup/rollup-freebsd-arm64': 4.59.0 + '@rollup/rollup-freebsd-x64': 4.59.0 + '@rollup/rollup-linux-arm-gnueabihf': 4.59.0 + '@rollup/rollup-linux-arm-musleabihf': 4.59.0 + '@rollup/rollup-linux-arm64-gnu': 4.59.0 + '@rollup/rollup-linux-arm64-musl': 4.59.0 + '@rollup/rollup-linux-loong64-gnu': 4.59.0 + '@rollup/rollup-linux-loong64-musl': 4.59.0 + '@rollup/rollup-linux-ppc64-gnu': 4.59.0 + '@rollup/rollup-linux-ppc64-musl': 4.59.0 + '@rollup/rollup-linux-riscv64-gnu': 4.59.0 + '@rollup/rollup-linux-riscv64-musl': 4.59.0 + '@rollup/rollup-linux-s390x-gnu': 4.59.0 + '@rollup/rollup-linux-x64-gnu': 4.59.0 + '@rollup/rollup-linux-x64-musl': 4.59.0 + '@rollup/rollup-openbsd-x64': 4.59.0 + '@rollup/rollup-openharmony-arm64': 4.59.0 + '@rollup/rollup-win32-arm64-msvc': 4.59.0 + '@rollup/rollup-win32-ia32-msvc': 4.59.0 + '@rollup/rollup-win32-x64-gnu': 4.59.0 + '@rollup/rollup-win32-x64-msvc': 4.59.0 fsevents: 2.3.3 - semver@7.7.3: {} + semver@7.7.4: {} shebang-command@2.0.0: dependencies: @@ -2203,8 +2105,6 @@ snapshots: std-env@3.10.0: {} - strip-json-comments@3.1.1: {} - sucrase@3.35.1: dependencies: '@jridgewell/gen-mapping': 0.3.13 @@ -2242,33 +2142,33 @@ snapshots: tree-kill@1.2.2: {} - ts-api-utils@2.1.0(typescript@5.9.3): + ts-api-utils@2.4.0(typescript@5.9.3): dependencies: typescript: 5.9.3 ts-interface-checker@0.1.13: {} - tsup@8.5.1(postcss@8.5.6)(typescript@5.9.3): + tsup@8.5.1(postcss@8.5.8)(typescript@5.9.3): dependencies: - bundle-require: 5.1.0(esbuild@0.27.2) + bundle-require: 5.1.0(esbuild@0.27.3) cac: 6.7.14 chokidar: 4.0.3 consola: 3.4.2 debug: 4.4.3 - esbuild: 0.27.2 + esbuild: 0.27.3 fix-dts-default-cjs-exports: 1.0.1 joycon: 3.1.1 picocolors: 1.1.1 - postcss-load-config: 6.0.1(postcss@8.5.6) + postcss-load-config: 6.0.1(postcss@8.5.8) resolve-from: 5.0.0 - rollup: 4.54.0 + rollup: 4.59.0 source-map: 0.7.6 sucrase: 3.35.1 tinyexec: 0.3.2 tinyglobby: 0.2.15 tree-kill: 1.2.2 optionalDependencies: - postcss: 8.5.6 + postcss: 8.5.8 typescript: 5.9.3 transitivePeerDependencies: - jiti @@ -2282,35 +2182,35 @@ snapshots: typescript@5.9.3: {} - ufo@1.6.1: {} + ufo@1.6.3: {} - undici-types@7.16.0: {} + undici-types@7.18.2: {} uri-js@4.4.1: dependencies: punycode: 2.3.1 - vite@7.3.0(@types/node@25.0.3): + vite@7.3.1(@types/node@25.4.0): dependencies: - esbuild: 0.27.2 + esbuild: 0.27.3 fdir: 6.5.0(picomatch@4.0.3) picomatch: 4.0.3 - postcss: 8.5.6 - rollup: 4.54.0 + postcss: 8.5.8 + rollup: 4.59.0 tinyglobby: 0.2.15 optionalDependencies: - '@types/node': 25.0.3 + '@types/node': 25.4.0 fsevents: 2.3.3 - vitest@4.0.16(@types/node@25.0.3): + vitest@4.0.18(@types/node@25.4.0): dependencies: - '@vitest/expect': 4.0.16 - '@vitest/mocker': 4.0.16(vite@7.3.0(@types/node@25.0.3)) - '@vitest/pretty-format': 4.0.16 - '@vitest/runner': 4.0.16 - '@vitest/snapshot': 4.0.16 - '@vitest/spy': 4.0.16 - '@vitest/utils': 4.0.16 + '@vitest/expect': 4.0.18 + '@vitest/mocker': 4.0.18(vite@7.3.1(@types/node@25.4.0)) + '@vitest/pretty-format': 4.0.18 + '@vitest/runner': 4.0.18 + '@vitest/snapshot': 4.0.18 + '@vitest/spy': 4.0.18 + '@vitest/utils': 4.0.18 es-module-lexer: 1.7.0 expect-type: 1.3.0 magic-string: 0.30.21 @@ -2322,10 +2222,10 @@ snapshots: tinyexec: 1.0.2 tinyglobby: 0.2.15 tinyrainbow: 3.0.3 - vite: 7.3.0(@types/node@25.0.3) + vite: 7.3.1(@types/node@25.4.0) why-is-node-running: 2.3.0 optionalDependencies: - '@types/node': 25.0.3 + '@types/node': 25.4.0 transitivePeerDependencies: - jiti - less diff --git a/web/.env.example b/web/.env.example index df4e725c51..079c3bdeef 100644 --- a/web/.env.example +++ b/web/.env.example @@ -6,15 +6,24 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED NEXT_PUBLIC_BASE_PATH= # The base URL of console application, refers to the Console base URL of WEB service if console domain is # different from api or web app domain. -# example: http://cloud.dify.ai/console/api +# example: https://cloud.dify.ai/console/api NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api # The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from # console or api domain. -# example: http://udify.app/api +# example: https://udify.app/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api # When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. NEXT_PUBLIC_COOKIE_DOMAIN= +# Dev-only Hono proxy targets. +# The frontend keeps requesting http://localhost:5001 directly, +# the proxy server will forward the request to the target server, +# so that you don't need to run a separate backend server and use online API in development. +HONO_PROXY_HOST=127.0.0.1 +HONO_PROXY_PORT=5001 +HONO_CONSOLE_API_PROXY_TARGET= +HONO_PUBLIC_API_PROXY_TARGET= + # The API PREFIX for MARKETPLACE NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 # The URL for MARKETPLACE diff --git a/web/.husky/pre-commit b/web/.husky/pre-commit index dd4140b47e..3f25de256f 100644 --- a/web/.husky/pre-commit +++ b/web/.husky/pre-commit @@ -6,6 +6,20 @@ files=$(git diff --cached --name-only) api_modified=false web_modified=false +skip_web_checks=false + +git_path() { + git rev-parse --git-path "$1" +} + +if [ -f "$(git_path MERGE_HEAD)" ] || \ + [ -f "$(git_path CHERRY_PICK_HEAD)" ] || \ + [ -f "$(git_path REVERT_HEAD)" ] || \ + [ -f "$(git_path SQUASH_MSG)" ] || \ + [ -d "$(git_path rebase-merge)" ] || \ + [ -d "$(git_path rebase-apply)" ]; then + skip_web_checks=true +fi for file in $files do @@ -43,6 +57,11 @@ if $api_modified; then fi if $web_modified; then + if $skip_web_checks; then + echo "Git operation in progress, skipping web checks" + exit 0 + fi + echo "Running ESLint on web module" if git diff --cached --quiet -- 'web/**/*.ts' 'web/**/*.tsx'; then diff --git a/web/.nvmrc b/web/.nvmrc index a45fd52cc5..2bd5a0a98a 100644 --- a/web/.nvmrc +++ b/web/.nvmrc @@ -1 +1 @@ -24 +22 diff --git a/web/AGENTS.md b/web/AGENTS.md index 5dd41b8a3c..97f74441a7 100644 --- a/web/AGENTS.md +++ b/web/AGENTS.md @@ -2,6 +2,16 @@ - Refer to the `./docs/test.md` and `./docs/lint.md` for detailed frontend workflow instructions. +## Overlay Components (Mandatory) + +- `./docs/overlay-migration.md` is the source of truth for overlay-related work. +- In new or modified code, use only overlay primitives from `@/app/components/base/ui/*`. +- Do not introduce deprecated overlay imports from `@/app/components/base/*`; when touching legacy callers, prefer migrating them and keep the allowlist shrinking (never expanding). + +## Query & Mutation (Mandatory) + +- `frontend-query-mutation` is the source of truth for Dify frontend contracts, query and mutation call-site patterns, conditional queries, invalidation, and mutation error handling. + ## Automated Test Generation - Use `./docs/test.md` as the canonical instruction set for generating frontend automated tests. diff --git a/web/Dockerfile b/web/Dockerfile index fe4ea1a579..b54bae706c 100644 --- a/web/Dockerfile +++ b/web/Dockerfile @@ -1,5 +1,5 @@ # base image -FROM node:24-alpine AS base +FROM node:22-alpine AS base LABEL maintainer="takatost@gmail.com" # if you located in China, you can use aliyun mirror to speed up @@ -35,7 +35,7 @@ COPY --from=packages /app/web/ . COPY . . ENV NODE_OPTIONS="--max-old-space-size=4096" -RUN pnpm build:docker +RUN pnpm build # production stage diff --git a/web/README.md b/web/README.md index 1e57e7c6a9..14ca856875 100644 --- a/web/README.md +++ b/web/README.md @@ -1,6 +1,6 @@ # Dify Frontend -This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next-app`](https://github.com/vercel/next.js/tree/canary/packages/create-next-app). +This is a [Next.js] project, but you can dev with [vinext]. ## Getting Started @@ -8,8 +8,11 @@ This is a [Next.js](https://nextjs.org/) project bootstrapped with [`create-next Before starting the web frontend service, please make sure the following environment is ready. -- [Node.js](https://nodejs.org) -- [pnpm](https://pnpm.io) +- [Node.js] +- [pnpm] + +You can also use [Vite+] with the corresponding `vp` commands. +For example, use `vp install` instead of `pnpm install` and `vp test` instead of `pnpm run test`. > [!TIP] > It is recommended to install and enable Corepack to manage package manager versions automatically: @@ -19,7 +22,7 @@ Before starting the web frontend service, please make sure the following environ > corepack enable > ``` > -> Learn more: [Corepack](https://github.com/nodejs/corepack#readme) +> Learn more: [Corepack] First, install the dependencies: @@ -27,31 +30,14 @@ First, install the dependencies: pnpm install ``` -Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements: +Then, configure the environment variables. +Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. +Modify the values of these environment variables according to your requirements: ```bash cp .env.example .env.local ``` -```txt -# For production release, change this to PRODUCTION -NEXT_PUBLIC_DEPLOY_ENV=DEVELOPMENT -# The deployment edition, SELF_HOSTED -NEXT_PUBLIC_EDITION=SELF_HOSTED -# The base URL of console application, refers to the Console base URL of WEB service if console domain is -# different from api or web app domain. -# example: http://cloud.dify.ai/console/api -NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api -NEXT_PUBLIC_COOKIE_DOMAIN= -# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from -# console or api domain. -# example: http://udify.app/api -NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api - -# SENTRY -NEXT_PUBLIC_SENTRY_DSN= -``` - > [!IMPORTANT] > > 1. When the frontend and backend run on different subdomains, set NEXT_PUBLIC_COOKIE_DOMAIN=1. The frontend and backend must be under the same top-level domain in order to share authentication cookies. @@ -61,11 +47,16 @@ Finally, run the development server: ```bash pnpm run dev +# or if you are using vinext which provides a better development experience +pnpm run dev:vinext +# (optional) start the dev proxy server so that you can use online API in development +pnpm run dev:proxy ``` -Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. +Open with your browser to see the result. -You can start editing the file under folder `app`. The page auto-updates as you edit the file. +You can start editing the file under folder `app`. +The page auto-updates as you edit the file. ## Deploy @@ -91,7 +82,7 @@ pnpm run start --port=3001 --host=0.0.0.0 ## Storybook -This project uses [Storybook](https://storybook.js.org/) for UI component development. +This project uses [Storybook] for UI component development. To start the storybook server, run: @@ -99,19 +90,24 @@ To start the storybook server, run: pnpm storybook ``` -Open [http://localhost:6006](http://localhost:6006) with your browser to see the result. +Open with your browser to see the result. ## Lint Code If your IDE is VSCode, rename `web/.vscode/settings.example.json` to `web/.vscode/settings.json` for lint code setting. -Then follow the [Lint Documentation](./docs/lint.md) to lint the code. +Then follow the [Lint Documentation] to lint the code. ## Test -We use [Vitest](https://vitest.dev/) and [React Testing Library](https://testing-library.com/docs/react-testing-library/intro/) for Unit Testing. +We use [Vitest] and [React Testing Library] for Unit Testing. -**📖 Complete Testing Guide**: See [web/testing/testing.md](./testing/testing.md) for detailed testing specifications, best practices, and examples. +**📖 Complete Testing Guide**: See [web/docs/test.md] for detailed testing specifications, best practices, and examples. + +> [!IMPORTANT] +> As we are using Vite+, the `vitest` command is not available. +> Please make sure to run tests with `vp` commands. +> For example, use `npx vp test` instead of `npx vitest`. Run test: @@ -119,12 +115,17 @@ Run test: pnpm test ``` +> [!NOTE] +> Our test is not fully stable yet, and we are actively working on improving it. +> If you encounter test failures only in CI but not locally, please feel free to ignore them and report the issue to us. +> You can try to re-run the test in CI, and it may pass successfully. + ### Example Code If you are not familiar with writing tests, refer to: -- [classnames.spec.ts](./utils/classnames.spec.ts) - Utility function test example -- [index.spec.tsx](./app/components/base/button/index.spec.tsx) - Component test example +- [classnames.spec.ts] - Utility function test example +- [index.spec.tsx] - Component test example ### Analyze Component Complexity @@ -134,7 +135,7 @@ Before writing tests, use the script to analyze component complexity: pnpm analyze-component app/components/your-component/index.tsx ``` -This will help you determine the testing strategy. See [web/testing/testing.md](./testing/testing.md) for details. +This will help you determine the testing strategy. See [web/testing/testing.md] for details. ## Documentation @@ -142,4 +143,19 @@ Visit to view the full documentation. ## Community -The Dify community can be found on [Discord community](https://discord.gg/5AEfbxcd9k), where you can ask questions, voice ideas, and share your projects. +The Dify community can be found on [Discord community], where you can ask questions, voice ideas, and share your projects. + +[Corepack]: https://github.com/nodejs/corepack#readme +[Discord community]: https://discord.gg/5AEfbxcd9k +[Lint Documentation]: ./docs/lint.md +[Next.js]: https://nextjs.org +[Node.js]: https://nodejs.org +[React Testing Library]: https://testing-library.com/docs/react-testing-library/intro +[Storybook]: https://storybook.js.org +[Vite+]: https://viteplus.dev +[Vitest]: https://vitest.dev +[classnames.spec.ts]: ./utils/classnames.spec.ts +[index.spec.tsx]: ./app/components/base/button/index.spec.tsx +[pnpm]: https://pnpm.io +[vinext]: https://github.com/cloudflare/vinext +[web/docs/test.md]: ./docs/test.md diff --git a/web/__tests__/apps/app-card-operations-flow.test.tsx b/web/__tests__/apps/app-card-operations-flow.test.tsx index 1aa6706b82..c5766878a1 100644 --- a/web/__tests__/apps/app-card-operations-flow.test.tsx +++ b/web/__tests__/apps/app-card-operations-flow.test.tsx @@ -14,7 +14,7 @@ import { fireEvent, render, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import AppCard from '@/app/components/apps/app-card' import { AccessMode } from '@/models/access-control' -import { deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { exportAppConfig, updateAppInfo } from '@/service/apps' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -26,8 +26,10 @@ let mockSystemFeatures = { const mockRouterPush = vi.fn() const mockNotify = vi.fn() const mockOnPlanInfoChanged = vi.fn() +const mockDeleteAppMutation = vi.fn().mockResolvedValue(undefined) +let mockDeleteMutationPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, }), @@ -55,7 +57,7 @@ vi.mock('@headlessui/react', async () => { } }) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { @@ -117,6 +119,13 @@ vi.mock('@/service/tag', () => ({ fetchTagList: vi.fn().mockResolvedValue([]), })) +vi.mock('@/service/use-apps', () => ({ + useDeleteAppMutation: () => ({ + mutateAsync: mockDeleteAppMutation, + isPending: mockDeleteMutationPending, + }), +})) + vi.mock('@/service/apps', () => ({ deleteApp: vi.fn().mockResolvedValue({}), updateAppInfo: vi.fn().mockResolvedValue({}), @@ -270,6 +279,7 @@ const renderAppCard = (app?: Partial) => { describe('App Card Operations Flow', () => { beforeEach(() => { vi.clearAllMocks() + mockDeleteMutationPending = false mockIsCurrentWorkspaceEditor = true mockSystemFeatures = { branding: { enabled: false }, @@ -341,7 +351,7 @@ describe('App Card Operations Flow', () => { fireEvent.click(confirmBtn) await waitFor(() => { - expect(deleteApp).toHaveBeenCalledWith('app-to-delete') + expect(mockDeleteAppMutation).toHaveBeenCalledWith('app-to-delete') }) } } diff --git a/web/__tests__/apps/app-list-browsing-flow.test.tsx b/web/__tests__/apps/app-list-browsing-flow.test.tsx index 1c046f5dd0..1be7e56086 100644 --- a/web/__tests__/apps/app-list-browsing-flow.test.tsx +++ b/web/__tests__/apps/app-list-browsing-flow.test.tsx @@ -8,11 +8,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -38,7 +38,7 @@ let mockShowTagManagementModal = false const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -46,7 +46,7 @@ vi.mock('next/navigation', () => ({ useSearchParams: () => new URLSearchParams(), })) -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (_loader: () => Promise<{ default: React.ComponentType }>) => { const LazyComponent = (props: Record) => { return
@@ -104,6 +104,10 @@ vi.mock('@/service/use-apps', () => ({ error: mockError, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ @@ -161,10 +165,9 @@ const createPage = (apps: App[], hasMore = false, page = 1): AppListResponse => }) const renderList = (searchParams?: Record) => { - return render( - - - , + return renderWithNuqs( + , + { searchParams }, ) } @@ -209,11 +212,7 @@ describe('App List Browsing Flow', () => { it('should transition from loading to content when data loads', () => { mockIsLoading = true - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() const skeletonCards = document.querySelectorAll('.animate-pulse') expect(skeletonCards.length).toBeGreaterThan(0) @@ -224,11 +223,7 @@ describe('App List Browsing Flow', () => { createMockApp({ id: 'app-1', name: 'Loaded App' }), ])] - rerender( - - - , - ) + rerender() expect(screen.getByText('Loaded App')).toBeInTheDocument() }) @@ -424,17 +419,9 @@ describe('App List Browsing Flow', () => { it('should call refetch when controlRefreshList increments', () => { mockPages = [createPage([createMockApp()])] - const { rerender } = render( - - - , - ) + const { rerender } = renderWithNuqs() - rerender( - - - , - ) + rerender() expect(mockRefetch).toHaveBeenCalled() }) diff --git a/web/__tests__/apps/create-app-flow.test.tsx b/web/__tests__/apps/create-app-flow.test.tsx index 556c973b06..bc1f7a3a06 100644 --- a/web/__tests__/apps/create-app-flow.test.tsx +++ b/web/__tests__/apps/create-app-flow.test.tsx @@ -9,11 +9,11 @@ */ import type { AppListResponse } from '@/models/app' import type { App } from '@/types/app' -import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import { NuqsTestingAdapter } from 'nuqs/adapters/testing' +import { fireEvent, screen, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import List from '@/app/components/apps/list' import { AccessMode } from '@/models/access-control' +import { renderWithNuqs } from '@/test/nuqs-testing' import { AppModeEnum } from '@/types/app' let mockIsCurrentWorkspaceEditor = true @@ -35,7 +35,7 @@ const mockRouterPush = vi.fn() const mockRouterReplace = vi.fn() const mockOnPlanInfoChanged = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush, replace: mockRouterReplace, @@ -91,6 +91,10 @@ vi.mock('@/service/use-apps', () => ({ error: null, refetch: mockRefetch, }), + useDeleteAppMutation: () => ({ + mutateAsync: vi.fn(), + isPending: false, + }), })) vi.mock('@/hooks/use-pay', () => ({ @@ -113,7 +117,7 @@ vi.mock('ahooks', async () => { }) // Mock dynamically loaded modals with test stubs -vi.mock('next/dynamic', () => ({ +vi.mock('@/next/dynamic', () => ({ default: (loader: () => Promise<{ default: React.ComponentType }>) => { let Component: React.ComponentType> | null = null loader().then((mod) => { @@ -214,11 +218,7 @@ const createPage = (apps: App[]): AppListResponse => ({ }) const renderList = () => { - return render( - - - , - ) + return renderWithNuqs() } describe('Create App Flow', () => { diff --git a/web/__tests__/billing/billing-integration.test.tsx b/web/__tests__/billing/billing-integration.test.tsx index 4891760df4..64d358cbe6 100644 --- a/web/__tests__/billing/billing-integration.test.tsx +++ b/web/__tests__/billing/billing-integration.test.tsx @@ -64,7 +64,7 @@ vi.mock('@/service/use-education', () => ({ // ─── Navigation mocks ─────────────────────────────────────────────────────── const mockRouterPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx index e01d9250fd..0c1efbe1af 100644 --- a/web/__tests__/billing/cloud-plan-payment-flow.test.tsx +++ b/web/__tests__/billing/cloud-plan-payment-flow.test.tsx @@ -11,6 +11,7 @@ import type { BasicPlan } from '@/app/components/billing/type' import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { ALL_PLANS } from '@/app/components/billing/config' import { PlanRange } from '@/app/components/billing/pricing/plan-switcher/plan-range-switcher' import CloudPlanItem from '@/app/components/billing/pricing/plans/cloud-plan-item' @@ -21,7 +22,6 @@ let mockAppCtx: Record = {} const mockFetchSubscriptionUrls = vi.fn() const mockInvoices = vi.fn() const mockOpenAsyncWindow = vi.fn() -const mockToastNotify = vi.fn() // ─── Context mocks ─────────────────────────────────────────────────────────── vi.mock('@/context/app-context', () => ({ @@ -49,12 +49,8 @@ vi.mock('@/hooks/use-async-window-open', () => ({ useAsyncWindowOpen: () => mockOpenAsyncWindow, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -82,12 +78,15 @@ const renderCloudPlanItem = ({ canPay = true, }: RenderCloudPlanItemOptions = {}) => { return render( - , + <> + + + , ) } @@ -96,6 +95,7 @@ describe('Cloud Plan Payment Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.dismiss() setupAppContext() mockFetchSubscriptionUrls.mockResolvedValue({ url: 'https://pay.example.com/checkout' }) mockInvoices.mockResolvedValue({ url: 'https://billing.example.com/invoices' }) @@ -283,11 +283,7 @@ describe('Cloud Plan Payment Flow', () => { await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ - type: 'error', - }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should not proceed with payment expect(mockFetchSubscriptionUrls).not.toHaveBeenCalled() diff --git a/web/__tests__/billing/education-verification-flow.test.tsx b/web/__tests__/billing/education-verification-flow.test.tsx index 8c35cd9a8c..707f1d690a 100644 --- a/web/__tests__/billing/education-verification-flow.test.tsx +++ b/web/__tests__/billing/education-verification-flow.test.tsx @@ -63,7 +63,7 @@ vi.mock('@/service/use-billing', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: mockRouterPush }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), diff --git a/web/__tests__/billing/partner-stack-flow.test.tsx b/web/__tests__/billing/partner-stack-flow.test.tsx index 4f265478cd..fe642ac70b 100644 --- a/web/__tests__/billing/partner-stack-flow.test.tsx +++ b/web/__tests__/billing/partner-stack-flow.test.tsx @@ -18,7 +18,7 @@ let mockSearchParams = new URLSearchParams() const mockMutateAsync = vi.fn() // ─── Module mocks ──────────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => mockSearchParams, useRouter: () => ({ push: vi.fn() }), usePathname: () => '/', diff --git a/web/__tests__/billing/pricing-modal-flow.test.tsx b/web/__tests__/billing/pricing-modal-flow.test.tsx index 6b8fb57f83..2ec7298618 100644 --- a/web/__tests__/billing/pricing-modal-flow.test.tsx +++ b/web/__tests__/billing/pricing-modal-flow.test.tsx @@ -51,7 +51,7 @@ vi.mock('@/hooks/use-async-window-open', () => ({ })) // ─── Navigation mocks ─────────────────────────────────────────────────────── -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/billing', useSearchParams: () => new URLSearchParams(), @@ -295,24 +295,7 @@ describe('Pricing Modal Flow', () => { }) }) - // ─── 6. Close Handling ─────────────────────────────────────────────────── - describe('Close handling', () => { - it('should call onCancel when pressing ESC key', () => { - render() - - // ahooks useKeyPress listens on document for keydown events - document.dispatchEvent(new KeyboardEvent('keydown', { - key: 'Escape', - code: 'Escape', - keyCode: 27, - bubbles: true, - })) - - expect(onCancel).toHaveBeenCalledTimes(1) - }) - }) - - // ─── 7. Pricing URL ───────────────────────────────────────────────────── + // ─── 6. Pricing URL ───────────────────────────────────────────────────── describe('Pricing page URL', () => { it('should render pricing link with correct URL', () => { render() diff --git a/web/__tests__/billing/self-hosted-plan-flow.test.tsx b/web/__tests__/billing/self-hosted-plan-flow.test.tsx index 810d36da8a..a3386d0092 100644 --- a/web/__tests__/billing/self-hosted-plan-flow.test.tsx +++ b/web/__tests__/billing/self-hosted-plan-flow.test.tsx @@ -10,12 +10,12 @@ import { cleanup, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' import * as React from 'react' +import { toast, ToastHost } from '@/app/components/base/ui/toast' import { contactSalesUrl, getStartedWithCommunityUrl, getWithPremiumUrl } from '@/app/components/billing/config' import SelfHostedPlanItem from '@/app/components/billing/pricing/plans/self-hosted-plan-item' import { SelfHostedPlan } from '@/app/components/billing/type' let mockAppCtx: Record = {} -const mockToastNotify = vi.fn() const originalLocation = window.location let assignedHref = '' @@ -40,10 +40,6 @@ vi.mock('@/app/components/base/icons/src/public/billing', () => ({ AwsMarketplaceDark: () => , })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: (args: unknown) => mockToastNotify(args) }, -})) - vi.mock('@/app/components/billing/pricing/plans/self-hosted-plan-item/list', () => ({ default: ({ plan }: { plan: string }) => (
Features
@@ -57,10 +53,20 @@ const setupAppContext = (overrides: Record = {}) => { } } +const renderSelfHostedPlanItem = (plan: SelfHostedPlan) => { + return render( + <> + + + , + ) +} + describe('Self-Hosted Plan Flow', () => { beforeEach(() => { vi.clearAllMocks() cleanup() + toast.dismiss() setupAppContext() // Mock window.location with minimal getter/setter (Location props are non-enumerable) @@ -85,14 +91,14 @@ describe('Self-Hosted Plan Flow', () => { // ─── 1. Plan Rendering ────────────────────────────────────────────────── describe('Plan rendering', () => { it('should render community plan with name and description', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByText(/plans\.community\.name/i)).toBeInTheDocument() expect(screen.getByText(/plans\.community\.description/i)).toBeInTheDocument() }) it('should render premium plan with cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.name/i)).toBeInTheDocument() expect(screen.getByTestId('icon-azure')).toBeInTheDocument() @@ -100,39 +106,39 @@ describe('Self-Hosted Plan Flow', () => { }) it('should render enterprise plan without cloud provider icons', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByText(/plans\.enterprise\.name/i)).toBeInTheDocument() expect(screen.queryByTestId('icon-azure')).not.toBeInTheDocument() }) it('should not show price tip for community (free) plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.queryByText(/plans\.community\.priceTip/i)).not.toBeInTheDocument() }) it('should show price tip for premium plan', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByText(/plans\.premium\.priceTip/i)).toBeInTheDocument() }) it('should render features list for each plan', () => { - const { unmount: unmount1 } = render() + const { unmount: unmount1 } = renderSelfHostedPlanItem(SelfHostedPlan.community) expect(screen.getByTestId('self-hosted-list-community')).toBeInTheDocument() unmount1() - const { unmount: unmount2 } = render() + const { unmount: unmount2 } = renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('self-hosted-list-premium')).toBeInTheDocument() unmount2() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) expect(screen.getByTestId('self-hosted-list-enterprise')).toBeInTheDocument() }) it('should show AWS marketplace icon for premium plan button', () => { - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) expect(screen.getByTestId('icon-aws-light')).toBeInTheDocument() }) @@ -142,7 +148,7 @@ describe('Self-Hosted Plan Flow', () => { describe('Navigation flow', () => { it('should redirect to GitHub when clicking community plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) @@ -152,7 +158,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to AWS Marketplace when clicking premium plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) @@ -162,7 +168,7 @@ describe('Self-Hosted Plan Flow', () => { it('should redirect to Typeform when clicking enterprise plan button', async () => { const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) @@ -176,15 +182,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks community button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.community) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) // Should NOT redirect expect(assignedHref).toBe('') @@ -193,15 +197,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks premium button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.premium) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) @@ -209,15 +211,13 @@ describe('Self-Hosted Plan Flow', () => { it('should show error toast when non-manager clicks enterprise button', async () => { setupAppContext({ isCurrentWorkspaceManager: false }) const user = userEvent.setup() - render() + renderSelfHostedPlanItem(SelfHostedPlan.enterprise) const button = screen.getByRole('button') await user.click(button) await waitFor(() => { - expect(mockToastNotify).toHaveBeenCalledWith( - expect.objectContaining({ type: 'error' }), - ) + expect(screen.getByText('billing.buyPermissionDeniedTip')).toBeInTheDocument() }) expect(assignedHref).toBe('') }) diff --git a/web/__tests__/datasets/dataset-settings-flow.test.tsx b/web/__tests__/datasets/dataset-settings-flow.test.tsx index 607cd8c2d5..b4a5e78326 100644 --- a/web/__tests__/datasets/dataset-settings-flow.test.tsx +++ b/web/__tests__/datasets/dataset-settings-flow.test.tsx @@ -19,6 +19,10 @@ import { RETRIEVE_METHOD } from '@/types/app' // --- Mocks --- +const { mockToastError } = vi.hoisted(() => ({ + mockToastError: vi.fn(), +})) + const mockMutateDatasets = vi.fn() const mockInvalidDatasetList = vi.fn() const mockUpdateDatasetSetting = vi.fn().mockResolvedValue({}) @@ -55,8 +59,11 @@ vi.mock('@/app/components/datasets/common/check-rerank-model', () => ({ isReRankModelSelected: () => true, })) -vi.mock('@/app/components/base/toast', () => ({ - default: { notify: vi.fn() }, +vi.mock('@/app/components/base/ui/toast', () => ({ + toast: { + error: mockToastError, + success: vi.fn(), + }, })) // --- Dataset factory --- @@ -311,7 +318,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { describe('Form Submission Validation → All Fields Together', () => { it('should reject empty name on save', async () => { - const Toast = await import('@/app/components/base/toast') + const { toast } = await import('@/app/components/base/ui/toast') const { result } = renderHook(() => useFormState()) act(() => { @@ -322,10 +329,7 @@ describe('Dataset Settings Flow - Cross-Module Configuration Cascade', () => { await result.current.handleSave() }) - expect(Toast.default.notify).toHaveBeenCalledWith({ - type: 'error', - message: expect.any(String), - }) + expect(toast.error).toHaveBeenCalledWith(expect.any(String)) expect(mockUpdateDatasetSetting).not.toHaveBeenCalled() }) diff --git a/web/__tests__/datasets/document-management.test.tsx b/web/__tests__/datasets/document-management.test.tsx index 3b901ccee2..f9d80520ed 100644 --- a/web/__tests__/datasets/document-management.test.tsx +++ b/web/__tests__/datasets/document-management.test.tsx @@ -7,12 +7,13 @@ */ import type { SimpleDocumentDetail } from '@/models/datasets' -import { act, renderHook } from '@testing-library/react' +import { act, renderHook, waitFor } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' import { DataSourceType } from '@/models/datasets' +import { renderHookWithNuqs } from '@/test/nuqs-testing' const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSearchParams: () => new URLSearchParams(''), useRouter: () => ({ push: mockPush }), usePathname: () => '/datasets/ds-1/documents', @@ -28,12 +29,16 @@ const { useDocumentSort } = await import( const { useDocumentSelection } = await import( '@/app/components/datasets/documents/components/document-list/hooks/use-document-selection', ) -const { default: useDocumentListQueryState } = await import( +const { useDocumentListQueryState } = await import( '@/app/components/datasets/documents/hooks/use-document-list-query-state', ) type LocalDoc = SimpleDocumentDetail & { percent?: number } +const renderQueryStateHook = (searchParams = '') => { + return renderHookWithNuqs(() => useDocumentListQueryState(), { searchParams }) +} + const createDoc = (overrides?: Partial): LocalDoc => ({ id: `doc-${Math.random().toString(36).slice(2, 8)}`, name: 'test-doc.txt', @@ -85,7 +90,7 @@ describe('Document Management Flow', () => { describe('URL-based Query State', () => { it('should parse default query from empty URL params', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + const { result } = renderQueryStateHook() expect(result.current.query).toEqual({ page: 1, @@ -96,107 +101,85 @@ describe('Document Management Flow', () => { }) }) - it('should update query and push to router', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should update keyword query with replace history', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.updateQuery({ keyword: 'test', page: 2 }) }) - expect(mockPush).toHaveBeenCalled() - // The push call should contain the updated query params - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toContain('keyword=test') - expect(pushUrl).toContain('page=2') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.get('keyword')).toBe('test') + expect(update.searchParams.get('page')).toBe('2') }) - it('should reset query to defaults', () => { - const { result } = renderHook(() => useDocumentListQueryState()) + it('should reset query to defaults', async () => { + const { result, onUrlUpdate } = renderQueryStateHook() act(() => { result.current.resetQuery() }) - expect(mockPush).toHaveBeenCalled() - // Default query omits default values from URL - const pushUrl = mockPush.mock.calls[0][0] as string - expect(pushUrl).toBe('/datasets/ds-1/documents') + await waitFor(() => expect(onUrlUpdate).toHaveBeenCalled()) + const update = onUrlUpdate.mock.calls[onUrlUpdate.mock.calls.length - 1][0] + expect(update.options.history).toBe('replace') + expect(update.searchParams.toString()).toBe('') }) }) describe('Document Sort Integration', () => { - it('should return documents unsorted when no sort field set', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt', word_count: 300 }), - createDoc({ id: 'doc-2', name: 'Apple.txt', word_count: 100 }), - createDoc({ id: 'doc-3', name: 'Cherry.txt', word_count: 200 }), - ] - + it('should derive sort field and order from remote sort value', () => { const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange: vi.fn(), })) - expect(result.current.sortField).toBeNull() - expect(result.current.sortedDocuments).toHaveLength(3) + expect(result.current.sortField).toBe('created_at') + expect(result.current.sortOrder).toBe('desc') }) - it('should sort by name descending', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'Banana.txt' }), - createDoc({ id: 'doc-2', name: 'Apple.txt' }), - createDoc({ id: 'doc-3', name: 'Cherry.txt' }), - ] - + it('should call remote sort change with descending sort for a new field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', remoteSortValue: '-created_at', + onRemoteSortChange, })) act(() => { - result.current.handleSort('name') + result.current.handleSort('hit_count') }) - expect(result.current.sortField).toBe('name') - expect(result.current.sortOrder).toBe('desc') - const names = result.current.sortedDocuments.map(d => d.name) - expect(names).toEqual(['Cherry.txt', 'Banana.txt', 'Apple.txt']) + expect(onRemoteSortChange).toHaveBeenCalledWith('-hit_count') }) - it('should toggle sort order on same field click', () => { - const docs = [createDoc({ id: 'doc-1', name: 'A.txt' }), createDoc({ id: 'doc-2', name: 'B.txt' })] - + it('should toggle descending to ascending when clicking active field', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: '', - remoteSortValue: '-created_at', + remoteSortValue: '-hit_count', + onRemoteSortChange, })) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('desc') + act(() => { + result.current.handleSort('hit_count') + }) - act(() => result.current.handleSort('name')) - expect(result.current.sortOrder).toBe('asc') + expect(onRemoteSortChange).toHaveBeenCalledWith('hit_count') }) - it('should filter by status before sorting', () => { - const docs = [ - createDoc({ id: 'doc-1', name: 'A.txt', display_status: 'available' }), - createDoc({ id: 'doc-2', name: 'B.txt', display_status: 'error' }), - createDoc({ id: 'doc-3', name: 'C.txt', display_status: 'available' }), - ] - + it('should ignore null sort field updates', () => { + const onRemoteSortChange = vi.fn() const { result } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: 'available', remoteSortValue: '-created_at', + onRemoteSortChange, })) - // Only 'available' documents should remain - expect(result.current.sortedDocuments).toHaveLength(2) - expect(result.current.sortedDocuments.every(d => d.display_status === 'available')).toBe(true) + act(() => { + result.current.handleSort(null) + }) + + expect(onRemoteSortChange).not.toHaveBeenCalled() }) }) @@ -309,14 +292,13 @@ describe('Document Management Flow', () => { describe('Cross-Module: Query State → Sort → Selection Pipeline', () => { it('should maintain consistent default state across all hooks', () => { const docs = [createDoc({ id: 'doc-1' })] - const { result: queryResult } = renderHook(() => useDocumentListQueryState()) + const { result: queryResult } = renderQueryStateHook() const { result: sortResult } = renderHook(() => useDocumentSort({ - documents: docs, - statusFilterValue: queryResult.current.query.status, remoteSortValue: queryResult.current.query.sort, + onRemoteSortChange: vi.fn(), })) const { result: selResult } = renderHook(() => useDocumentSelection({ - documents: sortResult.current.sortedDocuments, + documents: docs, selectedIds: [], onSelectedIdChange: vi.fn(), })) @@ -325,8 +307,9 @@ describe('Document Management Flow', () => { expect(queryResult.current.query.sort).toBe('-created_at') expect(queryResult.current.query.status).toBe('all') - // Sort inherits 'all' status → no filtering applied - expect(sortResult.current.sortedDocuments).toHaveLength(1) + // Sort state is derived from URL default sort. + expect(sortResult.current.sortField).toBe('created_at') + expect(sortResult.current.sortOrder).toBe('desc') // Selection starts empty expect(selResult.current.isAllSelected).toBe(false) diff --git a/web/__tests__/document-detail-navigation-fix.test.tsx b/web/__tests__/document-detail-navigation-fix.test.tsx index 6b348cd15b..5cb115830e 100644 --- a/web/__tests__/document-detail-navigation-fix.test.tsx +++ b/web/__tests__/document-detail-navigation-fix.test.tsx @@ -7,12 +7,12 @@ import type { Mock } from 'vitest' */ import { fireEvent, render, screen } from '@testing-library/react' -import { useRouter } from 'next/navigation' +import { useRouter } from '@/next/navigation' import { useDocumentDetail, useDocumentMetadata } from '@/service/knowledge/use-document' // Mock Next.js router const mockPush = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: vi.fn(() => ({ push: mockPush, })), diff --git a/web/__tests__/embedded-user-id-auth.test.tsx b/web/__tests__/embedded-user-id-auth.test.tsx index 9231ac6199..cacd6331f8 100644 --- a/web/__tests__/embedded-user-id-auth.test.tsx +++ b/web/__tests__/embedded-user-id-auth.test.tsx @@ -8,7 +8,7 @@ const replaceMock = vi.fn() const backMock = vi.fn() const useSearchParamsMock = vi.fn(() => new URLSearchParams()) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/test-app'), useRouter: vi.fn(() => ({ replace: replaceMock, diff --git a/web/__tests__/embedded-user-id-store.test.tsx b/web/__tests__/embedded-user-id-store.test.tsx index 901218e76b..04597ccfeb 100644 --- a/web/__tests__/embedded-user-id-store.test.tsx +++ b/web/__tests__/embedded-user-id-store.test.tsx @@ -4,7 +4,7 @@ import WebAppStoreProvider, { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: vi.fn(() => '/chatbot/sample-app'), useSearchParams: vi.fn(() => { const params = new URLSearchParams() diff --git a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx index e2c18bcc4f..64dd5321ac 100644 --- a/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx +++ b/web/__tests__/explore/sidebar-lifecycle-flow.test.tsx @@ -7,19 +7,23 @@ */ import type { InstalledApp } from '@/models/explore' import { fireEvent, render, screen, waitFor } from '@testing-library/react' -import Toast from '@/app/components/base/toast' import SideBar from '@/app/components/explore/sidebar' import { MediaType } from '@/hooks/use-breakpoints' import { AppModeEnum } from '@/types/app' +const { mockToastSuccess } = vi.hoisted(() => ({ + mockToastSuccess: vi.fn(), +})) + let mockMediaType: string = MediaType.pc const mockSegments = ['apps'] const mockPush = vi.fn() const mockUninstall = vi.fn() const mockUpdatePinStatus = vi.fn() let mockInstalledApps: InstalledApp[] = [] +let mockIsUninstallPending = false -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegments: () => mockSegments, useRouter: () => ({ push: mockPush, @@ -42,12 +46,24 @@ vi.mock('@/service/use-explore', () => ({ }), useUninstallApp: () => ({ mutateAsync: mockUninstall, + isPending: mockIsUninstallPending, }), useUpdateAppPinStatus: () => ({ mutateAsync: mockUpdatePinStatus, }), })) +vi.mock('@/app/components/base/ui/toast', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + toast: { + ...actual.toast, + success: mockToastSuccess, + }, + } +}) + const createInstalledApp = (overrides: Partial = {}): InstalledApp => ({ id: overrides.id ?? 'app-1', uninstallable: overrides.uninstallable ?? false, @@ -74,7 +90,7 @@ describe('Sidebar Lifecycle Flow', () => { vi.clearAllMocks() mockMediaType = MediaType.pc mockInstalledApps = [] - vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() })) + mockIsUninstallPending = false }) describe('Pin / Unpin / Delete Flow', () => { @@ -91,9 +107,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: true }) - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) // Step 2: Simulate refetch returning pinned state, then unpin @@ -110,9 +124,7 @@ describe('Sidebar Lifecycle Flow', () => { await waitFor(() => { expect(mockUpdatePinStatus).toHaveBeenCalledWith({ appId: 'app-1', isPinned: false }) - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - })) + expect(mockToastSuccess).toHaveBeenCalled() }) }) @@ -136,10 +148,7 @@ describe('Sidebar Lifecycle Flow', () => { // Step 4: Uninstall API called and success toast shown await waitFor(() => { expect(mockUninstall).toHaveBeenCalledWith('app-1') - expect(Toast.notify).toHaveBeenCalledWith(expect.objectContaining({ - type: 'success', - message: 'common.api.remove', - })) + expect(mockToastSuccess).toHaveBeenCalledWith('common.api.remove') }) }) diff --git a/web/__tests__/plugins/plugin-card-rendering.test.tsx b/web/__tests__/plugins/plugin-card-rendering.test.tsx index 7abcb01b49..5bd7f0c8bf 100644 --- a/web/__tests__/plugins/plugin-card-rendering.test.tsx +++ b/web/__tests__/plugins/plugin-card-rendering.test.tsx @@ -8,6 +8,8 @@ import { cleanup, render, screen } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' +let mockTheme = 'light' + vi.mock('#i18n', () => ({ useTranslation: () => ({ t: (key: string) => key, @@ -19,16 +21,16 @@ vi.mock('@/context/i18n', () => ({ })) vi.mock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'light' }), + default: () => ({ theme: mockTheme }), })) vi.mock('@/i18n-config', () => ({ renderI18nObject: (obj: Record, locale: string) => obj[locale] || obj.en_US || '', })) -vi.mock('@/types/app', () => ({ - Theme: { dark: 'dark', light: 'light' }, -})) +vi.mock('@/types/app', async () => { + return vi.importActual('@/types/app') +}) vi.mock('@/utils/classnames', () => ({ cn: (...args: unknown[]) => args.filter(a => typeof a === 'string' && a).join(' '), @@ -100,6 +102,7 @@ type CardPayload = Parameters[0]['payload'] describe('Plugin Card Rendering Integration', () => { beforeEach(() => { cleanup() + mockTheme = 'light' }) const makePayload = (overrides = {}) => ({ @@ -194,9 +197,7 @@ describe('Plugin Card Rendering Integration', () => { }) it('uses dark icon when theme is dark and icon_dark is provided', () => { - vi.doMock('@/hooks/use-theme', () => ({ - default: () => ({ theme: 'dark' }), - })) + mockTheme = 'dark' const payload = makePayload({ icon: 'https://example.com/icon-light.png', @@ -204,7 +205,7 @@ describe('Plugin Card Rendering Integration', () => { }) render() - expect(screen.getByTestId('card-icon')).toBeInTheDocument() + expect(screen.getByTestId('card-icon')).toHaveTextContent('https://example.com/icon-dark.png') }) it('shows loading placeholder when isLoading is true', () => { diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 7ceca4535b..8edb6705d4 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -22,33 +22,6 @@ vi.mock('@/service/plugins', () => ({ checkTaskStatus: vi.fn(), })) -vi.mock('@/utils/semver', () => ({ - compareVersion: (a: string, b: string) => { - const parse = (v: string) => v.replace(/^v/, '').split('.').map(Number) - const [aMajor, aMinor = 0, aPatch = 0] = parse(a) - const [bMajor, bMinor = 0, bPatch = 0] = parse(b) - if (aMajor !== bMajor) - return aMajor > bMajor ? 1 : -1 - if (aMinor !== bMinor) - return aMinor > bMinor ? 1 : -1 - if (aPatch !== bPatch) - return aPatch > bPatch ? 1 : -1 - return 0 - }, - getLatestVersion: (versions: string[]) => { - return versions.sort((a, b) => { - const parse = (v: string) => v.replace(/^v/, '').split('.').map(Number) - const [aMaj, aMin = 0, aPat = 0] = parse(a) - const [bMaj, bMin = 0, bPat = 0] = parse(b) - if (aMaj !== bMaj) - return bMaj - aMaj - if (aMin !== bMin) - return bMin - aMin - return bPat - aPat - })[0] - }, -})) - const { useGitHubReleases, useGitHubUpload } = await import( '@/app/components/plugins/install-plugin/hooks', ) diff --git a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts index 578552840d..dc5ab3fc86 100644 --- a/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts +++ b/web/__tests__/rag-pipeline/dsl-export-import-flow.test.ts @@ -19,7 +19,7 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('@/app/components/base/toast', () => ({ +vi.mock('@/app/components/base/toast/context', () => ({ useToastContext: () => ({ notify: mockNotify }), })) diff --git a/web/__tests__/share/text-generation-index-flow.test.tsx b/web/__tests__/share/text-generation-index-flow.test.tsx new file mode 100644 index 0000000000..2fec054a47 --- /dev/null +++ b/web/__tests__/share/text-generation-index-flow.test.tsx @@ -0,0 +1,235 @@ +import type { AccessMode } from '@/models/access-control' +import { fireEvent, render, screen, waitFor } from '@testing-library/react' +import * as React from 'react' +import TextGeneration from '@/app/components/share/text-generation' + +const useSearchParamsMock = vi.fn(() => new URLSearchParams()) + +vi.mock('@/next/navigation', () => ({ + useSearchParams: () => useSearchParamsMock(), +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: vi.fn(() => 'pc'), + MediaType: { pc: 'pc', pad: 'pad', mobile: 'mobile' }, +})) + +vi.mock('@/hooks/use-app-favicon', () => ({ + useAppFavicon: vi.fn(), +})) + +vi.mock('@/hooks/use-document-title', () => ({ + default: vi.fn(), +})) + +vi.mock('@/i18n-config/client', () => ({ + changeLanguage: vi.fn(() => Promise.resolve()), +})) + +vi.mock('@/app/components/share/text-generation/run-once', () => ({ + default: ({ + inputs, + onInputsChange, + onSend, + runControl, + }: { + inputs: Record + onInputsChange: (inputs: Record) => void + onSend: () => void + runControl?: { isStopping: boolean } | null + }) => ( +
+ {String(inputs.name ?? '')} + + + {runControl ? 'stop-ready' : 'idle'} +
+ ), +})) + +vi.mock('@/app/components/share/text-generation/run-batch', () => ({ + default: ({ onSend }: { onSend: (data: string[][]) => void }) => ( + + ), +})) + +vi.mock('@/app/components/app/text-generate/saved-items', () => ({ + default: ({ list }: { list: { id: string }[] }) =>
{list.length}
, +})) + +vi.mock('@/app/components/share/text-generation/menu-dropdown', () => ({ + default: () =>
, +})) + +vi.mock('@/app/components/share/text-generation/result', () => { + const MockResult = ({ + isCallBatchAPI, + onRunControlChange, + onRunStart, + taskId, + }: { + isCallBatchAPI: boolean + onRunControlChange?: (control: { onStop: () => void, isStopping: boolean } | null) => void + onRunStart: () => void + taskId?: number + }) => { + const runControlRef = React.useRef(false) + + React.useEffect(() => { + onRunStart() + }, [onRunStart]) + + React.useEffect(() => { + if (!isCallBatchAPI && !runControlRef.current) { + runControlRef.current = true + onRunControlChange?.({ onStop: vi.fn(), isStopping: false }) + } + }, [isCallBatchAPI, onRunControlChange]) + + return
+ } + + return { + default: MockResult, + } +}) + +const fetchSavedMessageMock = vi.fn() + +vi.mock('@/service/share', async () => { + const actual = await vi.importActual('@/service/share') + return { + ...actual, + fetchSavedMessage: (...args: Parameters) => fetchSavedMessageMock(...args), + removeMessage: vi.fn(), + saveMessage: vi.fn(), + } +}) + +const mockSystemFeatures = { + branding: { + enabled: false, + workspace_logo: null, + }, +} + +const mockWebAppState = { + appInfo: { + app_id: 'app-123', + site: { + title: 'Text Generation', + description: 'Share description', + default_language: 'en-US', + icon_type: 'emoji', + icon: 'robot', + icon_background: '#fff', + icon_url: '', + }, + custom_config: { + remove_webapp_brand: false, + replace_webapp_logo: '', + }, + }, + appParams: { + user_input_form: [ + { + 'text-input': { + label: 'Name', + variable: 'name', + required: true, + max_length: 48, + default: '', + hide: false, + }, + }, + ], + more_like_this: { + enabled: true, + }, + file_upload: { + enabled: false, + number_limits: 2, + detail: 'low', + allowed_upload_methods: ['local_file'], + }, + text_to_speech: { + enabled: true, + }, + system_parameters: { + image_file_size_limit: 10, + }, + }, + webAppAccessMode: 'public' as AccessMode, +} + +vi.mock('@/context/global-public-context', () => ({ + useGlobalPublicStore: (selector: (state: { systemFeatures: typeof mockSystemFeatures }) => unknown) => + selector({ systemFeatures: mockSystemFeatures }), +})) + +vi.mock('@/context/web-app-context', () => ({ + useWebAppStore: (selector: (state: typeof mockWebAppState) => unknown) => selector(mockWebAppState), +})) + +describe('TextGeneration', () => { + beforeEach(() => { + vi.clearAllMocks() + useSearchParamsMock.mockReturnValue(new URLSearchParams()) + fetchSavedMessageMock.mockResolvedValue({ + data: [{ id: 'saved-1' }, { id: 'saved-2' }], + }) + }) + + it('should switch between create, batch, and saved tabs after app state loads', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('') + + fireEvent.click(screen.getByRole('button', { name: 'change-inputs' })) + await waitFor(() => { + expect(screen.getByTestId('run-once-input-name')).toHaveTextContent('Gamma') + }) + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + expect(screen.getByRole('button', { name: 'run-batch' })).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-saved')) + expect(screen.getByTestId('saved-items-mock')).toHaveTextContent('2') + + fireEvent.click(screen.getByTestId('tab-header-item-create')) + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + it('should wire single-run stop control and clear it when batch execution starts', async () => { + render() + + await waitFor(() => { + expect(screen.getByTestId('run-once-mock')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByRole('button', { name: 'run-once' })) + await waitFor(() => { + expect(screen.getByText('stop-ready')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-single')).toBeInTheDocument() + + fireEvent.click(screen.getByTestId('tab-header-item-batch')) + fireEvent.click(screen.getByRole('button', { name: 'run-batch' })) + await waitFor(() => { + expect(screen.getByText('idle')).toBeInTheDocument() + }) + expect(screen.getByTestId('result-task-1')).toBeInTheDocument() + expect(screen.getByTestId('result-task-2')).toBeInTheDocument() + }) +}) diff --git a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx index 4e7fa4952b..dbefb1fdc3 100644 --- a/web/__tests__/tools/tool-browsing-and-filtering.test.tsx +++ b/web/__tests__/tools/tool-browsing-and-filtering.test.tsx @@ -28,9 +28,13 @@ vi.mock('react-i18next', () => ({ }), })) -vi.mock('nuqs', () => ({ - useQueryState: () => ['builtin', vi.fn()], -})) +vi.mock('nuqs', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + useQueryState: () => ['builtin', vi.fn()], + } +}) vi.mock('@/context/global-public-context', () => ({ useGlobalPublicStore: () => ({ enable_marketplace: false }), @@ -212,6 +216,12 @@ vi.mock('@/app/components/tools/marketplace', () => ({ default: () => null, })) +vi.mock('@/app/components/tools/marketplace/hooks', () => ({ + useMarketplace: () => ({ + handleScroll: vi.fn(), + }), +})) + vi.mock('@/app/components/tools/mcp', () => ({ default: () =>
MCP List
, })) diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx index 470f4477fa..0c87fd1a4d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/layout-main.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import type { NavIcon } from '@/app/components/app-sidebar/navLink' +import type { NavIcon } from '@/app/components/app-sidebar/nav-link' import type { App } from '@/types/app' import { RiDashboard2Fill, @@ -13,8 +13,6 @@ import { RiTerminalWindowLine, } from '@remixicon/react' import { useUnmount } from 'ahooks' -import dynamic from 'next/dynamic' -import { usePathname, useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -26,6 +24,8 @@ import { useStore as useTagStore } from '@/app/components/base/tag-management/st import { useAppContext } from '@/context/app-context' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import dynamic from '@/next/dynamic' +import { usePathname, useRouter } from '@/next/navigation' import { fetchAppDetailDirect } from '@/service/apps' import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx index f07b2932c9..8c1df8d63d 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view.tsx @@ -12,7 +12,7 @@ import AppCard from '@/app/components/app/overview/app-card' import TriggerCard from '@/app/components/app/overview/trigger-card' import { useStore as useAppStore } from '@/app/components/app/store' import Loading from '@/app/components/base/loading' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import MCPServiceCard from '@/app/components/tools/mcp/mcp-service-card' import { isTriggerNode } from '@/app/components/workflow/types' import { NEED_REFRESH_APP_LIST_KEY } from '@/config' diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts index 221ba2808f..71f5b009d3 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/config.ts @@ -5,7 +5,7 @@ export const docURL = { [TracingProvider.phoenix]: 'https://docs.arize.com/phoenix', [TracingProvider.langSmith]: 'https://docs.smith.langchain.com/', [TracingProvider.langfuse]: 'https://docs.langfuse.com', - [TracingProvider.opik]: 'https://www.comet.com/docs/opik/tracing/integrations/dify#setup-instructions', + [TracingProvider.opik]: 'https://www.comet.com/docs/opik/integrations/dify', [TracingProvider.weave]: 'https://weave-docs.wandb.ai/', [TracingProvider.aliyun]: 'https://help.aliyun.com/zh/arms/tracing-analysis/untitled-document-1750672984680', [TracingProvider.mlflow]: 'https://mlflow.org/docs/latest/genai/', diff --git a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx index 5e7d98d191..4201d11490 100644 --- a/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx +++ b/web/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/tracing/panel.tsx @@ -7,7 +7,6 @@ import { RiEqualizer2Line, } from '@remixicon/react' import { useBoolean } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -17,6 +16,7 @@ import Loading from '@/app/components/base/loading' import Toast from '@/app/components/base/toast' import Indicator from '@/app/components/header/indicator' import { useAppContext } from '@/context/app-context' +import { usePathname } from '@/next/navigation' import { fetchTracingConfig as doFetchTracingConfig, fetchTracingStatus, updateTracingStatus } from '@/service/apps' import { cn } from '@/utils/classnames' import ConfigButton from './config-button' diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index 4f3f724e62..730b76ee19 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -9,7 +9,6 @@ import { RiFocus2Fill, RiFocus2Line, } from '@remixicon/react' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -23,6 +22,7 @@ import DatasetDetailContext from '@/context/dataset-detail' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useDocumentTitle from '@/hooks/use-document-title' +import { usePathname } from '@/next/navigation' import { useDatasetDetail, useDatasetRelatedApps } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' diff --git a/web/app/(commonLayout)/datasets/layout.spec.tsx b/web/app/(commonLayout)/datasets/layout.spec.tsx index 5873f344d0..9c01cffba8 100644 --- a/web/app/(commonLayout)/datasets/layout.spec.tsx +++ b/web/app/(commonLayout)/datasets/layout.spec.tsx @@ -6,7 +6,7 @@ import DatasetsLayout from './layout' const mockReplace = vi.fn() const mockUseAppContext = vi.fn() -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/(commonLayout)/datasets/layout.tsx b/web/app/(commonLayout)/datasets/layout.tsx index b543c42570..a465f8222b 100644 --- a/web/app/(commonLayout)/datasets/layout.tsx +++ b/web/app/(commonLayout)/datasets/layout.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' import { ExternalApiPanelProvider } from '@/context/external-api-panel-context' import { ExternalKnowledgeApiProvider } from '@/context/external-knowledge-api-context' +import { useRouter } from '@/next/navigation' export default function DatasetsLayout({ children }: { children: React.ReactNode }) { const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator, currentWorkspace, isLoadingCurrentWorkspace } = useAppContext() diff --git a/web/app/(commonLayout)/education-apply/page.tsx b/web/app/(commonLayout)/education-apply/page.tsx index fce6fe1d5d..44ba5ee8ad 100644 --- a/web/app/(commonLayout)/education-apply/page.tsx +++ b/web/app/(commonLayout)/education-apply/page.tsx @@ -1,15 +1,15 @@ 'use client' -import { - useRouter, - useSearchParams, -} from 'next/navigation' import { useEffect, useMemo, } from 'react' import EducationApplyPage from '@/app/education-apply/education-apply-page' import { useProviderContext } from '@/context/provider-context' +import { + useRouter, + useSearchParams, +} from '@/next/navigation' export default function EducationApply() { const router = useRouter() diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index abd5dd96fd..5ac39f1e39 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -1,6 +1,7 @@ import type { ReactNode } from 'react' import * as React from 'react' import { AppInitializer } from '@/app/components/app-initializer' +import InSiteMessageNotification from '@/app/components/app/in-site-message/notification' import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import Zendesk from '@/app/components/base/zendesk' @@ -8,10 +9,10 @@ import GotoAnything from '@/app/components/goto-anything' import Header from '@/app/components/header' import HeaderWrapper from '@/app/components/header/header-wrapper' import ReadmePanel from '@/app/components/plugins/readme-panel' -import { AppContextProvider } from '@/context/app-context' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { AppContextProvider } from '@/context/app-context-provider' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import PartnerStack from '../components/billing/partner-stack' import Splash from '../components/splash' import RoleRouteGuard from './role-route-guard' @@ -32,6 +33,7 @@ const Layout = ({ children }: { children: ReactNode }) => { {children} + diff --git a/web/app/(commonLayout)/role-route-guard.spec.tsx b/web/app/(commonLayout)/role-route-guard.spec.tsx index 87bf9be8af..ca1550f0b8 100644 --- a/web/app/(commonLayout)/role-route-guard.spec.tsx +++ b/web/app/(commonLayout)/role-route-guard.spec.tsx @@ -6,7 +6,7 @@ const mockReplace = vi.fn() const mockUseAppContext = vi.fn() let mockPathname = '/apps' -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ usePathname: () => mockPathname, useRouter: () => ({ replace: mockReplace, diff --git a/web/app/(commonLayout)/role-route-guard.tsx b/web/app/(commonLayout)/role-route-guard.tsx index 1c42be9d15..483dfef095 100644 --- a/web/app/(commonLayout)/role-route-guard.tsx +++ b/web/app/(commonLayout)/role-route-guard.tsx @@ -1,10 +1,10 @@ 'use client' import type { ReactNode } from 'react' -import { usePathname, useRouter } from 'next/navigation' import { useEffect } from 'react' import Loading from '@/app/components/base/loading' import { useAppContext } from '@/context/app-context' +import { usePathname, useRouter } from '@/next/navigation' const datasetOperatorRedirectRoutes = ['/apps', '/app', '/explore', '/tools'] as const diff --git a/web/app/(humanInputLayout)/form/[token]/form.tsx b/web/app/(humanInputLayout)/form/[token]/form.tsx index d027ef8b7d..035da6be8a 100644 --- a/web/app/(humanInputLayout)/form/[token]/form.tsx +++ b/web/app/(humanInputLayout)/form/[token]/form.tsx @@ -9,7 +9,6 @@ import { RiInformation2Fill, } from '@remixicon/react' import { produce } from 'immer' -import { useParams } from 'next/navigation' import * as React from 'react' import { useEffect, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -21,6 +20,7 @@ import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-inp import Loading from '@/app/components/base/loading' import DifyLogo from '@/app/components/base/logo/dify-logo' import useDocumentTitle from '@/hooks/use-document-title' +import { useParams } from '@/next/navigation' import { useGetHumanInputForm, useSubmitHumanInputForm } from '@/service/use-share' import { cn } from '@/utils/classnames' diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx index c874990448..9f956a8501 100644 --- a/web/app/(shareLayout)/components/authenticated-layout.tsx +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -1,12 +1,12 @@ 'use client' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { useGetUserCanAccessApp } from '@/service/access-control' import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' import { webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx index a2b847f74f..402005752d 100644 --- a/web/app/(shareLayout)/components/splash.tsx +++ b/web/app/(shareLayout)/components/splash.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC, PropsWithChildren } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport, webAppLoginStatus, webAppLogout } from '@/service/webapp-auth' diff --git a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx index fbf45259e5..1d1c6518fe 100644 --- a/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/check-code/page.tsx @@ -1,14 +1,14 @@ 'use client' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Countdown from '@/app/components/signin/countdown' - import { useLocale } from '@/context/i18n' + +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppResetPasswordCode, verifyWebAppResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -24,17 +24,11 @@ export default function CheckCode() { const verify = async () => { try { if (!code.trim()) { - Toast.notify({ - type: 'error', - message: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - Toast.notify({ - type: 'error', - message: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-reset-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/page.tsx index 9b9a853cdd..0cdfb4ec11 100644 --- a/web/app/(shareLayout)/webapp-reset-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/page.tsx @@ -1,18 +1,18 @@ 'use client' import { RiArrowLeftLine, RiLockPasswordLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' - import { useLocale } from '@/context/i18n' import useDocumentTitle from '@/hooks/use-document-title' + +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendResetPasswordCode } from '@/service/common' export default function CheckCode() { @@ -27,15 +27,12 @@ export default function CheckCode() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) @@ -48,16 +45,10 @@ export default function CheckCode() { router.push(`/webapp-reset-password/check-code?${params.toString()}`) } else if (res.code === 'account_not_found') { - Toast.notify({ - type: 'error', - message: t('error.registrationNotAllowed', { ns: 'login' }), - }) + toast.error(t('error.registrationNotAllowed', { ns: 'login' })) } else { - Toast.notify({ - type: 'error', - message: res.data, - }) + toast.error(res.data) } } catch (error) { diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 9f59e8f9eb..bc8f651d17 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -1,13 +1,13 @@ 'use client' import { RiCheckboxCircleFill } from '@remixicon/react' import { useCountDown } from 'ahooks' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { validPassword } from '@/config' +import { useRouter, useSearchParams } from '@/next/navigation' import { changeWebAppPasswordWithToken } from '@/service/common' import { cn } from '@/utils/classnames' @@ -24,10 +24,7 @@ const ChangePasswordForm = () => { const [showConfirmPassword, setShowConfirmPassword] = useState(false) const showErrorMessage = useCallback((message: string) => { - Toast.notify({ - type: 'error', - message, - }) + toast.error(message) }, []) const getSignInUrl = () => { diff --git a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx index afea9d668b..f209ad9e5c 100644 --- a/web/app/(shareLayout)/webapp-signin/check-code/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/check-code/page.tsx @@ -1,15 +1,15 @@ 'use client' import type { FormEvent } from 'react' import { RiArrowLeftLine, RiMailSendFill } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import Countdown from '@/app/components/signin/countdown' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode, webAppEmailLoginWithCode } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' @@ -43,24 +43,15 @@ export default function CheckCode() { try { const appCode = getAppCodeFromRedirectUrl() if (!code.trim()) { - Toast.notify({ - type: 'error', - message: t('checkCode.emptyCode', { ns: 'login' }), - }) + toast.error(t('checkCode.emptyCode', { ns: 'login' })) return } if (!/\d{6}/.test(code)) { - Toast.notify({ - type: 'error', - message: t('checkCode.invalidCode', { ns: 'login' }), - }) + toast.error(t('checkCode.invalidCode', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx index 0776df036d..9b4a369908 100644 --- a/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/external-member-sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect } from 'react' import AppUnavailable from '@/app/components/base/app-unavailable' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchWebOAuth2SSOUrl, fetchWebOIDCSSOUrl, fetchWebSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' @@ -17,10 +17,7 @@ const ExternalMemberSSOAuth = () => { const redirectUrl = searchParams.get('redirect_url') const showErrorToast = (message: string) => { - Toast.notify({ - type: 'error', - message, - }) + toast.error(message) } const getAppCodeFromRedirectUrl = useCallback(() => { diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx index 5aa9d9f141..fbd6b216df 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-code-auth.tsx @@ -1,13 +1,13 @@ import { noop } from 'es-toolkit/function' -import { useRouter, useSearchParams } from 'next/navigation' import { useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { COUNT_DOWN_KEY, COUNT_DOWN_TIME_MS } from '@/app/components/signin/countdown' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' +import { useRouter, useSearchParams } from '@/next/navigation' import { sendWebAppEMailLoginCode } from '@/service/common' export default function MailAndCodeAuth() { @@ -22,15 +22,12 @@ export default function MailAndCodeAuth() { const handleGetEMailVerificationCode = async () => { try { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } setIsLoading(true) diff --git a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx index e49559401d..1e9355e7ba 100644 --- a/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/mail-and-password-auth.tsx @@ -1,15 +1,15 @@ 'use client' import { noop } from 'es-toolkit/function' -import Link from 'next/link' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { emailRegex } from '@/config' import { useLocale } from '@/context/i18n' import { useWebAppStore } from '@/context/web-app-context' +import Link from '@/next/link' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogin } from '@/service/common' import { fetchAccessToken } from '@/service/share' import { setWebAppAccessToken, setWebAppPassport } from '@/service/webapp-auth' @@ -46,26 +46,20 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut const appCode = getAppCodeFromRedirectUrl() const handleEmailPasswordLogin = async () => { if (!email) { - Toast.notify({ type: 'error', message: t('error.emailEmpty', { ns: 'login' }) }) + toast.error(t('error.emailEmpty', { ns: 'login' })) return } if (!emailRegex.test(email)) { - Toast.notify({ - type: 'error', - message: t('error.emailInValid', { ns: 'login' }), - }) + toast.error(t('error.emailInValid', { ns: 'login' })) return } if (!password?.trim()) { - Toast.notify({ type: 'error', message: t('error.passwordEmpty', { ns: 'login' }) }) + toast.error(t('error.passwordEmpty', { ns: 'login' })) return } if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: t('error.redirectUrlMissing', { ns: 'login' }), - }) + toast.error(t('error.redirectUrlMissing', { ns: 'login' })) return } try { @@ -94,15 +88,12 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut router.replace(decodeURIComponent(redirectUrl)) } else { - Toast.notify({ - type: 'error', - message: res.data, - }) + toast.error(res.data) } } catch (e: any) { if (e.code === 'authentication_failed') - Toast.notify({ type: 'error', message: e.message }) + toast.error(e.message) } finally { setIsLoading(false) diff --git a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx index d8f3854868..3178c638cc 100644 --- a/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx +++ b/web/app/(shareLayout)/webapp-signin/components/sso-auth.tsx @@ -1,11 +1,11 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { Lock01 } from '@/app/components/base/icons/src/vender/solid/security' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' +import { useRouter, useSearchParams } from '@/next/navigation' import { fetchMembersOAuth2SSOUrl, fetchMembersOIDCSSOUrl, fetchMembersSAMLSSOUrl } from '@/service/share' import { SSOProtocol } from '@/types/feature' @@ -37,10 +37,7 @@ const SSOAuth: FC = ({ const handleSSOLogin = () => { const appCode = getAppCodeFromRedirectUrl() if (!redirectUrl || !appCode) { - Toast.notify({ - type: 'error', - message: 'invalid redirect URL or app code', - }) + toast.error(t('error.invalidRedirectUrlOrAppCode', { ns: 'login' })) return } setIsLoading(true) @@ -66,10 +63,7 @@ const SSOAuth: FC = ({ }) } else { - Toast.notify({ - type: 'error', - message: 'invalid SSO protocol', - }) + toast.error(t('error.invalidSSOProtocol', { ns: 'login' })) setIsLoading(false) } } diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index b15145346f..7ee08d66ae 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -1,12 +1,12 @@ 'use client' import { RiContractLine, RiDoorLockLine, RiErrorWarningFill } from '@remixicon/react' -import Link from 'next/link' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Loading from '@/app/components/base/loading' import { IS_CE_EDITION } from '@/config' import { useGlobalPublicStore } from '@/context/global-public-context' +import Link from '@/next/link' import { LicenseStatus } from '@/types/feature' import { cn } from '@/utils/classnames' import MailAndCodeAuth from './components/mail-and-code-auth' diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index b3ad1d48a6..a5c2528cc7 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -1,6 +1,5 @@ 'use client' import type { FC } from 'react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' @@ -8,6 +7,7 @@ import AppUnavailable from '@/app/components/base/app-unavailable' import { useGlobalPublicStore } from '@/context/global-public-context' import { useWebAppStore } from '@/context/web-app-context' import { AccessMode } from '@/models/access-control' +import { useRouter, useSearchParams } from '@/next/navigation' import { webAppLogout } from '@/service/webapp-auth' import ExternalMemberSsoAuth from './components/external-member-sso-auth' import NormalForm from './normalForm' diff --git a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx index 49b59704b1..3fc677d8d8 100644 --- a/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx +++ b/web/app/account/(commonLayout)/account-page/AvatarWithEdit.tsx @@ -11,12 +11,12 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import ImageInput from '@/app/components/base/app-icon-picker/ImageInput' import getCroppedImg from '@/app/components/base/app-icon-picker/utils' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { DISABLE_UPLOAD_IMAGE_AS_ICON } from '@/config' import { updateUserProfile } from '@/service/common' @@ -103,7 +103,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { <>
- setOnAvatarError(x)} /> + setOnAvatarError(status === 'error')} />
{ @@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => { isShow={isShowDeleteConfirm} onClose={() => setIsShowDeleteConfirm(false)} > -
{t('avatar.deleteTitle', { ns: 'common' })}
+
{t('avatar.deleteTitle', { ns: 'common' })}

{t('avatar.deleteDescription', { ns: 'common' })}

diff --git a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx index 87ca6a689c..f0dfd4f12f 100644 --- a/web/app/account/(commonLayout)/account-page/email-change-modal.tsx +++ b/web/app/account/(commonLayout)/account-page/email-change-modal.tsx @@ -1,7 +1,6 @@ import type { ResponseError } from '@/service/fetch' import { RiCloseLine } from '@remixicon/react' import { noop } from 'es-toolkit/function' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useState } from 'react' import { Trans, useTranslation } from 'react-i18next' @@ -9,7 +8,8 @@ import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' +import { useRouter } from '@/next/navigation' import { checkEmailExisted, resetEmail, @@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
{step === STEP.start && ( <> -
{t('account.changeEmail.title', { ns: 'common' })}
+
{t('account.changeEmail.title', { ns: 'common' })}
-
{t('account.changeEmail.authTip', { ns: 'common' })}
-
+
{t('account.changeEmail.authTip', { ns: 'common' })}
+
}} + components={{ email: }} values={{ email }} />
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyOrigin && ( <> -
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
+
{t('account.changeEmail.verifyEmail', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
)} {step === STEP.newEmail && ( <> -
{t('account.changeEmail.newEmail', { ns: 'common' })}
+
{t('account.changeEmail.newEmail', { ns: 'common' })}
-
{t('account.changeEmail.content3', { ns: 'common' })}
+
{t('account.changeEmail.content3', { ns: 'common' })}
-
{t('account.changeEmail.emailLabel', { ns: 'common' })}
+
{t('account.changeEmail.emailLabel', { ns: 'common' })}
{ destructive={newEmailExited || unAvailableEmail} /> {newEmailExited && ( -
{t('account.changeEmail.existingEmail', { ns: 'common' })}
+
{t('account.changeEmail.existingEmail', { ns: 'common' })}
)} {unAvailableEmail && ( -
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
+
{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}
)}
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => { )} {step === STEP.verifyNew && ( <> -
{t('account.changeEmail.verifyNew', { ns: 'common' })}
+
{t('account.changeEmail.verifyNew', { ns: 'common' })}
-
+
}} + components={{ email: }} values={{ email: mail }} />
-
{t('account.changeEmail.codeLabel', { ns: 'common' })}
+
{t('account.changeEmail.codeLabel', { ns: 'common' })}
{ {t('operation.cancel', { ns: 'common' })}
-
+
{t('account.changeEmail.resendTip', { ns: 'common' })} {time > 0 && ( {t('account.changeEmail.resendCount', { ns: 'common', count: time })} )} {!time && ( - {t('account.changeEmail.resend', { ns: 'common' })} + {t('account.changeEmail.resend', { ns: 'common' })} )}
diff --git a/web/app/account/(commonLayout)/account-page/index.tsx b/web/app/account/(commonLayout)/account-page/index.tsx index f01efc002c..9a104619da 100644 --- a/web/app/account/(commonLayout)/account-page/index.tsx +++ b/web/app/account/(commonLayout)/account-page/index.tsx @@ -4,6 +4,7 @@ import type { App } from '@/types/app' import { RiGraduationCapFill, } from '@remixicon/react' +import { useQueryClient } from '@tanstack/react-query' import { useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -12,14 +13,14 @@ import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Modal from '@/app/components/base/modal' import PremiumBadge from '@/app/components/base/premium-badge' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import Collapse from '@/app/components/header/account-setting/collapse' import { IS_CE_EDITION, validPassword } from '@/config' -import { useAppContext } from '@/context/app-context' import { useGlobalPublicStore } from '@/context/global-public-context' import { useProviderContext } from '@/context/provider-context' import { updateUserProfile } from '@/service/common' import { useAppList } from '@/service/use-apps' +import { commonQueryKeys, useUserProfile } from '@/service/use-common' import DeleteAccount from '../delete-account' import AvatarWithEdit from './AvatarWithEdit' @@ -37,7 +38,10 @@ export default function AccountPage() { const { systemFeatures } = useGlobalPublicStore() const { data: appList } = useAppList({ page: 1, limit: 100, name: '' }) const apps = appList?.data || [] - const { mutateUserProfile, userProfile } = useAppContext() + const queryClient = useQueryClient() + const { data: userProfileResp } = useUserProfile() + const userProfile = userProfileResp?.profile + const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile }) const { isEducationAccount } = useProviderContext() const { notify } = useContext(ToastContext) const [editNameModalVisible, setEditNameModalVisible] = useState(false) @@ -53,6 +57,9 @@ export default function AccountPage() { const [showConfirmPassword, setShowConfirmPassword] = useState(false) const [showUpdateEmail, setShowUpdateEmail] = useState(false) + if (!userProfile) + return null + const handleEditName = () => { setEditNameModalVisible(true) setEditName(userProfile.name) @@ -138,7 +145,7 @@ export default function AccountPage() { imageUrl={icon_url} />
-
{item.name}
+
{item.name}
) } @@ -146,12 +153,12 @@ export default function AccountPage() { return ( <>
-

{t('account.myAccount', { ns: 'common' })}

+

{t('account.myAccount', { ns: 'common' })}

- +
-

+

{userProfile.name} {isEducationAccount && ( @@ -160,16 +167,16 @@ export default function AccountPage() { )}

-

{userProfile.email}

+

{userProfile.email}

{t('account.name', { ns: 'common' })}
-
+
{userProfile.name}
-
+
{t('operation.edit', { ns: 'common' })}
@@ -177,11 +184,11 @@ export default function AccountPage() {
{t('account.email', { ns: 'common' })}
-
+
{userProfile.email}
{systemFeatures.enable_change_email && ( -
setShowUpdateEmail(true)}> +
setShowUpdateEmail(true)}> {t('operation.change', { ns: 'common' })}
)} @@ -191,8 +198,8 @@ export default function AccountPage() { systemFeatures.enable_email_password_login && (
-
{t('account.password', { ns: 'common' })}
-
{t('account.passwordTip', { ns: 'common' })}
+
{t('account.password', { ns: 'common' })}
+
{t('account.passwordTip', { ns: 'common' })}
@@ -219,7 +226,7 @@ export default function AccountPage() { onClose={() => setEditNameModalVisible(false)} className="!w-[420px] !p-6" > -
{t('account.editName', { ns: 'common' })}
+
{t('account.editName', { ns: 'common' })}
{t('account.name', { ns: 'common' })}
-
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
+
{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}
{userProfile.is_password_set && ( <>
{t('account.currentPassword', { ns: 'common' })}
@@ -272,7 +279,7 @@ export default function AccountPage() {
)} -
+
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
@@ -291,7 +298,7 @@ export default function AccountPage() {
-
{t('account.confirmPassword', { ns: 'common' })}
+
{t('account.confirmPassword', { ns: 'common' })}
{ await logout() @@ -50,7 +54,7 @@ export default function AppSelector() { ${open && 'bg-components-panel-bg-blur'} `} > - +
{userProfile.email}
- +
diff --git a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx index 65a58c936e..e0f00189b2 100644 --- a/web/app/account/(commonLayout)/delete-account/components/check-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/check-email.tsx @@ -1,10 +1,10 @@ 'use client' -import Link from 'next/link' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import { useAppContext } from '@/context/app-context' +import Link from '@/next/link' import { useSendDeleteAccountEmail } from '../state' type DeleteAccountProps = { diff --git a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx index 67fea3c141..ae73d778f8 100644 --- a/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/feed-back.tsx @@ -1,5 +1,4 @@ 'use client' -import { useRouter } from 'next/navigation' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' @@ -7,6 +6,7 @@ import CustomDialog from '@/app/components/base/dialog' import Textarea from '@/app/components/base/textarea' import Toast from '@/app/components/base/toast' import { useAppContext } from '@/context/app-context' +import { useRouter } from '@/next/navigation' import { useLogout } from '@/service/use-common' import { useDeleteAccountFeedback } from '../state' diff --git a/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx index d7590c27f9..5d76f13f34 100644 --- a/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx +++ b/web/app/account/(commonLayout)/delete-account/components/verify-email.tsx @@ -1,10 +1,10 @@ 'use client' -import Link from 'next/link' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Countdown from '@/app/components/signin/countdown' +import Link from '@/next/link' import { useAccountDeleteStore, useConfirmDeleteAccount, useSendDeleteAccountEmail } from '../state' const CODE_EXP = /[A-Z\d]{6}/gi diff --git a/web/app/account/(commonLayout)/header.tsx b/web/app/account/(commonLayout)/header.tsx index bb58be87a8..5ef84a8f1e 100644 --- a/web/app/account/(commonLayout)/header.tsx +++ b/web/app/account/(commonLayout)/header.tsx @@ -1,11 +1,11 @@ 'use client' import { RiArrowRightUpLine, RiRobot2Line } from '@remixicon/react' -import { useRouter } from 'next/navigation' import { useCallback } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import DifyLogo from '@/app/components/base/logo/dify-logo' import { useGlobalPublicStore } from '@/context/global-public-context' +import { useRouter } from '@/next/navigation' import Avatar from './avatar' const Header = () => { diff --git a/web/app/account/(commonLayout)/layout.tsx b/web/app/account/(commonLayout)/layout.tsx index e4125015d9..8fdbd8a238 100644 --- a/web/app/account/(commonLayout)/layout.tsx +++ b/web/app/account/(commonLayout)/layout.tsx @@ -4,10 +4,10 @@ import { AppInitializer } from '@/app/components/app-initializer' import AmplitudeProvider from '@/app/components/base/amplitude' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' -import { AppContextProvider } from '@/context/app-context' -import { EventEmitterContextProvider } from '@/context/event-emitter' -import { ModalContextProvider } from '@/context/modal-context' -import { ProviderContextProvider } from '@/context/provider-context' +import { AppContextProvider } from '@/context/app-context-provider' +import { EventEmitterContextProvider } from '@/context/event-emitter-provider' +import { ModalContextProvider } from '@/context/modal-context-provider' +import { ProviderContextProvider } from '@/context/provider-context-provider' import Header from './header' const Layout = ({ children }: { children: ReactNode }) => { diff --git a/web/app/account/oauth/authorize/layout.tsx b/web/app/account/oauth/authorize/layout.tsx index 189971b16f..7f6b270b45 100644 --- a/web/app/account/oauth/authorize/layout.tsx +++ b/web/app/account/oauth/authorize/layout.tsx @@ -2,7 +2,7 @@ import Loading from '@/app/components/base/loading' import Header from '@/app/signin/_header' -import { AppContextProvider } from '@/context/app-context' +import { AppContextProvider } from '@/context/app-context-provider' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' import { useIsLogin } from '@/service/use-common' @@ -38,7 +38,7 @@ export default function SignInLayout({ children }: any) {
{systemFeatures.branding.enabled === false && ( -
+
© {' '} {new Date().getFullYear()} diff --git a/web/app/account/oauth/authorize/page.tsx b/web/app/account/oauth/authorize/page.tsx index d718e0941d..670f6ec593 100644 --- a/web/app/account/oauth/authorize/page.tsx +++ b/web/app/account/oauth/authorize/page.tsx @@ -7,18 +7,17 @@ import { RiMailLine, RiTranslate2, } from '@remixicon/react' -import { useRouter, useSearchParams } from 'next/navigation' import * as React from 'react' import { useEffect, useRef } from 'react' import { useTranslation } from 'react-i18next' -import Avatar from '@/app/components/base/avatar' +import { Avatar } from '@/app/components/base/avatar' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks' import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect' -import { useAppContext } from '@/context/app-context' -import { useIsLogin } from '@/service/use-common' +import { useRouter, useSearchParams } from '@/next/navigation' +import { useIsLogin, useUserProfile } from '@/service/use-common' import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth' function buildReturnUrl(pathname: string, search: string) { @@ -62,7 +61,8 @@ export default function OAuthAuthorize() { const searchParams = useSearchParams() const client_id = decodeURIComponent(searchParams.get('client_id') || '') const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '') - const { userProfile } = useAppContext() + const { data: userProfileResp } = useUserProfile() + const userProfile = userProfileResp?.profile const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri) const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp() const hasNotifiedRef = useRef(false) @@ -91,10 +91,7 @@ export default function OAuthAuthorize() { globalThis.location.href = url.toString() } catch (err: any) { - Toast.notify({ - type: 'error', - message: `${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`, - }) + toast.error(`${t('error.authorizeFailed', { ns: 'oauth' })}: ${err.message}`) } } @@ -102,11 +99,10 @@ export default function OAuthAuthorize() { const invalidParams = !client_id || !redirect_uri if ((invalidParams || isError) && !hasNotifiedRef.current) { hasNotifiedRef.current = true - Toast.notify({ - type: 'error', - message: invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), - duration: 0, - }) + toast.error( + invalidParams ? t('error.invalidParams', { ns: 'oauth' }) : t('error.authAppInfoFetchFailed', { ns: 'oauth' }), + { timeout: 0 }, + ) } }, [client_id, redirect_uri, isError]) @@ -138,7 +134,7 @@ export default function OAuthAuthorize() { {isLoggedIn && userProfile && (
- +
{userProfile.name}
{userProfile.email}
diff --git a/web/app/activate/activateForm.tsx b/web/app/activate/activateForm.tsx index 421b816652..418d3b8bb1 100644 --- a/web/app/activate/activateForm.tsx +++ b/web/app/activate/activateForm.tsx @@ -1,11 +1,11 @@ 'use client' -import { useRouter, useSearchParams } from 'next/navigation' import { useEffect } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Loading from '@/app/components/base/loading' - import useDocumentTitle from '@/hooks/use-document-title' + +import { useRouter, useSearchParams } from '@/next/navigation' import { useInvitationCheck } from '@/service/use-common' import { cn } from '@/utils/classnames' diff --git a/web/app/components/browser-initializer.spec.ts b/web/app/components/__tests__/browser-initializer.spec.ts similarity index 100% rename from web/app/components/browser-initializer.spec.ts rename to web/app/components/__tests__/browser-initializer.spec.ts diff --git a/web/app/components/app-initializer.tsx b/web/app/components/app-initializer.tsx index e4cd10175a..e08ece6666 100644 --- a/web/app/components/app-initializer.tsx +++ b/web/app/components/app-initializer.tsx @@ -2,13 +2,13 @@ import type { ReactNode } from 'react' import Cookies from 'js-cookie' -import { usePathname, useRouter, useSearchParams } from 'next/navigation' import { parseAsBoolean, useQueryState } from 'nuqs' import { useCallback, useEffect, useState } from 'react' import { EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION, EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, } from '@/app/education-apply/constants' +import { usePathname, useRouter, useSearchParams } from '@/next/navigation' import { sendGAEvent } from '@/utils/gtag' import { fetchSetupStatusWithCache } from '@/utils/setup-status' import { resolvePostLoginRedirect } from '../signin/utils/post-login-redirect' @@ -26,11 +26,10 @@ export const AppInitializer = ({ // Tokens are now stored in cookies, no need to check localStorage const pathname = usePathname() const [init, setInit] = useState(false) - const [oauthNewUser, setOauthNewUser] = useQueryState( + const [oauthNewUser] = useQueryState( 'oauth_new_user', parseAsBoolean.withOptions({ history: 'replace' }), ) - const isSetupFinished = useCallback(async () => { try { const setUpStatus = await fetchSetupStatusWithCache() @@ -69,11 +68,12 @@ export const AppInitializer = ({ ...utmInfo, }) - // Clean up: remove utm_info cookie and URL params Cookies.remove('utm_info') - setOauthNewUser(null) } + if (oauthNewUser !== null) + router.replace(pathname) + if (action === EDUCATION_VERIFY_URL_SEARCHPARAMS_ACTION) localStorage.setItem(EDUCATION_VERIFYING_LOCALSTORAGE_ITEM, 'yes') @@ -96,7 +96,7 @@ export const AppInitializer = ({ router.replace('/signin') } })() - }, [isSetupFinished, router, pathname, searchParams, oauthNewUser, setOauthNewUser]) + }, [isSetupFinished, router, pathname, searchParams, oauthNewUser]) return init ? children : null } diff --git a/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx b/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx new file mode 100644 index 0000000000..5018709da1 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/app-sidebar-dropdown.spec.tsx @@ -0,0 +1,177 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppSidebarDropdown from '../app-sidebar-dropdown' + +let mockAppDetail: (App & Partial) | undefined + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: mockAppDetail, + }), +})) + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: true, + }), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('../../base/divider', () => ({ + default: () =>
, +})) + +vi.mock('../app-info', () => ({ + default: ({ expand, onlyShowDetail, openState }: { + expand: boolean + onlyShowDetail?: boolean + openState?: boolean + }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode }: { name: string, href: string, mode?: string }) => ( + {name} + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +const navigation = [ + { name: 'Overview', href: '/overview', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Logs', href: '/logs', icon: MockIcon, selectedIcon: MockIcon }, +] + +describe('AppSidebarDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetail = createAppDetail() + }) + + it('should return null when appDetail is not available', () => { + mockAppDetail = undefined + const { container } = render() + expect(container.innerHTML).toBe('') + }) + + it('should render trigger with app icon', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const smallIcon = icons.find(i => i.getAttribute('data-size') === 'small') + expect(smallIcon).toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Logs')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should display app mode label', () => { + render() + expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() + }) + + it('should display mode labels for different modes', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.ADVANCED_CHAT }) + render() + expect(screen.getByText('app.types.advanced')).toBeInTheDocument() + }) + + it('should render AppInfo component for detail expand', () => { + render() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + expect(screen.getByTestId('app-info')).toHaveAttribute('data-only-detail', 'true') + }) + + it('should toggle portal open state when trigger is clicked', async () => { + const user = userEvent.setup() + render() + + const trigger = screen.getByTestId('portal-trigger') + await user.click(trigger) + + const portal = screen.getByTestId('portal-elem') + expect(portal).toHaveAttribute('data-open', 'true') + }) + + it('should render divider between app info and navigation', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should render large app icon in dropdown content', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const largeIcon = icons.find(icon => icon.getAttribute('data-size') === 'large') + expect(largeIcon).toBeInTheDocument() + }) + + it('should set detailExpand when clicking app info area', async () => { + const user = userEvent.setup() + render() + + const appName = screen.getByText('Test App') + const appInfoArea = appName.closest('[class*="cursor-pointer"]') + if (appInfoArea) + await user.click(appInfoArea) + }) + + it('should display workflow mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.WORKFLOW }) + render() + expect(screen.getByText('app.types.workflow')).toBeInTheDocument() + }) + + it('should display agent mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.AGENT_CHAT }) + render() + expect(screen.getByText('app.types.agent')).toBeInTheDocument() + }) + + it('should display completion mode label', () => { + mockAppDetail = createAppDetail({ mode: AppModeEnum.COMPLETION }) + render() + expect(screen.getByText('app.types.completion')).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/basic.spec.tsx b/web/app/components/app-sidebar/__tests__/basic.spec.tsx new file mode 100644 index 0000000000..67e708eb02 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/basic.spec.tsx @@ -0,0 +1,110 @@ +import { render, screen } from '@testing-library/react' +import * as React from 'react' +import AppBasic from '../basic' + +vi.mock('@/app/components/base/icons/src/vender/workflow', () => ({ + ApiAggregate: (props: React.SVGProps) => , + WindowCursor: (props: React.SVGProps) => , +})) + +vi.mock('@/app/components/base/tooltip', () => ({ + default: ({ popupContent }: { popupContent: React.ReactNode }) => ( +
{popupContent}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ icon, background, innerIcon, className }: { + icon?: string + background?: string + innerIcon?: React.ReactNode + className?: string + }) => ( +
+ {innerIcon} +
+ ), +})) + +describe('AppBasic', () => { + describe('Icon rendering', () => { + it('should render app icon when iconType is app with valid icon and background', () => { + render() + expect(screen.getByTestId('app-icon')).toBeInTheDocument() + }) + + it('should not render app icon when icon is empty', () => { + render() + expect(screen.queryByTestId('app-icon')).not.toBeInTheDocument() + }) + + it('should render api icon when iconType is api', () => { + render() + expect(screen.getByTestId('api-icon')).toBeInTheDocument() + }) + + it('should render webapp icon when iconType is webapp', () => { + render() + expect(screen.getByTestId('webapp-icon')).toBeInTheDocument() + }) + + it('should render dataset icon when iconType is dataset', () => { + render() + const icons = screen.getAllByTestId('app-icon') + expect(icons.length).toBeGreaterThan(0) + }) + + it('should render notion icon when iconType is notion', () => { + render() + const icons = screen.getAllByTestId('app-icon') + expect(icons.length).toBeGreaterThan(0) + }) + }) + + describe('Expand mode', () => { + it('should show name and type in expand mode', () => { + render() + expect(screen.getByText('My App')).toBeInTheDocument() + expect(screen.getByText('Chatbot')).toBeInTheDocument() + }) + + it('should hide name and type in collapse mode', () => { + render() + expect(screen.queryByText('My App')).not.toBeInTheDocument() + }) + + it('should show hover tip when provided', () => { + render() + expect(screen.getByTestId('tooltip')).toBeInTheDocument() + expect(screen.getByText('Some tip')).toBeInTheDocument() + }) + + it('should not show hover tip when not provided', () => { + render() + expect(screen.queryByTestId('tooltip')).not.toBeInTheDocument() + }) + }) + + describe('Type display', () => { + it('should hide type when hideType is true', () => { + render() + expect(screen.queryByText('Chatbot')).not.toBeInTheDocument() + }) + + it('should show external tag when isExternal is true', () => { + render() + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + }) + + it('should show type inline when isExtraInLine is true and hideType is false', () => { + render() + expect(screen.getByText('Chatbot')).toBeInTheDocument() + }) + + it('should apply custom text styles', () => { + render() + const nameContainer = screen.getByText('My App').parentElement + expect(nameContainer).toHaveClass('text-red-500') + }) + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx b/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx new file mode 100644 index 0000000000..1f3a5f9ad8 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/dataset-sidebar-dropdown.spec.tsx @@ -0,0 +1,193 @@ +import type { DataSet } from '@/models/datasets' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import DatasetSidebarDropdown from '../dataset-sidebar-dropdown' + +let mockDataset: DataSet + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset: DataSet }) => unknown) => + selector({ dataset: mockDataset }), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + useDatasetRelatedApps: () => ({ data: [] }), +})) + +vi.mock('@/hooks/use-knowledge', () => ({ + useKnowledge: () => ({ + formatIndexingTechniqueAndMethod: () => 'method-text', + }), +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) => ( +
{children}
+ ), +})) + +vi.mock('../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('../../base/divider', () => ({ + default: () =>
, +})) + +vi.mock('../../base/effect', () => ({ + default: ({ className }: { className?: string }) =>
, +})) + +vi.mock('../../datasets/extra-info', () => ({ + default: ({ expand, documentCount }: { + relatedApps?: unknown[] + expand: boolean + documentCount: number + }) => ( +
+ ), +})) + +vi.mock('../dataset-info/dropdown', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode, disabled }: { name: string, href: string, mode?: string, disabled?: boolean }) => ( + {name} + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Test Dataset', + description: 'A test dataset', + provider: 'internal', + icon_info: { + icon: '📙', + icon_type: 'emoji', + icon_background: '#FFF4ED', + icon_url: '', + }, + doc_form: 'text_model' as DataSet['doc_form'], + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + document_count: 10, + runtime_mode: 'general', + retrieval_model_dict: { + search_method: 'semantic_search' as DataSet['retrieval_model_dict']['search_method'], + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + ...overrides, +} as DataSet) + +const navigation = [ + { name: 'Documents', href: '/documents', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Settings', href: '/settings', icon: MockIcon, selectedIcon: MockIcon, disabled: true }, +] + +describe('DatasetSidebarDropdown', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset() + }) + + it('should render trigger with dataset icon', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const smallIcon = icons.find(i => i.getAttribute('data-size') === 'small') + expect(smallIcon).toBeInTheDocument() + expect(smallIcon).toHaveAttribute('data-icon', '📙') + }) + + it('should display dataset name in dropdown content', () => { + render() + expect(screen.getByText('Test Dataset')).toBeInTheDocument() + }) + + it('should display dataset description', () => { + render() + expect(screen.getByText('A test dataset')).toBeInTheDocument() + }) + + it('should not display description when empty', () => { + mockDataset = createDataset({ description: '' }) + render() + expect(screen.queryByText('A test dataset')).not.toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Documents')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Settings')).toBeInTheDocument() + }) + + it('should render ExtraInfo', () => { + render() + const extraInfo = screen.getByTestId('extra-info') + expect(extraInfo).toHaveAttribute('data-expand', 'true') + expect(extraInfo).toHaveAttribute('data-doc-count', '10') + }) + + it('should render Effect component', () => { + render() + expect(screen.getByTestId('effect')).toBeInTheDocument() + }) + + it('should render Dropdown component with expand=true', () => { + render() + expect(screen.getByTestId('dataset-dropdown')).toHaveAttribute('data-expand', 'true') + }) + + it('should show external tag for external provider', () => { + mockDataset = createDataset({ provider: 'external' }) + render() + expect(screen.getByText('dataset.externalTag')).toBeInTheDocument() + }) + + it('should use fallback icon info when icon_info is missing', () => { + mockDataset = createDataset({ icon_info: undefined as unknown as DataSet['icon_info'] }) + render() + const icons = screen.getAllByTestId('app-icon') + const fallbackIcon = icons.find(i => i.getAttribute('data-icon') === '📙') + expect(fallbackIcon).toBeInTheDocument() + }) + + it('should toggle dropdown open state on trigger click', async () => { + const user = userEvent.setup() + render() + + const trigger = screen.getByTestId('portal-trigger') + await user.click(trigger) + + expect(screen.getByTestId('portal-elem')).toHaveAttribute('data-open', 'true') + }) + + it('should render divider', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should render medium app icon in content area', () => { + render() + const icons = screen.getAllByTestId('app-icon') + const mediumIcon = icons.find(i => i.getAttribute('data-size') === 'medium') + expect(mediumIcon).toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/__tests__/index.spec.tsx b/web/app/components/app-sidebar/__tests__/index.spec.tsx new file mode 100644 index 0000000000..b2e1e92bbb --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/index.spec.tsx @@ -0,0 +1,298 @@ +import { act, render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import AppDetailNav from '..' + +let mockAppSidebarExpand = 'expand' +const mockSetAppSidebarExpand = vi.fn() +let mockPathname = '/app/123/overview' + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: { id: 'app-1', name: 'Test', mode: 'chat', icon: '🤖', icon_type: 'emoji', icon_background: '#fff' }, + appSidebarExpand: mockAppSidebarExpand, + setAppSidebarExpand: mockSetAppSidebarExpand, + }), +})) + +vi.mock('zustand/react/shallow', () => ({ + useShallow: (fn: unknown) => fn, +})) + +vi.mock('@/next/navigation', () => ({ + usePathname: () => mockPathname, +})) + +let mockIsHovering = true +let mockKeyPressCallback: ((e: { preventDefault: () => void }) => void) | null = null + +vi.mock('ahooks', () => ({ + useHover: () => mockIsHovering, + useKeyPress: (_key: string, cb: (e: { preventDefault: () => void }) => void) => { + mockKeyPressCallback = cb + }, +})) + +vi.mock('@/hooks/use-breakpoints', () => ({ + default: () => 'desktop', + MediaType: { mobile: 'mobile', desktop: 'desktop' }, +})) + +let mockSubscriptionCallback: ((v: unknown) => void) | null = null + +vi.mock('@/context/event-emitter', () => ({ + useEventEmitterContextContext: () => ({ + eventEmitter: { + useSubscription: (cb: (v: unknown) => void) => { mockSubscriptionCallback = cb }, + }, + }), +})) + +vi.mock('../../base/divider', () => ({ + default: ({ className }: { className?: string }) =>
, +})) + +vi.mock('@/app/components/workflow/utils', () => ({ + getKeyboardKeyCodeBySystem: () => 'ctrl', +})) + +vi.mock('../app-info', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../app-sidebar-dropdown', () => ({ + default: ({ navigation }: { navigation: unknown[] }) => ( +
+ ), +})) + +vi.mock('../dataset-info', () => ({ + default: ({ expand }: { expand: boolean }) => ( +
+ ), +})) + +vi.mock('../dataset-sidebar-dropdown', () => ({ + default: ({ navigation }: { navigation: unknown[] }) => ( +
+ ), +})) + +vi.mock('../nav-link', () => ({ + default: ({ name, href, mode }: { name: string, href: string, mode?: string }) => ( + {name} + ), +})) + +vi.mock('../toggle-button', () => ({ + default: ({ expand, handleToggle, className }: { expand: boolean, handleToggle: () => void, className?: string }) => ( + + ), +})) + +const MockIcon = (props: React.SVGProps) => + +const navigation = [ + { name: 'Overview', href: '/overview', icon: MockIcon, selectedIcon: MockIcon }, + { name: 'Logs', href: '/logs', icon: MockIcon, selectedIcon: MockIcon }, +] + +describe('AppDetailNav', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppSidebarExpand = 'expand' + mockPathname = '/app/123/overview' + mockIsHovering = true + }) + + describe('Normal sidebar mode', () => { + it('should render AppInfo when iconType is app', () => { + render() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + expect(screen.getByTestId('app-info')).toHaveAttribute('data-expand', 'true') + }) + + it('should render DatasetInfo when iconType is dataset', () => { + render() + expect(screen.getByTestId('dataset-info')).toBeInTheDocument() + }) + + it('should render navigation links', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toBeInTheDocument() + expect(screen.getByTestId('nav-link-Logs')).toBeInTheDocument() + }) + + it('should render divider', () => { + render() + expect(screen.getByTestId('divider')).toBeInTheDocument() + }) + + it('should apply expanded width class', () => { + const { container } = render() + const sidebar = container.firstElementChild as HTMLElement + expect(sidebar).toHaveClass('w-[216px]') + }) + + it('should apply collapsed width class', () => { + mockAppSidebarExpand = 'collapse' + const { container } = render() + const sidebar = container.firstElementChild as HTMLElement + expect(sidebar).toHaveClass('w-14') + }) + + it('should render extraInfo when iconType is dataset and extraInfo provided', () => { + render( +
} + />, + ) + expect(screen.getByTestId('extra-info')).toBeInTheDocument() + }) + + it('should not render extraInfo when iconType is app', () => { + render( +
} + />, + ) + expect(screen.queryByTestId('extra-info')).not.toBeInTheDocument() + }) + }) + + describe('Workflow canvas mode', () => { + it('should render AppSidebarDropdown when in workflow canvas with hidden header', () => { + mockPathname = '/app/123/workflow' + localStorage.setItem('workflow-canvas-maximize', 'true') + + render() + + expect(screen.getByTestId('app-sidebar-dropdown')).toBeInTheDocument() + expect(screen.queryByTestId('app-info')).not.toBeInTheDocument() + }) + + it('should render normal sidebar when workflow canvas is not maximized', () => { + mockPathname = '/app/123/workflow' + localStorage.setItem('workflow-canvas-maximize', 'false') + + render() + + expect(screen.queryByTestId('app-sidebar-dropdown')).not.toBeInTheDocument() + expect(screen.getByTestId('app-info')).toBeInTheDocument() + }) + }) + + describe('Pipeline canvas mode', () => { + it('should render DatasetSidebarDropdown when in pipeline canvas with hidden header', () => { + mockPathname = '/dataset/123/pipeline' + localStorage.setItem('workflow-canvas-maximize', 'true') + + render() + + expect(screen.getByTestId('dataset-sidebar-dropdown')).toBeInTheDocument() + expect(screen.queryByTestId('app-info')).not.toBeInTheDocument() + }) + }) + + describe('Navigation mode', () => { + it('should pass expand mode to nav links when expanded', () => { + render() + expect(screen.getByTestId('nav-link-Overview')).toHaveAttribute('data-mode', 'expand') + }) + + it('should pass collapse mode to nav links when collapsed', () => { + mockAppSidebarExpand = 'collapse' + render() + expect(screen.getByTestId('nav-link-Overview')).toHaveAttribute('data-mode', 'collapse') + }) + }) + + describe('Toggle behavior', () => { + it('should call setAppSidebarExpand on toggle', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('toggle-button')) + + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse') + }) + + it('should toggle from collapse to expand', async () => { + const user = userEvent.setup() + mockAppSidebarExpand = 'collapse' + render() + + await user.click(screen.getByTestId('toggle-button')) + + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('expand') + }) + }) + + describe('Sidebar persistence', () => { + it('should persist expand state to localStorage', () => { + render() + expect(localStorage.setItem).toHaveBeenCalledWith('app-detail-collapse-or-expand', 'expand') + }) + }) + + describe('Disabled navigation items', () => { + it('should render disabled navigation items', () => { + const navWithDisabled = [ + ...navigation, + { name: 'Disabled', href: '/disabled', icon: MockIcon, selectedIcon: MockIcon, disabled: true }, + ] + render() + expect(screen.getByTestId('nav-link-Disabled')).toBeInTheDocument() + }) + }) + + describe('Event emitter subscription', () => { + it('should handle workflow-canvas-maximize event', () => { + mockPathname = '/app/123/workflow' + render() + + const cb = mockSubscriptionCallback + expect(cb).not.toBeNull() + act(() => { + cb!({ type: 'workflow-canvas-maximize', payload: true }) + }) + }) + + it('should ignore non-maximize events', () => { + render() + + const cb = mockSubscriptionCallback + act(() => { + cb!({ type: 'other-event' }) + }) + }) + }) + + describe('Keyboard shortcut', () => { + it('should toggle sidebar on ctrl+b', () => { + render() + + const cb = mockKeyPressCallback + expect(cb).not.toBeNull() + act(() => { + cb!({ preventDefault: vi.fn() }) + }) + expect(mockSetAppSidebarExpand).toHaveBeenCalledWith('collapse') + }) + }) + + describe('Hover-based toggle button visibility', () => { + it('should hide toggle button when not hovering', () => { + mockIsHovering = false + render() + expect(screen.queryByTestId('toggle-button')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx b/web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx similarity index 80% rename from web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx rename to web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx index 5d85b99d9a..fef65fcad3 100644 --- a/web/app/components/app-sidebar/sidebar-animation-issues.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/sidebar-animation-issues.spec.tsx @@ -143,12 +143,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(toggleSection).toHaveClass('px-4') // Same consistent padding expect(toggleSection).not.toHaveClass('px-5') expect(toggleSection).not.toHaveClass('px-6') - - // THE FIX: px-4 in both states prevents position movement - console.log('✅ Issue #1 FIXED: Toggle button now has consistent padding') - console.log(' - Before: px-4 (collapsed) vs px-6 (expanded) - 8px difference') - console.log(' - After: px-4 (both states) - 0px difference') - console.log(' - Result: No button position movement during transition') }) it('should verify sidebar width animation is working correctly', () => { @@ -164,8 +158,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Expanded state rerender() expect(container).toHaveClass('w-[216px]') - - console.log('✅ Sidebar width transition is properly configured') }) }) @@ -188,13 +180,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(link).toHaveClass('px-3') // 12px padding (+2px) expect(icon).toHaveClass('mr-2') // 8px margin (+8px) expect(screen.getByTestId('nav-text-Orchestrate')).toBeInTheDocument() - - // THE BUG: Multiple simultaneous changes create squeeze effect - console.log('🐛 Issue #2 Reproduced: Text squeeze effect from multiple layout changes') - console.log(' - Link padding: px-2.5 → px-3 (+2px)') - console.log(' - Icon margin: mr-0 → mr-2 (+8px)') - console.log(' - Text appears: none → visible (abrupt)') - console.log(' - Result: Text appears with squeeze effect due to layout shifts') }) it('should document the abrupt text rendering issue', () => { @@ -207,10 +192,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Text suddenly appears - no transition expect(screen.getByTestId('nav-text-API Access')).toBeInTheDocument() - - console.log('🐛 Issue #2 Detail: Conditional rendering {mode === "expand" && name}') - console.log(' - Problem: Text appears/disappears abruptly without transition') - console.log(' - Should use: opacity or width transition for smooth appearance') }) }) @@ -234,13 +215,6 @@ describe('Sidebar Animation Issues Reproduction', () => { expect(iconContainer).toHaveClass('gap-1') expect(iconContainer).not.toHaveClass('justify-between') expect(appIcon).toHaveAttribute('data-size', 'small') - - // THE BUG: Layout mode switch causes icon to "bounce" - console.log('🐛 Issue #3 Reproduced: Icon bounce from layout mode switching') - console.log(' - Layout change: justify-between → flex-col gap-1') - console.log(' - Icon size: large (40px) → small (24px)') - console.log(' - Transition: transition-all causes excessive animation') - console.log(' - Result: Icon appears to bounce to right then back during collapse') }) it('should identify the problematic transition-all property', () => { @@ -251,10 +225,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // The problematic broad transition expect(computedStyle.transition).toContain('all') - - console.log('🐛 Issue #3 Detail: transition-all affects ALL CSS properties') - console.log(' - Problem: Animates layout properties that should not transition') - console.log(' - Solution: Use specific transition properties instead of "all"') }) }) @@ -276,7 +246,6 @@ describe('Sidebar Animation Issues Reproduction', () => { // Initial state verification expect(expanded).toBe(false) - console.log('🔄 Starting interactive test - all issues will be reproduced') // Simulate toggle click fireEvent.click(toggleButton) @@ -287,11 +256,6 @@ describe('Sidebar Animation Issues Reproduction', () => {
, ) - - console.log('✨ All three issues successfully reproduced in interactive test:') - console.log(' 1. Toggle button position movement (padding inconsistency)') - console.log(' 2. Navigation text squeeze effect (multiple layout changes)') - console.log(' 3. App icon bounce animation (layout mode switching)') }) }) }) diff --git a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx similarity index 65% rename from web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx rename to web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx index f7e91b3dea..a3868a8330 100644 --- a/web/app/components/app-sidebar/text-squeeze-fix-verification.spec.tsx +++ b/web/app/components/app-sidebar/__tests__/text-squeeze-fix-verification.spec.tsx @@ -7,13 +7,13 @@ import { render } from '@testing-library/react' import * as React from 'react' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock classnames utility vi.mock('@/utils/classnames', () => ({ - default: (...classes: any[]) => classes.filter(Boolean).join(' '), + default: (...classes: unknown[]) => classes.filter(Boolean).join(' '), })) // Simplified NavLink component to test the fix @@ -101,12 +101,6 @@ describe('Text Squeeze Fix Verification', () => { expect(textElement).toHaveClass('whitespace-nowrap') expect(textElement).toHaveClass('transition-all') - console.log('✅ NavLink Collapsed State:') - console.log(' - Text is in DOM but visually hidden') - console.log(' - Uses opacity-0 and w-0 for hiding') - console.log(' - Has whitespace-nowrap to prevent wrapping') - console.log(' - Has transition-all for smooth animation') - // Switch to expanded state rerender() @@ -115,13 +109,6 @@ describe('Text Squeeze Fix Verification', () => { expect(expandedText).toHaveClass('opacity-100') expect(expandedText).toHaveClass('w-auto') expect(expandedText).not.toHaveClass('pointer-events-none') - - console.log('✅ NavLink Expanded State:') - console.log(' - Text is visible with opacity-100') - console.log(' - Uses w-auto for natural width') - console.log(' - No layout jumps during transition') - - console.log('🎯 NavLink Fix Result: Text squeeze effect ELIMINATED') }) it('should verify smooth transition properties', () => { @@ -131,11 +118,6 @@ describe('Text Squeeze Fix Verification', () => { expect(textElement).toHaveClass('transition-all') expect(textElement).toHaveClass('duration-200') expect(textElement).toHaveClass('ease-in-out') - - console.log('✅ Transition Properties Verified:') - console.log(' - transition-all: Smooth property changes') - console.log(' - duration-200: 200ms transition time') - console.log(' - ease-in-out: Smooth easing function') }) }) @@ -159,11 +141,6 @@ describe('Text Squeeze Fix Verification', () => { expect(appName).toHaveClass('whitespace-nowrap') expect(appType).toHaveClass('whitespace-nowrap') - console.log('✅ AppInfo Collapsed State:') - console.log(' - Text container is in DOM but visually hidden') - console.log(' - App name and type elements always present') - console.log(' - Uses whitespace-nowrap to prevent wrapping') - // Switch to expanded state rerender() @@ -172,13 +149,6 @@ describe('Text Squeeze Fix Verification', () => { expect(expandedContainer).toHaveClass('opacity-100') expect(expandedContainer).toHaveClass('w-auto') expect(expandedContainer).not.toHaveClass('pointer-events-none') - - console.log('✅ AppInfo Expanded State:') - console.log(' - Text container is visible with opacity-100') - console.log(' - Uses w-auto for natural width') - console.log(' - No layout jumps during transition') - - console.log('🎯 AppInfo Fix Result: Text squeeze effect ELIMINATED') }) it('should verify transition properties on text container', () => { @@ -188,45 +158,11 @@ describe('Text Squeeze Fix Verification', () => { expect(textContainer).toHaveClass('transition-all') expect(textContainer).toHaveClass('duration-200') expect(textContainer).toHaveClass('ease-in-out') - - console.log('✅ AppInfo Transition Properties Verified:') - console.log(' - Container has smooth CSS transitions') - console.log(' - Same 200ms duration as NavLink for consistency') }) }) describe('Fix Strategy Comparison', () => { it('should document the fix strategy differences', () => { - console.log('\n📋 TEXT SQUEEZE FIX STRATEGY COMPARISON') - console.log('='.repeat(60)) - - console.log('\n❌ BEFORE (Problematic):') - console.log(' NavLink: {mode === "expand" && name}') - console.log(' AppInfo: {expand && (
...
)}') - console.log(' Problem: Conditional rendering causes abrupt appearance') - console.log(' Result: Text "squeezes" from center during layout changes') - - console.log('\n✅ AFTER (Fixed):') - console.log(' NavLink: {name}') - console.log(' AppInfo:
...
') - console.log(' Solution: CSS controls visibility, element always in DOM') - console.log(' Result: Smooth opacity and width transitions') - - console.log('\n🎯 KEY FIX PRINCIPLES:') - console.log(' 1. ✅ Always keep text elements in DOM') - console.log(' 2. ✅ Use opacity for show/hide transitions') - console.log(' 3. ✅ Use width (w-0/w-auto) for layout control') - console.log(' 4. ✅ Add whitespace-nowrap to prevent wrapping') - console.log(' 5. ✅ Use pointer-events-none when hidden') - console.log(' 6. ✅ Add overflow-hidden for clean hiding') - - console.log('\n🚀 BENEFITS:') - console.log(' - No more abrupt text appearance') - console.log(' - Smooth 200ms transitions') - console.log(' - No layout jumps or shifts') - console.log(' - Consistent animation timing') - console.log(' - Better user experience') - // Always pass documentation test expect(true).toBe(true) }) diff --git a/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx b/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx new file mode 100644 index 0000000000..1a117ac5e3 --- /dev/null +++ b/web/app/components/app-sidebar/__tests__/toggle-button.spec.tsx @@ -0,0 +1,46 @@ +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import ToggleButton from '../toggle-button' + +vi.mock('@/app/components/workflow/shortcuts-name', () => ({ + default: ({ keys }: { keys: string[] }) => ( + {keys.join('+')} + ), +})) + +describe('ToggleButton', () => { + it('should render collapse arrow when expanded', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should render expand arrow when collapsed', () => { + render() + const button = screen.getByRole('button') + expect(button).toBeInTheDocument() + }) + + it('should call handleToggle when clicked', async () => { + const user = userEvent.setup() + const handleToggle = vi.fn() + render() + + await user.click(screen.getByRole('button')) + + expect(handleToggle).toHaveBeenCalledTimes(1) + }) + + it('should apply custom className', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveClass('custom-class') + }) + + it('should have rounded-full style', () => { + render() + const button = screen.getByRole('button') + expect(button).toHaveClass('rounded-full') + }) +}) diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx deleted file mode 100644 index aa31f0201f..0000000000 --- a/web/app/components/app-sidebar/app-info.tsx +++ /dev/null @@ -1,474 +0,0 @@ -import type { Operation } from './app-operations' -import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' -import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' -import type { EnvironmentVariable } from '@/app/components/workflow/types' -import { - RiDeleteBinLine, - RiEditLine, - RiEqualizer2Line, - RiExchange2Line, - RiFileCopy2Line, - RiFileDownloadLine, - RiFileUploadLine, -} from '@remixicon/react' -import dynamic from 'next/dynamic' -import { useRouter } from 'next/navigation' -import * as React from 'react' -import { useCallback, useState } from 'react' -import { useTranslation } from 'react-i18next' -import { useContext } from 'use-context-selector' -import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view' -import { useStore as useAppStore } from '@/app/components/app/store' -import Button from '@/app/components/base/button' -import ContentDialog from '@/app/components/base/content-dialog' -import { ToastContext } from '@/app/components/base/toast' -import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { useAppContext } from '@/context/app-context' -import { useProviderContext } from '@/context/provider-context' -import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' -import { useInvalidateAppList } from '@/service/use-apps' -import { fetchWorkflowDraft } from '@/service/workflow' -import { AppModeEnum } from '@/types/app' -import { getRedirection } from '@/utils/app-redirection' -import { cn } from '@/utils/classnames' -import { downloadBlob } from '@/utils/download' -import AppIcon from '../base/app-icon' -import AppOperations from './app-operations' - -const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { - ssr: false, -}) -const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { - ssr: false, -}) -const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), { - ssr: false, -}) -const Confirm = dynamic(() => import('@/app/components/base/confirm'), { - ssr: false, -}) -const UpdateDSLModal = dynamic(() => import('@/app/components/workflow/update-dsl-modal'), { - ssr: false, -}) -const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), { - ssr: false, -}) - -export type IAppInfoProps = { - expand: boolean - onlyShowDetail?: boolean - openState?: boolean - onDetailExpand?: (expand: boolean) => void -} - -const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailExpand }: IAppInfoProps) => { - const { t } = useTranslation() - const { notify } = useContext(ToastContext) - const { replace } = useRouter() - const { onPlanInfoChanged } = useProviderContext() - const appDetail = useAppStore(state => state.appDetail) - const setAppDetail = useAppStore(state => state.setAppDetail) - const invalidateAppList = useInvalidateAppList() - const [open, setOpen] = useState(openState) - const [showEditModal, setShowEditModal] = useState(false) - const [showDuplicateModal, setShowDuplicateModal] = useState(false) - const [showConfirmDelete, setShowConfirmDelete] = useState(false) - const [showSwitchModal, setShowSwitchModal] = useState(false) - const [showImportDSLModal, setShowImportDSLModal] = useState(false) - const [secretEnvList, setSecretEnvList] = useState([]) - const [showExportWarning, setShowExportWarning] = useState(false) - - const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ - name, - icon_type, - icon, - icon_background, - description, - use_icon_as_answer_icon, - max_active_requests, - }) => { - if (!appDetail) - return - try { - const app = await updateAppInfo({ - appID: appDetail.id, - name, - icon_type, - icon, - icon_background, - description, - use_icon_as_answer_icon, - max_active_requests, - }) - setShowEditModal(false) - notify({ - type: 'success', - message: t('editDone', { ns: 'app' }), - }) - setAppDetail(app) - } - catch { - notify({ type: 'error', message: t('editFailed', { ns: 'app' }) }) - } - }, [appDetail, notify, setAppDetail, t]) - - const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { - if (!appDetail) - return - try { - const newApp = await copyApp({ - appID: appDetail.id, - name, - icon_type, - icon, - icon_background, - mode: appDetail.mode, - }) - setShowDuplicateModal(false) - notify({ - type: 'success', - message: t('newApp.appCreated', { ns: 'app' }), - }) - localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - onPlanInfoChanged() - getRedirection(true, newApp, replace) - } - catch { - notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) - } - } - - const onExport = async (include = false) => { - if (!appDetail) - return - try { - const { data } = await exportAppConfig({ - appID: appDetail.id, - include, - }) - const file = new Blob([data], { type: 'application/yaml' }) - downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) - } - catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) - } - } - - const exportCheck = async () => { - if (!appDetail) - return - if (appDetail.mode !== AppModeEnum.WORKFLOW && appDetail.mode !== AppModeEnum.ADVANCED_CHAT) { - onExport() - return - } - - setShowExportWarning(true) - } - - const handleConfirmExport = async () => { - if (!appDetail) - return - setShowExportWarning(false) - try { - const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) - const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') - if (list.length === 0) { - onExport() - return - } - setSecretEnvList(list) - } - catch { - notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) - } - } - - const onConfirmDelete = useCallback(async () => { - if (!appDetail) - return - try { - await deleteApp(appDetail.id) - notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) - invalidateAppList() - onPlanInfoChanged() - setAppDetail() - replace('/apps') - } - catch (e: any) { - notify({ - type: 'error', - message: `${t('appDeleteFailed', { ns: 'app' })}${'message' in e ? `: ${e.message}` : ''}`, - }) - } - setShowConfirmDelete(false) - }, [appDetail, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t]) - - const { isCurrentWorkspaceEditor } = useAppContext() - - if (!appDetail) - return null - - const primaryOperations = [ - { - id: 'edit', - title: t('editApp', { ns: 'app' }), - icon: , - onClick: () => { - setOpen(false) - onDetailExpand?.(false) - setShowEditModal(true) - }, - }, - { - id: 'duplicate', - title: t('duplicate', { ns: 'app' }), - icon: , - onClick: () => { - setOpen(false) - onDetailExpand?.(false) - setShowDuplicateModal(true) - }, - }, - { - id: 'export', - title: t('export', { ns: 'app' }), - icon: , - onClick: exportCheck, - }, - ] - - const secondaryOperations: Operation[] = [ - // Import DSL (conditional) - ...(appDetail.mode === AppModeEnum.ADVANCED_CHAT || appDetail.mode === AppModeEnum.WORKFLOW) - ? [{ - id: 'import', - title: t('common.importDSL', { ns: 'workflow' }), - icon: , - onClick: () => { - setOpen(false) - onDetailExpand?.(false) - setShowImportDSLModal(true) - }, - }] - : [], - // Divider - { - id: 'divider-1', - title: '', - icon: <>, - onClick: () => { /* divider has no action */ }, - type: 'divider' as const, - }, - // Delete operation - { - id: 'delete', - title: t('operation.delete', { ns: 'common' }), - icon: , - onClick: () => { - setOpen(false) - onDetailExpand?.(false) - setShowConfirmDelete(true) - }, - }, - ] - - // Keep the switch operation separate as it's not part of the main operations - const switchOperation = (appDetail.mode === AppModeEnum.COMPLETION || appDetail.mode === AppModeEnum.CHAT) - ? { - id: 'switch', - title: t('switch', { ns: 'app' }), - icon: , - onClick: () => { - setOpen(false) - onDetailExpand?.(false) - setShowSwitchModal(true) - }, - } - : null - - return ( -
- {!onlyShowDetail && ( - - )} - { - setOpen(false) - onDetailExpand?.(false) - }} - className="absolute bottom-2 left-2 top-2 flex w-[420px] flex-col rounded-2xl !p-0" - > -
-
- -
-
{appDetail.name}
-
{appDetail.mode === AppModeEnum.ADVANCED_CHAT ? t('types.advanced', { ns: 'app' }) : appDetail.mode === AppModeEnum.AGENT_CHAT ? t('types.agent', { ns: 'app' }) : appDetail.mode === AppModeEnum.CHAT ? t('types.chatbot', { ns: 'app' }) : appDetail.mode === AppModeEnum.COMPLETION ? t('types.completion', { ns: 'app' }) : t('types.workflow', { ns: 'app' })}
-
-
- {/* description */} - {appDetail.description && ( -
{appDetail.description}
- )} - {/* operations */} - -
- - {/* Switch operation (if available) */} - {switchOperation && ( -
- -
- )} -
- {showSwitchModal && ( - setShowSwitchModal(false)} - onSuccess={() => setShowSwitchModal(false)} - /> - )} - {showEditModal && ( - setShowEditModal(false)} - /> - )} - {showDuplicateModal && ( - setShowDuplicateModal(false)} - /> - )} - {showConfirmDelete && ( - setShowConfirmDelete(false)} - /> - )} - {showImportDSLModal && ( - setShowImportDSLModal(false)} - onBackup={exportCheck} - /> - )} - {secretEnvList.length > 0 && ( - setSecretEnvList([])} - /> - )} - {showExportWarning && ( - setShowExportWarning(false)} - /> - )} -
- ) -} - -export default React.memo(AppInfo) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx new file mode 100644 index 0000000000..3082eb3789 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-detail-panel.spec.tsx @@ -0,0 +1,298 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoDetailPanel from '../app-info-detail-panel' + +vi.mock('../../../base/app-icon', () => ({ + default: ({ size, icon }: { size: string, icon: string }) => ( +
+ ), +})) + +vi.mock('@/app/components/base/content-dialog', () => ({ + default: ({ show, onClose, children, className }: { + show: boolean + onClose: () => void + children: React.ReactNode + className?: string + }) => ( + show + ? ( +
+ + {children} +
+ ) + : null + ), +})) + +vi.mock('@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view', () => ({ + default: ({ appId }: { appId: string }) => ( +
+ ), +})) + +vi.mock('@/app/components/base/button', () => ({ + default: ({ children, onClick, className, size, variant }: { + children: React.ReactNode + onClick?: () => void + className?: string + size?: string + variant?: string + }) => ( + + ), +})) + +vi.mock('../app-operations', () => ({ + default: ({ primaryOperations, secondaryOperations }: { + primaryOperations?: Array<{ id: string, title: string, onClick: () => void }> + secondaryOperations?: Array<{ id: string, title: string, onClick: () => void, type?: string }> + }) => ( +
+ {primaryOperations?.map(op => ( + + ))} + {secondaryOperations?.map(op => ( + op.type === 'divider' + ? + : + ))} +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: 'A test description', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +describe('AppInfoDetailPanel', () => { + const defaultProps = { + appDetail: createAppDetail(), + show: true, + onClose: vi.fn(), + openModal: vi.fn(), + exportCheck: vi.fn(), + } + + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering', () => { + it('should not render when show is false', () => { + render() + expect(screen.queryByTestId('content-dialog')).not.toBeInTheDocument() + }) + + it('should render dialog when show is true', () => { + render() + expect(screen.getByTestId('content-dialog')).toBeInTheDocument() + }) + + it('should display app name', () => { + render() + expect(screen.getByText('Test App')).toBeInTheDocument() + }) + + it('should display app mode label', () => { + render() + expect(screen.getByText('app.types.chatbot')).toBeInTheDocument() + }) + + it('should display description when available', () => { + render() + expect(screen.getByText('A test description')).toBeInTheDocument() + }) + + it('should not display description when empty', () => { + render() + expect(screen.queryByText('A test description')).not.toBeInTheDocument() + }) + + it('should not display description when undefined', () => { + render() + expect(screen.queryByText('A test description')).not.toBeInTheDocument() + }) + + it('should render CardView with correct appId', () => { + render() + const cardView = screen.getByTestId('card-view') + expect(cardView).toHaveAttribute('data-app-id', 'app-1') + }) + + it('should render app icon with large size', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'large') + }) + }) + + describe('Operations', () => { + it('should render edit, duplicate, and export operations', () => { + render() + expect(screen.getByTestId('op-edit')).toBeInTheDocument() + expect(screen.getByTestId('op-duplicate')).toBeInTheDocument() + expect(screen.getByTestId('op-export')).toBeInTheDocument() + }) + + it('should call openModal with edit when edit is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-edit')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('edit') + }) + + it('should call openModal with duplicate when duplicate is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-duplicate')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('duplicate') + }) + + it('should call exportCheck when export is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-export')) + + expect(defaultProps.exportCheck).toHaveBeenCalledTimes(1) + }) + + it('should render delete operation', () => { + render() + expect(screen.getByTestId('op-delete')).toBeInTheDocument() + }) + + it('should call openModal with delete when delete is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('op-delete')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('delete') + }) + }) + + describe('Import DSL option', () => { + it('should show import DSL for advanced_chat mode', () => { + render( + , + ) + expect(screen.getByTestId('op-import')).toBeInTheDocument() + }) + + it('should show import DSL for workflow mode', () => { + render( + , + ) + expect(screen.getByTestId('op-import')).toBeInTheDocument() + }) + + it('should not show import DSL for chat mode', () => { + render() + expect(screen.queryByTestId('op-import')).not.toBeInTheDocument() + }) + + it('should call openModal with importDSL when import is clicked', async () => { + const user = userEvent.setup() + render( + , + ) + await user.click(screen.getByTestId('op-import')) + expect(defaultProps.openModal).toHaveBeenCalledWith('importDSL') + }) + + it('should render divider in secondary operations', async () => { + const user = userEvent.setup() + render() + const divider = screen.getByTestId('op-divider-1') + expect(divider).toBeInTheDocument() + await user.click(divider) + }) + }) + + describe('Switch operation', () => { + it('should show switch button for chat mode', () => { + render() + expect(screen.getByText('app.switch')).toBeInTheDocument() + }) + + it('should show switch button for completion mode', () => { + render( + , + ) + expect(screen.getByText('app.switch')).toBeInTheDocument() + }) + + it('should not show switch button for workflow mode', () => { + render( + , + ) + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + + it('should not show switch button for advanced_chat mode', () => { + render( + , + ) + expect(screen.queryByText('app.switch')).not.toBeInTheDocument() + }) + + it('should call openModal with switch when switch button is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByText('app.switch')) + + expect(defaultProps.openModal).toHaveBeenCalledWith('switch') + }) + }) + + describe('Dialog interactions', () => { + it('should call onClose when dialog close button is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('dialog-close')) + + expect(defaultProps.onClose).toHaveBeenCalledTimes(1) + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx new file mode 100644 index 0000000000..2f98089e40 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-modals.spec.tsx @@ -0,0 +1,264 @@ +import type { App, AppSSO } from '@/types/app' +import { act, render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoModals from '../app-info-modals' + +vi.mock('@/next/dynamic', () => ({ + default: (loader: () => Promise<{ default: React.ComponentType }>) => { + const LazyComp = React.lazy(loader) + return function DynamicWrapper(props: Record) { + return React.createElement( + React.Suspense, + { fallback: null }, + React.createElement(LazyComp, props), + ) + } + }, +})) + +vi.mock('@/app/components/app/switch-app-modal', () => ({ + default: ({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/explore/create-app-modal', () => ({ + default: ({ show, onHide, isEditModal }: { show: boolean, onHide: () => void, isEditModal?: boolean }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/app/duplicate-modal', () => ({ + default: ({ show, onHide }: { show: boolean, onHide: () => void }) => ( + show ?
: null + ), +})) + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ isShow, title, onConfirm, onCancel }: { + isShow: boolean + title: string + onConfirm: () => void + onCancel: () => void + }) => ( + isShow + ? ( +
+ + +
+ ) + : null + ), +})) + +vi.mock('@/app/components/workflow/update-dsl-modal', () => ({ + default: ({ onCancel, onBackup }: { onCancel: () => void, onBackup: () => void }) => ( +
+ + +
+ ), +})) + +vi.mock('@/app/components/workflow/dsl-export-confirm-modal', () => ({ + default: ({ onConfirm, onClose }: { onConfirm: (include?: boolean) => void, onClose: () => void }) => ( +
+ + +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, + max_active_requests: null, + ...overrides, +} as App & Partial) + +const defaultProps = { + appDetail: createAppDetail(), + closeModal: vi.fn(), + secretEnvList: [] as never[], + setSecretEnvList: vi.fn(), + onEdit: vi.fn(), + onCopy: vi.fn(), + onExport: vi.fn(), + exportCheck: vi.fn(), + handleConfirmExport: vi.fn(), + onConfirmDelete: vi.fn(), +} + +describe('AppInfoModals', () => { + beforeAll(async () => { + await new Promise(resolve => setTimeout(resolve, 0)) + }) + + beforeEach(() => { + vi.clearAllMocks() + }) + + it('should render nothing when activeModal is null', async () => { + await act(async () => { + render() + }) + expect(screen.queryByTestId('switch-modal')).not.toBeInTheDocument() + expect(screen.queryByTestId('confirm-modal')).not.toBeInTheDocument() + }) + + it('should render SwitchAppModal when activeModal is switch', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('switch-modal')).toBeInTheDocument() + }) + }) + + it('should render CreateAppModal in edit mode when activeModal is edit', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('edit-modal')).toBeInTheDocument() + }) + }) + + it('should render DuplicateAppModal when activeModal is duplicate', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('duplicate-modal')).toBeInTheDocument() + }) + }) + + it('should render Confirm for delete when activeModal is delete', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + const confirm = screen.getByTestId('confirm-modal') + expect(confirm).toBeInTheDocument() + expect(confirm).toHaveAttribute('data-title', 'app.deleteAppConfirmTitle') + }) + }) + + it('should render UpdateDSLModal when activeModal is importDSL', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + expect(screen.getByTestId('import-dsl-modal')).toBeInTheDocument() + }) + }) + + it('should render export warning Confirm when activeModal is exportWarning', async () => { + await act(async () => { + render() + }) + await waitFor(() => { + const confirm = screen.getByTestId('confirm-modal') + expect(confirm).toBeInTheDocument() + expect(confirm).toHaveAttribute('data-title', 'workflow.sidebar.exportWarning') + }) + }) + + it('should render DSLExportConfirmModal when secretEnvList is not empty', async () => { + await act(async () => { + render( + , + ) + }) + await waitFor(() => { + expect(screen.getByTestId('dsl-export-confirm-modal')).toBeInTheDocument() + }) + }) + + it('should not render DSLExportConfirmModal when secretEnvList is empty', async () => { + await act(async () => { + render() + }) + expect(screen.queryByTestId('dsl-export-confirm-modal')).not.toBeInTheDocument() + }) + + it('should call closeModal when cancel on delete modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Cancel')).toBeInTheDocument()) + await user.click(screen.getByText('Cancel')) + + expect(defaultProps.closeModal).toHaveBeenCalledTimes(1) + }) + + it('should call onConfirmDelete when confirm on delete modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Confirm')).toBeInTheDocument()) + await user.click(screen.getByText('Confirm')) + + expect(defaultProps.onConfirmDelete).toHaveBeenCalledTimes(1) + }) + + it('should call handleConfirmExport when confirm on export warning', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Confirm')).toBeInTheDocument()) + await user.click(screen.getByText('Confirm')) + + expect(defaultProps.handleConfirmExport).toHaveBeenCalledTimes(1) + }) + + it('should call exportCheck when backup on importDSL modal', async () => { + const user = userEvent.setup() + await act(async () => { + render() + }) + + await waitFor(() => expect(screen.getByText('Backup')).toBeInTheDocument()) + await user.click(screen.getByText('Backup')) + + expect(defaultProps.exportCheck).toHaveBeenCalledTimes(1) + }) + + it('should call setSecretEnvList with empty array when closing DSLExportConfirmModal', async () => { + const user = userEvent.setup() + await act(async () => { + render( + , + ) + }) + + await waitFor(() => expect(screen.getByText('Close Export')).toBeInTheDocument()) + await user.click(screen.getByText('Close Export')) + + expect(defaultProps.setSecretEnvList).toHaveBeenCalledWith([]) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx new file mode 100644 index 0000000000..65d660876c --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-info-trigger.spec.tsx @@ -0,0 +1,99 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfoTrigger from '../app-info-trigger' + +vi.mock('../../../base/app-icon', () => ({ + default: ({ size, icon, background }: { + size: string + icon: string + background: string + iconType?: string + imageUrl?: string + }) => ( +
+ ), +})) + +const createAppDetail = (overrides: Partial = {}): App & Partial => ({ + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: 'A test app', + use_icon_as_answer_icon: false, + ...overrides, +} as App & Partial) + +describe('AppInfoTrigger', () => { + it('should render app icon with correct size when expanded', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'large') + }) + + it('should render app icon with small size when collapsed', () => { + render() + const icon = screen.getByTestId('app-icon') + expect(icon).toHaveAttribute('data-size', 'small') + }) + + it('should show app name when expanded', () => { + render() + expect(screen.getByText('My Chatbot')).toBeInTheDocument() + }) + + it('should not show app name when collapsed', () => { + render() + expect(screen.queryByText('My Chatbot')).not.toBeInTheDocument() + }) + + it('should show app mode label when expanded', () => { + render() + expect(screen.getByText('app.types.advanced')).toBeInTheDocument() + }) + + it('should not show mode label when collapsed', () => { + render() + expect(screen.queryByText('app.types.chatbot')).not.toBeInTheDocument() + }) + + it('should call onClick when button is clicked', async () => { + const user = userEvent.setup() + const onClick = vi.fn() + render() + + await user.click(screen.getByRole('button')) + + expect(onClick).toHaveBeenCalledTimes(1) + }) + + it('should show settings icon in expanded and collapsed states', () => { + const { container, rerender } = render( + , + ) + expect(container.querySelector('svg')).toBeInTheDocument() + + rerender() + expect(container.querySelector('svg')).toBeInTheDocument() + }) + + it('should apply ml-1 class to icon wrapper when collapsed', () => { + render( + , + ) + const iconWrapper = screen.getByTestId('app-icon').parentElement + expect(iconWrapper).toHaveClass('ml-1') + }) + + it('should not apply ml-1 class when expanded', () => { + render() + const iconWrapper = screen.getByTestId('app-icon').parentElement + expect(iconWrapper).not.toHaveClass('ml-1') + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts new file mode 100644 index 0000000000..ac4318278c --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-mode-labels.spec.ts @@ -0,0 +1,34 @@ +import type { TFunction } from 'i18next' +import { AppModeEnum } from '@/types/app' +import { getAppModeLabel } from '../app-mode-labels' + +describe('getAppModeLabel', () => { + const t: TFunction = ((key: string, options?: Record) => { + const ns = (options?.ns as string | undefined) ?? '' + return ns ? `${ns}.${key}` : key + }) as TFunction + + it('should return advanced chat label', () => { + expect(getAppModeLabel(AppModeEnum.ADVANCED_CHAT, t)).toBe('app.types.advanced') + }) + + it('should return agent chat label', () => { + expect(getAppModeLabel(AppModeEnum.AGENT_CHAT, t)).toBe('app.types.agent') + }) + + it('should return chatbot label', () => { + expect(getAppModeLabel(AppModeEnum.CHAT, t)).toBe('app.types.chatbot') + }) + + it('should return completion label', () => { + expect(getAppModeLabel(AppModeEnum.COMPLETION, t)).toBe('app.types.completion') + }) + + it('should return workflow label for unknown mode', () => { + expect(getAppModeLabel('unknown-mode', t)).toBe('app.types.workflow') + }) + + it('should return workflow label for workflow mode', () => { + expect(getAppModeLabel(AppModeEnum.WORKFLOW, t)).toBe('app.types.workflow') + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx new file mode 100644 index 0000000000..1df23c2d20 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/app-operations.spec.tsx @@ -0,0 +1,253 @@ +import type { Operation } from '../app-operations' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import AppOperations from '../app-operations' + +vi.mock('../../../base/button', () => ({ + default: ({ children, onClick, className, size, variant, id, tabIndex, ...rest }: { + 'children': React.ReactNode + 'onClick'?: () => void + 'className'?: string + 'size'?: string + 'variant'?: string + 'id'?: string + 'tabIndex'?: number + 'data-targetid'?: string + }) => ( + + ), +})) + +vi.mock('../../../base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean }) => ( +
{children}
+ ), + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children, className }: { children: React.ReactNode, className?: string }) => ( +
{children}
+ ), +})) + +const createOperation = (id: string, title: string, type?: 'divider'): Operation => ({ + id, + title, + icon: , + onClick: vi.fn(), + type, +}) + +function setupDomMeasurements(navWidth: number, moreWidth: number, childWidths: number[]) { + const originalClientWidth = Object.getOwnPropertyDescriptor(HTMLElement.prototype, 'clientWidth') + + Object.defineProperty(HTMLElement.prototype, 'clientWidth', { + configurable: true, + get(this: HTMLElement) { + if (this.getAttribute('aria-hidden') === 'true') + return navWidth + if (this.id === 'more-measure') + return moreWidth + if (this.dataset.targetid) { + const idx = Array.from(this.parentElement?.children ?? []).indexOf(this) + return childWidths[idx] ?? 50 + } + return 0 + }, + }) + + return () => { + if (originalClientWidth) + Object.defineProperty(HTMLElement.prototype, 'clientWidth', originalClientWidth) + } +} + +describe('AppOperations', () => { + beforeEach(() => { + vi.clearAllMocks() + }) + + describe('Rendering with operations prop', () => { + it('should render measurement container', () => { + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const { container } = render() + expect(container.querySelector('[aria-hidden="true"]')).toBeInTheDocument() + }) + + it('should render operation buttons in measurement container', () => { + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use operations as primary when provided', () => { + const ops = [createOperation('edit', 'Edit')] + const secondary = [createOperation('delete', 'Delete')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + }) + + describe('Rendering with primaryOperations and secondaryOperations', () => { + it('should render primary operations in measurement container', () => { + const primary = [createOperation('edit', 'Edit')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use secondary operations when provided', () => { + const primary = [createOperation('edit', 'Edit')] + const secondary = [createOperation('delete', 'Delete')] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should use empty operations array when neither operations nor primaryOperations provided', () => { + const { container } = render() + expect(container).toBeInTheDocument() + }) + }) + + describe('Overflow behavior', () => { + it('should show all operations when container is wide enough', () => { + const cleanup = setupDomMeasurements(500, 60, [80, 80]) + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + + render() + + cleanup() + }) + + it('should move operations to more menu when container is narrow', () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + + render() + + cleanup() + }) + + it('should show last item without more button if it fits alone', () => { + const cleanup = setupDomMeasurements(90, 60, [80]) + const ops = [createOperation('edit', 'Edit')] + + render() + + cleanup() + }) + }) + + describe('More button', () => { + it('should render more button text in measurement container', () => { + const ops = [createOperation('edit', 'Edit')] + render() + const moreButtons = screen.getAllByText('common.operation.more') + expect(moreButtons.length).toBeGreaterThanOrEqual(1) + }) + + it('should handle trigger more click', async () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const user = userEvent.setup() + const ops = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const secondary = [createOperation('delete', 'Delete')] + + render() + + const trigger = screen.queryByTestId('portal-trigger') + if (trigger) + await user.click(trigger) + + cleanup() + }) + }) + + describe('Visible operations click', () => { + it('should call onClick when a visible operation is clicked', async () => { + const cleanup = setupDomMeasurements(500, 60, [80, 80]) + const user = userEvent.setup() + const editOp = createOperation('edit', 'Edit') + const copyOp = createOperation('copy', 'Copy') + + render() + + const visibleButtons = screen.getAllByText('Edit') + const clickableButton = visibleButtons.find(btn => btn.closest('button')?.tabIndex !== -1) + if (clickableButton) + await user.click(clickableButton) + + cleanup() + }) + }) + + describe('Divider operations', () => { + it('should filter out divider operations from inline display', () => { + const ops = [ + createOperation('edit', 'Edit'), + createOperation('div-1', '', 'divider'), + createOperation('delete', 'Delete'), + ] + render() + const editButtons = screen.getAllByText('Edit') + expect(editButtons.length).toBeGreaterThanOrEqual(1) + }) + }) + + describe('Gap styling', () => { + it('should apply gap to measurement and visible containers', () => { + const ops = [createOperation('edit', 'Edit')] + const { container } = render() + const hiddenContainer = container.querySelector('[aria-hidden="true"]') + expect(hiddenContainer).toHaveStyle({ gap: '8px' }) + }) + + it('should apply gap to visible container', () => { + const ops = [createOperation('edit', 'Edit')] + const { container } = render() + const containers = container.querySelectorAll('div[style]') + const visibleContainer = Array.from(containers).find( + el => el.getAttribute('aria-hidden') !== 'true', + ) + if (visibleContainer) + expect(visibleContainer).toHaveStyle({ gap: '4px' }) + }) + }) + + describe('More menu content', () => { + it('should render divider items in more menu', () => { + const cleanup = setupDomMeasurements(100, 60, [80, 80]) + const primary = [createOperation('edit', 'Edit'), createOperation('copy', 'Copy')] + const secondary = [ + createOperation('divider-1', '', 'divider'), + createOperation('delete', 'Delete'), + ] + + render() + + cleanup() + }) + }) + + describe('Empty inline operations', () => { + it('should handle when all operations are dividers', () => { + const ops = [createOperation('div-1', '', 'divider'), createOperation('div-2', '', 'divider')] + const { container } = render() + expect(container).toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx b/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx new file mode 100644 index 0000000000..fc0bb56f75 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/index.spec.tsx @@ -0,0 +1,147 @@ +import type { App, AppSSO } from '@/types/app' +import { render, screen } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { AppModeEnum } from '@/types/app' +import AppInfo from '..' + +let mockIsCurrentWorkspaceEditor = true +const mockSetPanelOpen = vi.fn() + +vi.mock('@/context/app-context', () => ({ + useAppContext: () => ({ + isCurrentWorkspaceEditor: mockIsCurrentWorkspaceEditor, + }), +})) + +vi.mock('../app-info-trigger', () => ({ + default: React.memo(({ appDetail, expand, onClick }: { + appDetail: App & Partial + expand: boolean + onClick: () => void + }) => ( + + )), +})) + +vi.mock('../app-info-detail-panel', () => ({ + default: React.memo(({ show, onClose }: { show: boolean, onClose: () => void }) => ( + show ?
: null + )), +})) + +vi.mock('../app-info-modals', () => ({ + default: React.memo(({ activeModal }: { activeModal: string | null }) => ( + activeModal ?
: null + )), +})) + +const mockAppDetail: App & Partial = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + icon_url: '', + description: '', + use_icon_as_answer_icon: false, +} as App & Partial + +const mockUseAppInfoActions = { + appDetail: mockAppDetail, + panelOpen: false, + setPanelOpen: mockSetPanelOpen, + closePanel: vi.fn(), + activeModal: null as string | null, + openModal: vi.fn(), + closeModal: vi.fn(), + secretEnvList: [], + setSecretEnvList: vi.fn(), + onEdit: vi.fn(), + onCopy: vi.fn(), + onExport: vi.fn(), + exportCheck: vi.fn(), + handleConfirmExport: vi.fn(), + onConfirmDelete: vi.fn(), +} + +vi.mock('../use-app-info-actions', () => ({ + useAppInfoActions: () => mockUseAppInfoActions, +})) + +describe('AppInfo', () => { + beforeEach(() => { + vi.clearAllMocks() + mockIsCurrentWorkspaceEditor = true + mockUseAppInfoActions.appDetail = mockAppDetail + mockUseAppInfoActions.panelOpen = false + mockUseAppInfoActions.activeModal = null + }) + + it('should return null when appDetail is not available', () => { + mockUseAppInfoActions.appDetail = undefined as unknown as App & Partial + const { container } = render() + expect(container.innerHTML).toBe('') + }) + + it('should render trigger when not onlyShowDetail', () => { + render() + expect(screen.getByTestId('trigger')).toBeInTheDocument() + }) + + it('should not render trigger when onlyShowDetail is true', () => { + render() + expect(screen.queryByTestId('trigger')).not.toBeInTheDocument() + }) + + it('should pass expand prop to trigger', () => { + render() + expect(screen.getByTestId('trigger')).toHaveAttribute('data-expand', 'true') + + const { unmount } = render() + const triggers = screen.getAllByTestId('trigger') + expect(triggers[triggers.length - 1]).toHaveAttribute('data-expand', 'false') + unmount() + }) + + it('should toggle panel when trigger is clicked and user is editor', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('trigger')) + + expect(mockSetPanelOpen).toHaveBeenCalled() + const updater = mockSetPanelOpen.mock.calls[0][0] as (v: boolean) => boolean + expect(updater(false)).toBe(true) + expect(updater(true)).toBe(false) + }) + + it('should not toggle panel when trigger is clicked and user is not editor', async () => { + const user = userEvent.setup() + mockIsCurrentWorkspaceEditor = false + render() + + await user.click(screen.getByTestId('trigger')) + + expect(mockSetPanelOpen).not.toHaveBeenCalled() + }) + + it('should show detail panel based on panelOpen when not onlyShowDetail', () => { + mockUseAppInfoActions.panelOpen = true + render() + expect(screen.getByTestId('detail-panel')).toBeInTheDocument() + }) + + it('should show detail panel based on openState when onlyShowDetail', () => { + render() + expect(screen.getByTestId('detail-panel')).toBeInTheDocument() + }) + + it('should hide detail panel when openState is false and onlyShowDetail', () => { + render() + expect(screen.queryByTestId('detail-panel')).not.toBeInTheDocument() + }) +}) diff --git a/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts new file mode 100644 index 0000000000..deea28ce3e --- /dev/null +++ b/web/app/components/app-sidebar/app-info/__tests__/use-app-info-actions.spec.ts @@ -0,0 +1,492 @@ +import { act, renderHook } from '@testing-library/react' +import { AppModeEnum } from '@/types/app' +import { useAppInfoActions } from '../use-app-info-actions' + +const mockNotify = vi.fn() +const mockReplace = vi.fn() +const mockOnPlanInfoChanged = vi.fn() +const mockInvalidateAppList = vi.fn() +const mockSetAppDetail = vi.fn() +const mockUpdateAppInfo = vi.fn() +const mockCopyApp = vi.fn() +const mockExportAppConfig = vi.fn() +const mockDeleteApp = vi.fn() +const mockFetchWorkflowDraft = vi.fn() +const mockDownloadBlob = vi.fn() + +let mockAppDetail: Record | undefined = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', +} + +vi.mock('@/next/navigation', () => ({ + useRouter: () => ({ replace: mockReplace }), +})) + +vi.mock('use-context-selector', () => ({ + useContext: () => ({ notify: mockNotify }), +})) + +vi.mock('@/context/provider-context', () => ({ + useProviderContext: () => ({ onPlanInfoChanged: mockOnPlanInfoChanged }), +})) + +vi.mock('@/app/components/app/store', () => ({ + useStore: (selector: (state: Record) => unknown) => selector({ + appDetail: mockAppDetail, + setAppDetail: mockSetAppDetail, + }), +})) + +vi.mock('@/app/components/base/toast/context', () => ({ + ToastContext: {}, +})) + +vi.mock('@/service/use-apps', () => ({ + useInvalidateAppList: () => mockInvalidateAppList, +})) + +vi.mock('@/service/apps', () => ({ + updateAppInfo: (...args: unknown[]) => mockUpdateAppInfo(...args), + copyApp: (...args: unknown[]) => mockCopyApp(...args), + exportAppConfig: (...args: unknown[]) => mockExportAppConfig(...args), + deleteApp: (...args: unknown[]) => mockDeleteApp(...args), +})) + +vi.mock('@/service/workflow', () => ({ + fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args), +})) + +vi.mock('@/utils/download', () => ({ + downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args), +})) + +vi.mock('@/utils/app-redirection', () => ({ + getRedirection: vi.fn(), +})) + +vi.mock('@/config', () => ({ + NEED_REFRESH_APP_LIST_KEY: 'test-refresh-key', +})) + +describe('useAppInfoActions', () => { + beforeEach(() => { + vi.clearAllMocks() + mockAppDetail = { + id: 'app-1', + name: 'Test App', + mode: AppModeEnum.CHAT, + icon: '🤖', + icon_type: 'emoji', + icon_background: '#FFEAD5', + } + }) + + describe('Initial state', () => { + it('should return initial state correctly', () => { + const { result } = renderHook(() => useAppInfoActions({})) + expect(result.current.appDetail).toEqual(mockAppDetail) + expect(result.current.panelOpen).toBe(false) + expect(result.current.activeModal).toBeNull() + expect(result.current.secretEnvList).toEqual([]) + }) + }) + + describe('Panel management', () => { + it('should toggle panelOpen', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.setPanelOpen(true) + }) + + expect(result.current.panelOpen).toBe(true) + }) + + it('should close panel and call onDetailExpand', () => { + const onDetailExpand = vi.fn() + const { result } = renderHook(() => useAppInfoActions({ onDetailExpand })) + + act(() => { + result.current.setPanelOpen(true) + }) + + act(() => { + result.current.closePanel() + }) + + expect(result.current.panelOpen).toBe(false) + expect(onDetailExpand).toHaveBeenCalledWith(false) + }) + }) + + describe('Modal management', () => { + it('should open modal and close panel', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.setPanelOpen(true) + }) + + act(() => { + result.current.openModal('edit') + }) + + expect(result.current.activeModal).toBe('edit') + expect(result.current.panelOpen).toBe(false) + }) + + it('should close modal', () => { + const { result } = renderHook(() => useAppInfoActions({})) + + act(() => { + result.current.openModal('delete') + }) + + act(() => { + result.current.closeModal() + }) + + expect(result.current.activeModal).toBeNull() + }) + }) + + describe('onEdit', () => { + it('should update app info and close modal on success', async () => { + const updatedApp = { ...mockAppDetail, name: 'Updated' } + mockUpdateAppInfo.mockResolvedValue(updatedApp) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockUpdateAppInfo).toHaveBeenCalled() + expect(mockSetAppDetail).toHaveBeenCalledWith(updatedApp) + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.editDone' }) + }) + + it('should notify error on edit failure', async () => { + mockUpdateAppInfo.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.editFailed' }) + }) + + it('should not call updateAppInfo when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onEdit({ + name: 'Updated', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + description: '', + use_icon_as_answer_icon: false, + }) + }) + + expect(mockUpdateAppInfo).not.toHaveBeenCalled() + }) + }) + + describe('onCopy', () => { + it('should copy app and redirect on success', async () => { + const newApp = { id: 'app-2', name: 'Copy', mode: 'chat' } + mockCopyApp.mockResolvedValue(newApp) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockCopyApp).toHaveBeenCalled() + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.newApp.appCreated' }) + expect(mockOnPlanInfoChanged).toHaveBeenCalled() + }) + + it('should notify error on copy failure', async () => { + mockCopyApp.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.newApp.appCreateFailed' }) + }) + }) + + describe('onCopy - early return', () => { + it('should not call copyApp when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onCopy({ + name: 'Copy', + icon_type: 'emoji', + icon: '🤖', + icon_background: '#fff', + }) + }) + + expect(mockCopyApp).not.toHaveBeenCalled() + }) + }) + + describe('onExport', () => { + it('should export app config and trigger download', async () => { + mockExportAppConfig.mockResolvedValue({ data: 'yaml-content' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport(false) + }) + + expect(mockExportAppConfig).toHaveBeenCalledWith({ appID: 'app-1', include: false }) + expect(mockDownloadBlob).toHaveBeenCalled() + }) + + it('should notify error on export failure', async () => { + mockExportAppConfig.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport() + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + + describe('onExport - early return', () => { + it('should not export when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onExport() + }) + + expect(mockExportAppConfig).not.toHaveBeenCalled() + }) + }) + + describe('exportCheck', () => { + it('should call onExport directly for non-workflow modes', async () => { + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + + it('should open export warning modal for workflow mode', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(result.current.activeModal).toBe('exportWarning') + }) + + it('should open export warning modal for advanced_chat mode', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.ADVANCED_CHAT } + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(result.current.activeModal).toBe('exportWarning') + }) + }) + + describe('exportCheck - early return', () => { + it('should not do anything when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.exportCheck() + }) + + expect(mockExportAppConfig).not.toHaveBeenCalled() + }) + }) + + describe('handleConfirmExport', () => { + it('should export directly when no secret env variables', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: [{ value_type: 'string' }], + }) + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + + it('should set secret env list when secret variables exist', async () => { + mockAppDetail = { ...mockAppDetail, mode: AppModeEnum.WORKFLOW } + const secretVars = [{ value_type: 'secret', key: 'API_KEY' }] + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: secretVars, + }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(result.current.secretEnvList).toEqual(secretVars) + }) + + it('should notify error on workflow draft fetch failure', async () => { + mockFetchWorkflowDraft.mockRejectedValue(new Error('fail')) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockNotify).toHaveBeenCalledWith({ type: 'error', message: 'app.exportFailed' }) + }) + }) + + describe('handleConfirmExport - early return', () => { + it('should not do anything when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockFetchWorkflowDraft).not.toHaveBeenCalled() + }) + }) + + describe('handleConfirmExport - with environment variables', () => { + it('should handle empty environment_variables', async () => { + mockFetchWorkflowDraft.mockResolvedValue({ + environment_variables: undefined, + }) + mockExportAppConfig.mockResolvedValue({ data: 'yaml' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.handleConfirmExport() + }) + + expect(mockExportAppConfig).toHaveBeenCalled() + }) + }) + + describe('onConfirmDelete', () => { + it('should delete app and redirect on success', async () => { + mockDeleteApp.mockResolvedValue({}) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockDeleteApp).toHaveBeenCalledWith('app-1') + expect(mockNotify).toHaveBeenCalledWith({ type: 'success', message: 'app.appDeleted' }) + expect(mockInvalidateAppList).toHaveBeenCalled() + expect(mockReplace).toHaveBeenCalledWith('/apps') + expect(mockSetAppDetail).toHaveBeenCalledWith() + }) + + it('should not delete when appDetail is undefined', async () => { + mockAppDetail = undefined + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockDeleteApp).not.toHaveBeenCalled() + }) + + it('should notify error on delete failure', async () => { + mockDeleteApp.mockRejectedValue({ message: 'cannot delete' }) + + const { result } = renderHook(() => useAppInfoActions({})) + + await act(async () => { + await result.current.onConfirmDelete() + }) + + expect(mockNotify).toHaveBeenCalledWith({ + type: 'error', + message: expect.stringContaining('app.appDeleteFailed'), + }) + }) + }) +}) diff --git a/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx new file mode 100644 index 0000000000..70dcb8df70 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-detail-panel.tsx @@ -0,0 +1,151 @@ +import type { Operation } from './app-operations' +import type { AppInfoModalType } from './use-app-info-actions' +import type { App, AppSSO } from '@/types/app' +import { + RiDeleteBinLine, + RiEditLine, + RiExchange2Line, + RiFileCopy2Line, + RiFileDownloadLine, + RiFileUploadLine, +} from '@remixicon/react' +import * as React from 'react' +import { useMemo } from 'react' +import { useTranslation } from 'react-i18next' +import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/card-view' +import Button from '@/app/components/base/button' +import ContentDialog from '@/app/components/base/content-dialog' +import { AppModeEnum } from '@/types/app' +import AppIcon from '../../base/app-icon' +import { getAppModeLabel } from './app-mode-labels' +import AppOperations from './app-operations' + +type AppInfoDetailPanelProps = { + appDetail: App & Partial + show: boolean + onClose: () => void + openModal: (modal: Exclude) => void + exportCheck: () => void +} + +const AppInfoDetailPanel = ({ + appDetail, + show, + onClose, + openModal, + exportCheck, +}: AppInfoDetailPanelProps) => { + const { t } = useTranslation() + + const primaryOperations = useMemo(() => [ + { + id: 'edit', + title: t('editApp', { ns: 'app' }), + icon: , + onClick: () => openModal('edit'), + }, + { + id: 'duplicate', + title: t('duplicate', { ns: 'app' }), + icon: , + onClick: () => openModal('duplicate'), + }, + { + id: 'export', + title: t('export', { ns: 'app' }), + icon: , + onClick: exportCheck, + }, + ], [t, openModal, exportCheck]) + + const secondaryOperations = useMemo(() => [ + ...(appDetail.mode === AppModeEnum.ADVANCED_CHAT || appDetail.mode === AppModeEnum.WORKFLOW) + ? [{ + id: 'import', + title: t('common.importDSL', { ns: 'workflow' }), + icon: , + onClick: () => openModal('importDSL'), + }] + : [], + { + id: 'divider-1', + title: '', + icon: <>, + onClick: () => {}, + type: 'divider' as const, + }, + { + id: 'delete', + title: t('operation.delete', { ns: 'common' }), + icon: , + onClick: () => openModal('delete'), + }, + ], [appDetail.mode, t, openModal]) + + const switchOperation = useMemo(() => { + if (appDetail.mode !== AppModeEnum.COMPLETION && appDetail.mode !== AppModeEnum.CHAT) + return null + return { + id: 'switch', + title: t('switch', { ns: 'app' }), + icon: , + onClick: () => openModal('switch'), + } + }, [appDetail.mode, t, openModal]) + + return ( + +
+
+ +
+
{appDetail.name}
+
+ {getAppModeLabel(appDetail.mode, t)} +
+
+
+ {appDetail.description && ( +
+ {appDetail.description} +
+ )} + +
+ + {switchOperation && ( +
+ +
+ )} +
+ ) +} + +export default React.memo(AppInfoDetailPanel) diff --git a/web/app/components/app-sidebar/app-info/app-info-modals.tsx b/web/app/components/app-sidebar/app-info/app-info-modals.tsx new file mode 100644 index 0000000000..6b76be87bb --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-modals.tsx @@ -0,0 +1,132 @@ +import type { AppInfoModalType } from './use-app-info-actions' +import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' +import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import type { App, AppSSO } from '@/types/app' +import * as React from 'react' +import { useState } from 'react' +import { useTranslation } from 'react-i18next' +import dynamic from '@/next/dynamic' + +const SwitchAppModal = dynamic(() => import('@/app/components/app/switch-app-modal'), { ssr: false }) +const CreateAppModal = dynamic(() => import('@/app/components/explore/create-app-modal'), { ssr: false }) +const DuplicateAppModal = dynamic(() => import('@/app/components/app/duplicate-modal'), { ssr: false }) +const Confirm = dynamic(() => import('@/app/components/base/confirm'), { ssr: false }) +const UpdateDSLModal = dynamic(() => import('@/app/components/workflow/update-dsl-modal'), { ssr: false }) +const DSLExportConfirmModal = dynamic(() => import('@/app/components/workflow/dsl-export-confirm-modal'), { ssr: false }) + +type AppInfoModalsProps = { + appDetail: App & Partial + activeModal: AppInfoModalType + closeModal: () => void + secretEnvList: EnvironmentVariable[] + setSecretEnvList: (list: EnvironmentVariable[]) => void + onEdit: CreateAppModalProps['onConfirm'] + onCopy: DuplicateAppModalProps['onConfirm'] + onExport: (include?: boolean) => Promise + exportCheck: () => void + handleConfirmExport: () => void + onConfirmDelete: () => void +} + +const AppInfoModals = ({ + appDetail, + activeModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, +}: AppInfoModalsProps) => { + const { t } = useTranslation() + const [confirmDeleteInput, setConfirmDeleteInput] = useState('') + + return ( + <> + {activeModal === 'switch' && ( + + )} + {activeModal === 'edit' && ( + + )} + {activeModal === 'duplicate' && ( + + )} + {activeModal === 'delete' && ( + { + setConfirmDeleteInput('') + closeModal() + }} + /> + )} + {activeModal === 'importDSL' && ( + + )} + {activeModal === 'exportWarning' && ( + + )} + {secretEnvList.length > 0 && ( + setSecretEnvList([])} + /> + )} + + ) +} + +export default React.memo(AppInfoModals) diff --git a/web/app/components/app-sidebar/app-info/app-info-trigger.tsx b/web/app/components/app-sidebar/app-info/app-info-trigger.tsx new file mode 100644 index 0000000000..07a41124e3 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-info-trigger.tsx @@ -0,0 +1,67 @@ +import type { App, AppSSO } from '@/types/app' +import { RiEqualizer2Line } from '@remixicon/react' +import * as React from 'react' +import { useTranslation } from 'react-i18next' +import { cn } from '@/utils/classnames' +import AppIcon from '../../base/app-icon' +import { getAppModeLabel } from './app-mode-labels' + +type AppInfoTriggerProps = { + appDetail: App & Partial + expand: boolean + onClick: () => void +} + +const AppInfoTrigger = ({ appDetail, expand, onClick }: AppInfoTriggerProps) => { + const { t } = useTranslation() + const modeLabel = getAppModeLabel(appDetail.mode, t) + + return ( + + ) +} + +export default React.memo(AppInfoTrigger) diff --git a/web/app/components/app-sidebar/app-info/app-mode-labels.ts b/web/app/components/app-sidebar/app-info/app-mode-labels.ts new file mode 100644 index 0000000000..1d72feb089 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/app-mode-labels.ts @@ -0,0 +1,17 @@ +import type { TFunction } from 'i18next' +import { AppModeEnum } from '@/types/app' + +export function getAppModeLabel(mode: string, t: TFunction): string { + switch (mode) { + case AppModeEnum.ADVANCED_CHAT: + return t('types.advanced', { ns: 'app' }) + case AppModeEnum.AGENT_CHAT: + return t('types.agent', { ns: 'app' }) + case AppModeEnum.CHAT: + return t('types.chatbot', { ns: 'app' }) + case AppModeEnum.COMPLETION: + return t('types.completion', { ns: 'app' }) + default: + return t('types.workflow', { ns: 'app' }) + } +} diff --git a/web/app/components/app-sidebar/app-operations.tsx b/web/app/components/app-sidebar/app-info/app-operations.tsx similarity index 93% rename from web/app/components/app-sidebar/app-operations.tsx rename to web/app/components/app-sidebar/app-info/app-operations.tsx index 871d8a19e8..78dd6f0043 100644 --- a/web/app/components/app-sidebar/app-operations.tsx +++ b/web/app/components/app-sidebar/app-info/app-operations.tsx @@ -3,7 +3,7 @@ import { RiMoreLine } from '@remixicon/react' import { cloneElement, useCallback, useEffect, useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' -import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../base/portal-to-follow-elem' +import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../../base/portal-to-follow-elem' export type Operation = { id: string @@ -134,7 +134,7 @@ const AppOperations = ({ tabIndex={-1} > {cloneElement(operation.icon, { className: 'h-3.5 w-3.5 text-components-button-secondary-text' })} - + {operation.title} @@ -147,7 +147,7 @@ const AppOperations = ({ tabIndex={-1} > - + {t('operation.more', { ns: 'common' })} @@ -163,7 +163,7 @@ const AppOperations = ({ onClick={operation.onClick} > {cloneElement(operation.icon, { className: 'h-3.5 w-3.5 text-components-button-secondary-text' })} - + {operation.title} @@ -182,7 +182,7 @@ const AppOperations = ({ className="gap-[1px]" > - + {t('operation.more', { ns: 'common' })} @@ -200,7 +200,7 @@ const AppOperations = ({ onClick={item.onClick} > {cloneElement(item.icon, { className: 'h-4 w-4 text-text-tertiary' })} - {item.title} + {item.title}
))}
diff --git a/web/app/components/app-sidebar/app-info/index.tsx b/web/app/components/app-sidebar/app-info/index.tsx new file mode 100644 index 0000000000..2530add2dc --- /dev/null +++ b/web/app/components/app-sidebar/app-info/index.tsx @@ -0,0 +1,75 @@ +import * as React from 'react' +import { useAppContext } from '@/context/app-context' +import AppInfoDetailPanel from './app-info-detail-panel' +import AppInfoModals from './app-info-modals' +import AppInfoTrigger from './app-info-trigger' +import { useAppInfoActions } from './use-app-info-actions' + +export type IAppInfoProps = { + expand: boolean + onlyShowDetail?: boolean + openState?: boolean + onDetailExpand?: (expand: boolean) => void +} + +const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailExpand }: IAppInfoProps) => { + const { isCurrentWorkspaceEditor } = useAppContext() + + const { + appDetail, + panelOpen, + setPanelOpen, + closePanel, + activeModal, + openModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, + } = useAppInfoActions({ onDetailExpand }) + + if (!appDetail) + return null + + return ( +
+ {!onlyShowDetail && ( + { + if (isCurrentWorkspaceEditor) + setPanelOpen(v => !v) + }} + /> + )} + + +
+ ) +} + +export default React.memo(AppInfo) diff --git a/web/app/components/app-sidebar/app-info/use-app-info-actions.ts b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts new file mode 100644 index 0000000000..55ec13e506 --- /dev/null +++ b/web/app/components/app-sidebar/app-info/use-app-info-actions.ts @@ -0,0 +1,189 @@ +import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' +import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' +import type { EnvironmentVariable } from '@/app/components/workflow/types' +import { useCallback, useState } from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' +import { useStore as useAppStore } from '@/app/components/app/store' +import { ToastContext } from '@/app/components/base/toast/context' +import { NEED_REFRESH_APP_LIST_KEY } from '@/config' +import { useProviderContext } from '@/context/provider-context' +import { useRouter } from '@/next/navigation' +import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' +import { useInvalidateAppList } from '@/service/use-apps' +import { fetchWorkflowDraft } from '@/service/workflow' +import { AppModeEnum } from '@/types/app' +import { getRedirection } from '@/utils/app-redirection' +import { downloadBlob } from '@/utils/download' + +export type AppInfoModalType = 'edit' | 'duplicate' | 'delete' | 'switch' | 'importDSL' | 'exportWarning' | null + +type UseAppInfoActionsParams = { + onDetailExpand?: (expand: boolean) => void +} + +export function useAppInfoActions({ onDetailExpand }: UseAppInfoActionsParams) { + const { t } = useTranslation() + const { notify } = useContext(ToastContext) + const { replace } = useRouter() + const { onPlanInfoChanged } = useProviderContext() + const appDetail = useAppStore(state => state.appDetail) + const setAppDetail = useAppStore(state => state.setAppDetail) + const invalidateAppList = useInvalidateAppList() + + const [panelOpen, setPanelOpen] = useState(false) + const [activeModal, setActiveModal] = useState(null) + const [secretEnvList, setSecretEnvList] = useState([]) + + const closePanel = useCallback(() => { + setPanelOpen(false) + onDetailExpand?.(false) + }, [onDetailExpand]) + + const openModal = useCallback((modal: Exclude) => { + closePanel() + setActiveModal(modal) + }, [closePanel]) + + const closeModal = useCallback(() => { + setActiveModal(null) + }, []) + + const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ + name, + icon_type, + icon, + icon_background, + description, + use_icon_as_answer_icon, + max_active_requests, + }) => { + if (!appDetail) + return + try { + const app = await updateAppInfo({ + appID: appDetail.id, + name, + icon_type, + icon, + icon_background, + description, + use_icon_as_answer_icon, + max_active_requests, + }) + closeModal() + notify({ type: 'success', message: t('editDone', { ns: 'app' }) }) + setAppDetail(app) + } + catch { + notify({ type: 'error', message: t('editFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, setAppDetail, t]) + + const onCopy: DuplicateAppModalProps['onConfirm'] = useCallback(async ({ + name, + icon_type, + icon, + icon_background, + }) => { + if (!appDetail) + return + try { + const newApp = await copyApp({ + appID: appDetail.id, + name, + icon_type, + icon, + icon_background, + mode: appDetail.mode, + }) + closeModal() + notify({ type: 'success', message: t('newApp.appCreated', { ns: 'app' }) }) + localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') + onPlanInfoChanged() + getRedirection(true, newApp, replace) + } + catch { + notify({ type: 'error', message: t('newApp.appCreateFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, onPlanInfoChanged, replace, t]) + + const onExport = useCallback(async (include = false) => { + if (!appDetail) + return + try { + const { data } = await exportAppConfig({ appID: appDetail.id, include }) + const file = new Blob([data], { type: 'application/yaml' }) + downloadBlob({ data: file, fileName: `${appDetail.name}.yml` }) + } + catch { + notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + }, [appDetail, notify, t]) + + const exportCheck = useCallback(async () => { + if (!appDetail) + return + if (appDetail.mode !== AppModeEnum.WORKFLOW && appDetail.mode !== AppModeEnum.ADVANCED_CHAT) { + onExport() + return + } + setActiveModal('exportWarning') + }, [appDetail, onExport]) + + const handleConfirmExport = useCallback(async () => { + if (!appDetail) + return + closeModal() + try { + const workflowDraft = await fetchWorkflowDraft(`/apps/${appDetail.id}/workflows/draft`) + const list = (workflowDraft.environment_variables || []).filter(env => env.value_type === 'secret') + if (list.length === 0) { + onExport() + return + } + setSecretEnvList(list) + } + catch { + notify({ type: 'error', message: t('exportFailed', { ns: 'app' }) }) + } + }, [appDetail, closeModal, notify, onExport, t]) + + const onConfirmDelete = useCallback(async () => { + if (!appDetail) + return + try { + await deleteApp(appDetail.id) + notify({ type: 'success', message: t('appDeleted', { ns: 'app' }) }) + invalidateAppList() + onPlanInfoChanged() + setAppDetail() + replace('/apps') + } + catch (e: unknown) { + notify({ + type: 'error', + message: `${t('appDeleteFailed', { ns: 'app' })}${e instanceof Error && e.message ? `: ${e.message}` : ''}`, + }) + } + closeModal() + }, [appDetail, closeModal, invalidateAppList, notify, onPlanInfoChanged, replace, setAppDetail, t]) + + return { + appDetail, + panelOpen, + setPanelOpen, + closePanel, + activeModal, + openModal, + closeModal, + secretEnvList, + setSecretEnvList, + onEdit, + onCopy, + onExport, + exportCheck, + handleConfirmExport, + onConfirmDelete, + } +} diff --git a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx index 9ed86cbf32..87632ba647 100644 --- a/web/app/components/app-sidebar/app-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/app-sidebar-dropdown.tsx @@ -1,4 +1,4 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import { RiEqualizer2Line, RiMenuLine, @@ -13,12 +13,12 @@ import { PortalToFollowElemTrigger, } from '@/app/components/base/portal-to-follow-elem' import { useAppContext } from '@/context/app-context' -import { AppModeEnum } from '@/types/app' import { cn } from '@/utils/classnames' import AppIcon from '../base/app-icon' import Divider from '../base/divider' import AppInfo from './app-info' -import NavLink from './navLink' +import { getAppModeLabel } from './app-info/app-mode-labels' +import NavLink from './nav-link' type Props = { navigation: Array<{ @@ -97,9 +97,9 @@ const AppSidebarDropdown = ({ navigation }: Props) => {
-
{appDetail.name}
+
{appDetail.name}
-
{appDetail.mode === AppModeEnum.ADVANCED_CHAT ? t('types.advanced', { ns: 'app' }) : appDetail.mode === AppModeEnum.AGENT_CHAT ? t('types.agent', { ns: 'app' }) : appDetail.mode === AppModeEnum.CHAT ? t('types.chatbot', { ns: 'app' }) : appDetail.mode === AppModeEnum.COMPLETION ? t('types.completion', { ns: 'app' }) : t('types.workflow', { ns: 'app' })}
+
{getAppModeLabel(appDetail.mode, t)}
diff --git a/web/app/components/app-sidebar/basic.tsx b/web/app/components/app-sidebar/basic.tsx index f5f24f4938..5d4d2e5d1e 100644 --- a/web/app/components/app-sidebar/basic.tsx +++ b/web/app/components/app-sidebar/basic.tsx @@ -76,7 +76,7 @@ export default function AppBasic({ icon, icon_background, name, isExternal, type )} {mode === 'expand' && (
-
+
{name}
@@ -95,10 +95,10 @@ export default function AppBasic({ icon, icon_background, name, isExternal, type )}
{!hideType && isExtraInLine && ( -
{type}
+
{type}
)} {!hideType && !isExtraInLine && ( -
{isExternal ? t('externalTag', { ns: 'dataset' }) : type}
+
{isExternal ? t('externalTag', { ns: 'dataset' }) : type}
)}
)} diff --git a/web/app/components/app-sidebar/completion.png b/web/app/components/app-sidebar/completion.png deleted file mode 100644 index 7a3cbd5107..0000000000 Binary files a/web/app/components/app-sidebar/completion.png and /dev/null differ diff --git a/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx new file mode 100644 index 0000000000..1df6fa79b7 --- /dev/null +++ b/web/app/components/app-sidebar/dataset-info/__tests__/dropdown-callbacks.spec.tsx @@ -0,0 +1,228 @@ +import type { DataSet } from '@/models/datasets' +import { render, screen, waitFor } from '@testing-library/react' +import userEvent from '@testing-library/user-event' +import * as React from 'react' +import { + ChunkingMode, + DatasetPermission, + DataSourceType, +} from '@/models/datasets' +import { RETRIEVE_METHOD } from '@/types/app' +import Dropdown from '../dropdown' + +let mockDataset: DataSet +let mockIsDatasetOperator = false +const mockReplace = vi.fn() +const mockInvalidDatasetList = vi.fn() +const mockInvalidDatasetDetail = vi.fn() +const mockExportPipeline = vi.fn() +const mockCheckIsUsedInApp = vi.fn() +const mockDeleteDataset = vi.fn() + +const createDataset = (overrides: Partial = {}): DataSet => ({ + id: 'dataset-1', + name: 'Dataset Name', + indexing_status: 'completed', + icon_info: { + icon: '📙', + icon_background: '#FFF4ED', + icon_type: 'emoji', + icon_url: '', + }, + description: 'Dataset description', + permission: DatasetPermission.onlyMe, + data_source_type: DataSourceType.FILE, + indexing_technique: 'high_quality' as DataSet['indexing_technique'], + created_by: 'user-1', + updated_by: 'user-1', + updated_at: 1690000000, + app_count: 0, + doc_form: ChunkingMode.text, + document_count: 1, + total_document_count: 1, + word_count: 1000, + provider: 'internal', + embedding_model: 'text-embedding-3', + embedding_model_provider: 'openai', + embedding_available: true, + retrieval_model_dict: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + retrieval_model: { + search_method: RETRIEVE_METHOD.semantic, + reranking_enable: false, + reranking_model: { reranking_provider_name: '', reranking_model_name: '' }, + top_k: 5, + score_threshold_enabled: false, + score_threshold: 0, + }, + tags: [], + external_knowledge_info: { + external_knowledge_id: '', + external_knowledge_api_id: '', + external_knowledge_api_name: '', + external_knowledge_api_endpoint: '', + }, + external_retrieval_model: { + top_k: 0, + score_threshold: 0, + score_threshold_enabled: false, + }, + built_in_field_enabled: false, + runtime_mode: 'rag_pipeline', + enable_api: false, + is_multimodal: false, + ...overrides, +}) + +vi.mock('@/next/navigation', () => ({ + useRouter: () => ({ replace: mockReplace }), +})) + +vi.mock('@/context/dataset-detail', () => ({ + useDatasetDetailContextWithSelector: (selector: (state: { dataset?: DataSet }) => unknown) => selector({ dataset: mockDataset }), +})) + +vi.mock('@/context/app-context', () => ({ + useSelector: (selector: (state: { isCurrentWorkspaceDatasetOperator: boolean }) => unknown) => + selector({ isCurrentWorkspaceDatasetOperator: mockIsDatasetOperator }), +})) + +vi.mock('@/service/knowledge/use-dataset', () => ({ + datasetDetailQueryKeyPrefix: ['dataset', 'detail'], + useInvalidDatasetList: () => mockInvalidDatasetList, +})) + +vi.mock('@/service/use-base', () => ({ + useInvalid: () => mockInvalidDatasetDetail, +})) + +vi.mock('@/service/use-pipeline', () => ({ + useExportPipelineDSL: () => ({ mutateAsync: mockExportPipeline }), +})) + +vi.mock('@/service/datasets', () => ({ + checkIsUsedInApp: (...args: unknown[]) => mockCheckIsUsedInApp(...args), + deleteDataset: (...args: unknown[]) => mockDeleteDataset(...args), +})) + +vi.mock('@/app/components/datasets/rename-modal', () => ({ + default: ({ + show, + onClose, + onSuccess, + }: { + show: boolean + onClose: () => void + onSuccess?: () => void + }) => { + if (!show) + return null + return ( +
+ + +
+ ) + }, +})) + +vi.mock('@/app/components/base/confirm', () => ({ + default: ({ + isShow, + onConfirm, + onCancel, + title, + content, + }: { + isShow: boolean + onConfirm: () => void + onCancel: () => void + title: string + content: string + }) => { + if (!isShow) + return null + return ( +
+ {title} + {content} + + +
+ ) + }, +})) + +vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ + PortalToFollowElem: ({ children }: { children: React.ReactNode }) =>
{children}
, + PortalToFollowElemTrigger: ({ children, onClick }: { children: React.ReactNode, onClick?: () => void }) => ( +
{children}
+ ), + PortalToFollowElemContent: ({ children }: { children: React.ReactNode }) =>
{children}
, +})) + +describe('Dropdown callback coverage', () => { + beforeEach(() => { + vi.clearAllMocks() + mockDataset = createDataset({ pipeline_id: 'pipeline-1', runtime_mode: 'rag_pipeline' }) + mockIsDatasetOperator = false + mockExportPipeline.mockResolvedValue({ data: 'pipeline-content' }) + mockCheckIsUsedInApp.mockResolvedValue({ is_using: false }) + mockDeleteDataset.mockResolvedValue({}) + }) + + it('should call refreshDataset when rename succeeds', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.edit')) + + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + await user.click(screen.getByText('Success')) + + await waitFor(() => { + expect(mockInvalidDatasetList).toHaveBeenCalled() + expect(mockInvalidDatasetDetail).toHaveBeenCalled() + }) + }) + + it('should close rename modal when onClose is called', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.edit')) + + expect(screen.getByTestId('rename-modal')).toBeInTheDocument() + await user.click(screen.getByText('Close')) + + await waitFor(() => { + expect(screen.queryByTestId('rename-modal')).not.toBeInTheDocument() + }) + }) + + it('should close confirm dialog when cancel is clicked', async () => { + const user = userEvent.setup() + render() + + await user.click(screen.getByTestId('portal-trigger')) + await user.click(screen.getByText('common.operation.delete')) + + await waitFor(() => { + expect(screen.getByTestId('confirm-dialog')).toBeInTheDocument() + }) + + await user.click(screen.getByText('cancel')) + + await waitFor(() => { + expect(screen.queryByTestId('confirm-dialog')).not.toBeInTheDocument() + }) + }) +}) diff --git a/web/app/components/app-sidebar/dataset-info/index.spec.tsx b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx similarity index 98% rename from web/app/components/app-sidebar/dataset-info/index.spec.tsx rename to web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx index 9996ef2b4d..a1e275d731 100644 --- a/web/app/components/app-sidebar/dataset-info/index.spec.tsx +++ b/web/app/components/app-sidebar/dataset-info/__tests__/index.spec.tsx @@ -9,10 +9,10 @@ import { DataSourceType, } from '@/models/datasets' import { RETRIEVE_METHOD } from '@/types/app' -import Dropdown from './dropdown' -import DatasetInfo from './index' -import Menu from './menu' -import MenuItem from './menu-item' +import DatasetInfo from '..' +import Dropdown from '../dropdown' +import Menu from '../menu' +import MenuItem from '../menu-item' let mockDataset: DataSet let mockIsDatasetOperator = false @@ -90,7 +90,7 @@ const createDataset = (overrides: Partial = {}): DataSet => ({ ...overrides, }) -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ replace: mockReplace, }), diff --git a/web/app/components/app-sidebar/dataset-info/dropdown.tsx b/web/app/components/app-sidebar/dataset-info/dropdown.tsx index 96127c4210..528bac831f 100644 --- a/web/app/components/app-sidebar/dataset-info/dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-info/dropdown.tsx @@ -1,11 +1,11 @@ import type { DataSet } from '@/models/datasets' import { RiMoreFill } from '@remixicon/react' -import { useRouter } from 'next/navigation' import * as React from 'react' import { useCallback, useState } from 'react' import { useTranslation } from 'react-i18next' import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' +import { useRouter } from '@/next/navigation' import { checkIsUsedInApp, deleteDataset } from '@/service/datasets' import { datasetDetailQueryKeyPrefix, useInvalidDatasetList } from '@/service/knowledge/use-dataset' import { useInvalid } from '@/service/use-base' diff --git a/web/app/components/app-sidebar/dataset-info/index.tsx b/web/app/components/app-sidebar/dataset-info/index.tsx index ba82099b6c..e46c4e4085 100644 --- a/web/app/components/app-sidebar/dataset-info/index.tsx +++ b/web/app/components/app-sidebar/dataset-info/index.tsx @@ -64,12 +64,12 @@ const DatasetInfo: FC = ({ {expand && (
{dataset.name}
-
+
{isExternalProvider && t('externalTag', { ns: 'dataset' })} {!!(!isExternalProvider && isPipelinePublished && dataset.doc_form && dataset.indexing_technique) && (
@@ -79,7 +79,7 @@ const DatasetInfo: FC = ({ )}
{!!dataset.description && ( -

+

{dataset.description}

)} diff --git a/web/app/components/app-sidebar/dataset-info/menu-item.tsx b/web/app/components/app-sidebar/dataset-info/menu-item.tsx index 441482283b..7ad8d9407f 100644 --- a/web/app/components/app-sidebar/dataset-info/menu-item.tsx +++ b/web/app/components/app-sidebar/dataset-info/menu-item.tsx @@ -22,7 +22,7 @@ const MenuItem = ({ }} > - {name} + {name}
) } diff --git a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx index c81125e973..5beea54ab0 100644 --- a/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx +++ b/web/app/components/app-sidebar/dataset-sidebar-dropdown.tsx @@ -1,4 +1,4 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import type { DataSet } from '@/models/datasets' import { RiMenuLine, @@ -21,7 +21,7 @@ import Divider from '../base/divider' import Effect from '../base/effect' import ExtraInfo from '../datasets/extra-info' import Dropdown from './dataset-info/dropdown' -import NavLink from './navLink' +import NavLink from './nav-link' type DatasetSidebarDropdownProps = { navigation: Array<{ @@ -107,12 +107,12 @@ const DatasetSidebarDropdown = ({
{dataset.name}
-
+
{isExternalProvider && t('externalTag', { ns: 'dataset' })} {!!(!isExternalProvider && dataset.doc_form && dataset.indexing_technique) && (
@@ -123,7 +123,7 @@ const DatasetSidebarDropdown = ({
{!!dataset.description && ( -

+

{dataset.description}

)} diff --git a/web/app/components/app-sidebar/expert.png b/web/app/components/app-sidebar/expert.png deleted file mode 100644 index ba941a5865..0000000000 Binary files a/web/app/components/app-sidebar/expert.png and /dev/null differ diff --git a/web/app/components/app-sidebar/index.tsx b/web/app/components/app-sidebar/index.tsx index afc6bd0f13..13fde97f89 100644 --- a/web/app/components/app-sidebar/index.tsx +++ b/web/app/components/app-sidebar/index.tsx @@ -1,12 +1,12 @@ -import type { NavIcon } from './navLink' +import type { NavIcon } from './nav-link' import { useHover, useKeyPress } from 'ahooks' -import { usePathname } from 'next/navigation' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useShallow } from 'zustand/react/shallow' import { useStore as useAppStore } from '@/app/components/app/store' import { useEventEmitterContextContext } from '@/context/event-emitter' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' +import { usePathname } from '@/next/navigation' import { cn } from '@/utils/classnames' import Divider from '../base/divider' import { getKeyboardKeyCodeBySystem } from '../workflow/utils' @@ -14,7 +14,7 @@ import AppInfo from './app-info' import AppSidebarDropdown from './app-sidebar-dropdown' import DatasetInfo from './dataset-info' import DatasetSidebarDropdown from './dataset-sidebar-dropdown' -import NavLink from './navLink' +import NavLink from './nav-link' import ToggleButton from './toggle-button' export type IAppDetailNavProps = { diff --git a/web/app/components/app-sidebar/navLink.spec.tsx b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx similarity index 97% rename from web/app/components/app-sidebar/navLink.spec.tsx rename to web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx index 62ef553386..fe46290002 100644 --- a/web/app/components/app-sidebar/navLink.spec.tsx +++ b/web/app/components/app-sidebar/nav-link/__tests__/index.spec.tsx @@ -1,16 +1,16 @@ -import type { NavLinkProps } from './navLink' +import type { NavLinkProps } from '..' import { render, screen } from '@testing-library/react' import * as React from 'react' -import NavLink from './navLink' +import NavLink from '..' // Mock Next.js navigation -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useSelectedLayoutSegment: () => 'overview', })) // Mock Next.js Link component -vi.mock('next/link', () => ({ - default: function MockLink({ children, href, className, title }: any) { +vi.mock('@/next/link', () => ({ + default: function MockLink({ children, href, className, title }: { children: React.ReactNode, href: string, className?: string, title?: string }) { return ( {children} diff --git a/web/app/components/app-sidebar/navLink.tsx b/web/app/components/app-sidebar/nav-link/index.tsx similarity index 80% rename from web/app/components/app-sidebar/navLink.tsx rename to web/app/components/app-sidebar/nav-link/index.tsx index 9d5e319046..cf986a7407 100644 --- a/web/app/components/app-sidebar/navLink.tsx +++ b/web/app/components/app-sidebar/nav-link/index.tsx @@ -1,8 +1,8 @@ 'use client' import type { RemixiconComponentType } from '@remixicon/react' -import Link from 'next/link' -import { useSelectedLayoutSegment } from 'next/navigation' import * as React from 'react' +import Link from '@/next/link' +import { useSelectedLayoutSegment } from '@/next/navigation' import { cn } from '@/utils/classnames' export type NavIcon = React.ComponentType< @@ -54,7 +54,7 @@ const NavLink = ({ key={name} type="button" disabled - className={cn('system-sm-medium flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 hover:bg-components-menu-item-bg-hover', 'pl-3 pr-1')} + className={cn('flex h-8 cursor-not-allowed items-center rounded-lg text-components-menu-item-text opacity-30 system-sm-medium hover:bg-components-menu-item-bg-hover', 'pl-3 pr-1')} title={mode === 'collapse' ? name : ''} aria-disabled > @@ -75,8 +75,8 @@ const NavLink = ({ key={name} href={href} className={cn(isActive - ? 'system-sm-semibold border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only' - : 'system-sm-medium text-components-menu-item-text hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover', 'flex h-8 items-center rounded-lg pl-3 pr-1')} + ? 'border-b-[0.25px] border-l-[0.75px] border-r-[0.25px] border-t-[0.75px] border-effects-highlight-lightmode-off bg-components-menu-item-bg-active text-text-accent-light-mode-only system-sm-semibold' + : 'text-components-menu-item-text system-sm-medium hover:bg-components-menu-item-bg-hover hover:text-components-menu-item-text-hover', 'flex h-8 items-center rounded-lg pl-3 pr-1')} title={mode === 'collapse' ? name : ''} > {renderIcon()} diff --git a/web/app/components/app-sidebar/style.module.css b/web/app/components/app-sidebar/style.module.css deleted file mode 100644 index ca0978b760..0000000000 --- a/web/app/components/app-sidebar/style.module.css +++ /dev/null @@ -1,11 +0,0 @@ -.sidebar { - border-right: 1px solid #F3F4F6; -} - -.completionPic { -background-image: url('./completion.png') -} - -.expertPic { -background-image: url('./expert.png') -} diff --git a/web/app/components/app-sidebar/toggle-button.tsx b/web/app/components/app-sidebar/toggle-button.tsx index cbfbeee452..811e979b89 100644 --- a/web/app/components/app-sidebar/toggle-button.tsx +++ b/web/app/components/app-sidebar/toggle-button.tsx @@ -19,7 +19,7 @@ const TooltipContent = ({ return (
- {expand ? t('sidebar.collapseSidebar', { ns: 'layout' }) : t('sidebar.expandSidebar', { ns: 'layout' })} + {expand ? t('sidebar.collapseSidebar', { ns: 'layout' }) : t('sidebar.expandSidebar', { ns: 'layout' })}
) diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx index 6a67ba3207..55f5ee0564 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.spec.tsx @@ -1,7 +1,7 @@ import type { Props } from './csv-uploader' import { fireEvent, render, screen, waitFor } from '@testing-library/react' import * as React from 'react' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import CSVUploader from './csv-uploader' describe('CSVUploader', () => { diff --git a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx index 6d5eb1ef95..a969b3d491 100644 --- a/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx +++ b/web/app/components/app/annotation/batch-add-annotation-modal/csv-uploader.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import { Csv as CSVIcon } from '@/app/components/base/icons/src/public/files' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { cn } from '@/utils/classnames' export type Props = { @@ -94,7 +94,7 @@ const CSVUploader: FC = ({ />
{!file && ( -
+
diff --git a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx index 5c803a91f0..90cbac13a4 100644 --- a/web/app/components/app/app-access-control/add-member-or-group-pop.tsx +++ b/web/app/components/app/app-access-control/add-member-or-group-pop.tsx @@ -10,7 +10,7 @@ import { SubjectType } from '@/models/access-control' import { useSearchForWhiteListCandidates } from '@/service/access-control' import { cn } from '@/utils/classnames' import useAccessControlStore from '../../../../context/access-control-store' -import Avatar from '../../base/avatar' +import { Avatar } from '../../base/avatar' import Button from '../../base/button' import Checkbox from '../../base/checkbox' import Input from '../../base/input' @@ -203,7 +203,7 @@ function MemberItem({ member }: MemberItemProps) {
- +

{member.name}

diff --git a/web/app/components/app/app-access-control/specific-groups-or-members.tsx b/web/app/components/app/app-access-control/specific-groups-or-members.tsx index 8ca817c872..2c0e4b2694 100644 --- a/web/app/components/app/app-access-control/specific-groups-or-members.tsx +++ b/web/app/components/app/app-access-control/specific-groups-or-members.tsx @@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next' import { AccessMode } from '@/models/access-control' import { useAppWhiteListSubjects } from '@/service/access-control' import useAccessControlStore from '../../../../context/access-control-store' -import Avatar from '../../base/avatar' +import { Avatar } from '../../base/avatar' import Loading from '../../base/loading' import Tooltip from '../../base/tooltip' import AddMemberOrGroupDialog from './add-member-or-group-pop' @@ -106,7 +106,7 @@ function MemberItem({ member }: MemberItemProps) { }, [member, setSpecificMembers, specificMembers]) return ( } + icon={} onRemove={handleRemoveMember} >

{member.name}

diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx index be4377bfd9..abcf5795d0 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.spec.tsx @@ -2,25 +2,19 @@ import { fireEvent, render, screen } from '@testing-library/react' import * as React from 'react' import HasNotSetAPI from './has-not-set-api' -describe('HasNotSetAPI WarningMask', () => { - it('should show default title when trial not finished', () => { - render() +describe('HasNotSetAPI', () => { + it('should render the empty state copy', () => { + render() - expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument() - expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfigured')).toBeInTheDocument() + expect(screen.getByText('appDebug.noModelProviderConfiguredTip')).toBeInTheDocument() }) - it('should show trail finished title when flag is true', () => { - render() - - expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument() - }) - - it('should call onSetting when primary button clicked', () => { + it('should call onSetting when manage models button is clicked', () => { const onSetting = vi.fn() - render() + render() - fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' })) + fireEvent.click(screen.getByRole('button', { name: 'appDebug.manageModels' })) expect(onSetting).toHaveBeenCalledTimes(1) }) }) diff --git a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx index 84323e64f5..2c5fc5ff2f 100644 --- a/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx +++ b/web/app/components/app/configuration/base/warning-mask/has-not-set-api.tsx @@ -2,38 +2,38 @@ import type { FC } from 'react' import * as React from 'react' import { useTranslation } from 'react-i18next' -import Button from '@/app/components/base/button' -import WarningMask from '.' export type IHasNotSetAPIProps = { - isTrailFinished: boolean onSetting: () => void } -const icon = ( - - - - -) - const HasNotSetAPI: FC = ({ - isTrailFinished, onSetting, }) => { const { t } = useTranslation() return ( - - {t('notSetAPIKey.settingBtn', { ns: 'appDebug' })} - {icon} - - )} - /> +
+
+
+
+ +
+
+
+
{t('noModelProviderConfigured', { ns: 'appDebug' })}
+
{t('noModelProviderConfiguredTip', { ns: 'appDebug' })}
+
+ +
+
) } export default React.memo(HasNotSetAPI) diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index d0e9eb586c..9625204d81 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -20,7 +20,7 @@ import { } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index bc94f87838..39a1699063 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -17,7 +17,7 @@ import { useFeaturesStore } from '@/app/components/base/features/hooks' import PromptEditor from '@/app/components/base/prompt-editor' import { PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER } from '@/app/components/base/prompt-editor/plugins/update-block' import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '@/app/components/base/prompt-editor/plugins/variable-block' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import Tooltip from '@/app/components/base/tooltip' import ConfigContext from '@/context/debug-configuration' import { useEventEmitterContextContext } from '@/context/event-emitter' @@ -178,7 +178,7 @@ const Prompt: FC = ({ {!noTitle && (
-
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
+
{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}
{!readonly && ( { }) render() - const input = screen.getByRole('spinbutton') as HTMLInputElement + const input = screen.getByRole('textbox') as HTMLInputElement fireEvent.change(input, { target: { value: '4' } }) const updatedFile = getLatestFileConfig() diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index e9e3b60859..f719d87261 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -12,7 +12,7 @@ import { CopyCheck, } from '@/app/components/base/icons/src/vender/line/files' import PromptEditor from '@/app/components/base/prompt-editor' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import ConfigContext from '@/context/debug-configuration' import { useModalContext } from '@/context/modal-context' import { cn } from '@/utils/classnames' @@ -96,7 +96,7 @@ const Editor: FC = ({ )}
-
+
= ({
= (
= (
-
{t('codegen.instruction', { ns: 'appDebug' })}
+
{t('codegen.instruction', { ns: 'appDebug' })}
= ( disabled={isLoading} > - {t('codegen.generate', { ns: 'appDebug' })} + {t('codegen.generate', { ns: 'appDebug' })}
diff --git a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx index 0bbed83a99..09a5ff6d07 100644 --- a/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/card-item/index.spec.tsx @@ -172,12 +172,8 @@ describe('dataset-config/card-item', () => { const [editButton] = within(card).getAllByRole('button', { hidden: true }) await user.click(editButton) - expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - await waitFor(() => { - expect(screen.getByRole('dialog')).toBeVisible() - }) - - fireEvent.click(screen.getByText('Save changes')) + expect(await screen.findByText('Mock settings modal')).toBeInTheDocument() + fireEvent.click(await screen.findByText('Save changes')) await waitFor(() => { expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' })) @@ -194,7 +190,7 @@ describe('dataset-config/card-item', () => { const card = screen.getByText(dataset.name).closest('.group') as HTMLElement const buttons = within(card).getAllByRole('button', { hidden: true }) - const deleteButton = buttons[buttons.length - 1] + const deleteButton = buttons.at(-1)! expect(deleteButton.className).not.toContain('action-btn-destructive') @@ -233,7 +229,7 @@ describe('dataset-config/card-item', () => { await user.click(editButton) expect(screen.getByText('Mock settings modal')).toBeInTheDocument() - const overlay = Array.from(document.querySelectorAll('[class]')) + const overlay = [...document.querySelectorAll('[class]')] .find(element => element.className.toString().includes('bg-black/30')) expect(overlay).toBeInTheDocument() diff --git a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx index 7f71247d56..8c6e626b45 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/index.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import ContextVar from './index' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx index aa8dae813f..6704fa0afd 100644 --- a/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/context-var/var-picker.spec.tsx @@ -5,7 +5,7 @@ import * as React from 'react' import VarPicker from './var-picker' // Mock external dependencies only -vi.mock('next/navigation', () => ({ +vi.mock('@/next/navigation', () => ({ useRouter: () => ({ push: vi.fn() }), usePathname: () => '/test', })) diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index d2e4913e54..6dd03d217e 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -370,7 +370,6 @@ const ConfigContent: FC = ({ const mockedUseCurrentProviderAndModel = useCurrentProviderAndModel as MockedFunction -let toastNotifySpy: MockInstance +let toastErrorSpy: MockInstance const createDatasetConfigs = (overrides: Partial = {}): DatasetConfigs => { return { @@ -140,7 +140,7 @@ describe('dataset-config/params-config', () => { beforeEach(() => { vi.clearAllMocks() vi.useRealTimers() - toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({})) + toastErrorSpy = vi.spyOn(toast, 'error').mockImplementation(() => '') mockedUseModelListAndDefaultModelAndCurrentProviderAndModel.mockReturnValue({ modelList: [], defaultModel: undefined, @@ -154,7 +154,7 @@ describe('dataset-config/params-config', () => { }) afterEach(() => { - toastNotifySpy.mockRestore() + toastErrorSpy.mockRestore() }) // Rendering tests (REQUIRED) @@ -180,12 +180,12 @@ describe('dataset-config/params-config', () => { const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const dialogScope = within(dialog) - const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + const incrementButtons = dialogScope.getAllByRole('button', { name: /increment/i }) await user.click(incrementButtons[0]) await waitFor(() => { - const [topKInput] = dialogScope.getAllByRole('spinbutton') - expect(topKInput).toHaveValue(5) + const [topKInput] = dialogScope.getAllByRole('textbox') + expect(topKInput).toHaveValue('5') }) await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) @@ -197,10 +197,10 @@ describe('dataset-config/params-config', () => { await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const reopenedScope = within(reopenedDialog) - const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + const [reopenedTopKInput] = reopenedScope.getAllByRole('textbox') // Assert - expect(reopenedTopKInput).toHaveValue(5) + expect(reopenedTopKInput).toHaveValue('5') }) it('should discard changes when cancel is clicked', async () => { @@ -213,12 +213,12 @@ describe('dataset-config/params-config', () => { const dialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const dialogScope = within(dialog) - const incrementButtons = dialogScope.getAllByRole('button', { name: 'increment' }) + const incrementButtons = dialogScope.getAllByRole('button', { name: /increment/i }) await user.click(incrementButtons[0]) await waitFor(() => { - const [topKInput] = dialogScope.getAllByRole('spinbutton') - expect(topKInput).toHaveValue(5) + const [topKInput] = dialogScope.getAllByRole('textbox') + expect(topKInput).toHaveValue('5') }) const cancelButton = await dialogScope.findByRole('button', { name: 'common.operation.cancel' }) @@ -231,10 +231,10 @@ describe('dataset-config/params-config', () => { await user.click(screen.getByRole('button', { name: 'dataset.retrievalSettings' })) const reopenedDialog = await screen.findByRole('dialog', {}, { timeout: 3000 }) const reopenedScope = within(reopenedDialog) - const [reopenedTopKInput] = reopenedScope.getAllByRole('spinbutton') + const [reopenedTopKInput] = reopenedScope.getAllByRole('textbox') // Assert - expect(reopenedTopKInput).toHaveValue(4) + expect(reopenedTopKInput).toHaveValue('4') }) it('should prevent saving when rerank model is required but invalid', async () => { @@ -254,10 +254,7 @@ describe('dataset-config/params-config', () => { await user.click(dialogScope.getByRole('button', { name: 'common.operation.save' })) // Assert - expect(toastNotifySpy).toHaveBeenCalledWith({ - type: 'error', - message: 'appDebug.datasetConfig.rerankModelRequired', - }) + expect(toastErrorSpy).toHaveBeenCalledWith('appDebug.datasetConfig.rerankModelRequired') expect(screen.getByRole('dialog')).toBeInTheDocument() }) }) diff --git a/web/app/components/app/configuration/dataset-config/params-config/index.tsx b/web/app/components/app/configuration/dataset-config/params-config/index.tsx index 692ae12022..89410203df 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/index.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/index.tsx @@ -7,7 +7,7 @@ import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import Button from '@/app/components/base/button' import Modal from '@/app/components/base/modal' -import Toast from '@/app/components/base/toast' +import { toast } from '@/app/components/base/ui/toast' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { @@ -66,10 +66,7 @@ const ParamsConfig = ({ } } if (errMsg) { - Toast.notify({ - type: 'error', - message: errMsg, - }) + toast.error(errMsg) } return !errMsg } diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx index 40cb3ffc81..bd6c1976a6 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.spec.tsx @@ -137,4 +137,31 @@ describe('SelectDataSet', () => { expect(screen.getByRole('link', { name: 'appDebug.feature.dataSet.toCreate' })).toHaveAttribute('href', '/datasets/create') expect(screen.getByRole('button', { name: 'common.operation.add' })).toBeDisabled() }) + + it('uses selectedIds as the initial modal selection', async () => { + const datasetOne = makeDataset({ + id: 'set-1', + name: 'Dataset One', + }) + mockUseInfiniteDatasets.mockReturnValue({ + data: { pages: [{ data: [datasetOne] }] }, + isLoading: false, + isFetchingNextPage: false, + fetchNextPage: vi.fn(), + hasNextPage: false, + }) + + const onSelect = vi.fn() + await act(async () => { + render() + }) + + expect(screen.getByText('1 appDebug.feature.dataSet.selected')).toBeInTheDocument() + + await act(async () => { + fireEvent.click(screen.getByRole('button', { name: 'common.operation.add' })) + }) + + expect(onSelect).toHaveBeenCalledWith([datasetOne]) + }) }) diff --git a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx index 330223f974..8c2fb77c20 100644 --- a/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx +++ b/web/app/components/app/configuration/dataset-config/select-dataset/index.tsx @@ -2,9 +2,8 @@ import type { FC } from 'react' import type { DataSet } from '@/models/datasets' import { useInfiniteScroll } from 'ahooks' -import Link from 'next/link' import * as React from 'react' -import { useEffect, useMemo, useRef, useState } from 'react' +import { useMemo, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import AppIcon from '@/app/components/base/app-icon' import Badge from '@/app/components/base/badge' @@ -14,6 +13,7 @@ import Modal from '@/app/components/base/modal' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import FeatureIcon from '@/app/components/header/account-setting/model-provider-page/model-selector/feature-icon' import { useKnowledge } from '@/hooks/use-knowledge' +import Link from '@/next/link' import { useInfiniteDatasets } from '@/service/knowledge/use-dataset' import { cn } from '@/utils/classnames' @@ -31,17 +31,21 @@ const SelectDataSet: FC = ({ onSelect, }) => { const { t } = useTranslation() - const [selected, setSelected] = useState([]) + const [selectedIdsInModal, setSelectedIdsInModal] = useState(() => selectedIds) const canSelectMulti = true const { formatIndexingTechniqueAndMethod } = useKnowledge() const { data, isLoading, isFetchingNextPage, fetchNextPage, hasNextPage } = useInfiniteDatasets( { page: 1 }, { enabled: isShow, staleTime: 0, refetchOnMount: 'always' }, ) - const pages = data?.pages || [] const datasets = useMemo(() => { + const pages = data?.pages || [] return pages.flatMap(page => page.data.filter(item => item.indexing_technique || item.provider === 'external')) - }, [pages]) + }, [data]) + const datasetMap = useMemo(() => new Map(datasets.map(item => [item.id, item])), [datasets]) + const selected = useMemo(() => { + return selectedIdsInModal.map(id => datasetMap.get(id) || ({ id } as DataSet)) + }, [datasetMap, selectedIdsInModal]) const hasNoData = !isLoading && datasets.length === 0 const listRef = useRef(null) @@ -61,50 +65,14 @@ const SelectDataSet: FC = ({ }, ) - const prevSelectedIdsRef = useRef([]) - const hasUserModifiedSelectionRef = useRef(false) - useEffect(() => { - if (isShow) - hasUserModifiedSelectionRef.current = false - }, [isShow]) - useEffect(() => { - const prevSelectedIds = prevSelectedIdsRef.current - const idsChanged = selectedIds.length !== prevSelectedIds.length - || selectedIds.some((id, idx) => id !== prevSelectedIds[idx]) - - if (!selectedIds.length && (!hasUserModifiedSelectionRef.current || idsChanged)) { - setSelected([]) - prevSelectedIdsRef.current = selectedIds - hasUserModifiedSelectionRef.current = false - return - } - - if (!idsChanged && hasUserModifiedSelectionRef.current) - return - - setSelected((prev) => { - const prevMap = new Map(prev.map(item => [item.id, item])) - const nextSelected = selectedIds - .map(id => datasets.find(item => item.id === id) || prevMap.get(id)) - .filter(Boolean) as DataSet[] - return nextSelected - }) - prevSelectedIdsRef.current = selectedIds - hasUserModifiedSelectionRef.current = false - }, [datasets, selectedIds]) - const toggleSelect = (dataSet: DataSet) => { - hasUserModifiedSelectionRef.current = true - const isSelected = selected.some(item => item.id === dataSet.id) - if (isSelected) { - setSelected(selected.filter(item => item.id !== dataSet.id)) - } - else { - if (canSelectMulti) - setSelected([...selected, dataSet]) - else - setSelected([dataSet]) - } + setSelectedIdsInModal((prev) => { + const isSelected = prev.includes(dataSet.id) + if (isSelected) + return prev.filter(id => id !== dataSet.id) + + return canSelectMulti ? [...prev, dataSet.id] : [dataSet.id] + }) } const handleSelect = () => { @@ -126,7 +94,7 @@ const SelectDataSet: FC = ({ {hasNoData && (
= ({ key={item.id} className={cn( 'flex h-10 cursor-pointer items-center rounded-lg border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg px-2 shadow-xs hover:border-components-panel-border hover:bg-components-panel-on-panel-item-bg-hover hover:shadow-sm', - selected.some(i => i.id === item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs', + selectedIdsInModal.includes(item.id) && 'border-[1.5px] border-components-option-card-option-selected-border bg-state-accent-hover shadow-xs hover:border-components-option-card-option-selected-border hover:bg-state-accent-hover hover:shadow-xs', !item.embedding_available && 'hover:border-components-panel-border-subtle hover:bg-components-panel-on-panel-item-bg hover:shadow-xs', )} onClick={() => { @@ -195,7 +163,7 @@ const SelectDataSet: FC = ({ )} {!isLoading && (
-
+
{selected.length > 0 && `${selected.length} ${t('feature.dataSet.selected', { ns: 'appDebug' })}`}
diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx index c4fdfb7553..ea70725ea8 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.spec.tsx @@ -3,7 +3,7 @@ import type { DataSet } from '@/models/datasets' import type { RetrievalConfig } from '@/types/app' import { render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' -import { ToastContext } from '@/app/components/base/toast' +import { ToastContext } from '@/app/components/base/toast/context' import { IndexingType } from '@/app/components/datasets/create/step-two' import { ACCOUNT_SETTING_TAB } from '@/app/components/header/account-setting/constants' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' diff --git a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx index e910702b66..bc534599de 100644 --- a/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx +++ b/web/app/components/app/configuration/dataset-config/settings-modal/index.tsx @@ -9,7 +9,7 @@ import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import Input from '@/app/components/base/input' import Textarea from '@/app/components/base/textarea' -import { useToastContext } from '@/app/components/base/toast' +import { useToastContext } from '@/app/components/base/toast/context' import { isReRankModelSelected } from '@/app/components/datasets/common/check-rerank-model' import { IndexingType } from '@/app/components/datasets/create/step-two' import IndexMethod from '@/app/components/datasets/settings/index-method' @@ -210,7 +210,7 @@ const SettingsModal: FC = ({
-
{t('form.name', { ns: 'datasetSettings' })}
+
{t('form.name', { ns: 'datasetSettings' })}
= ({
-
{t('form.desc', { ns: 'datasetSettings' })}
+
{t('form.desc', { ns: 'datasetSettings' })}